Compare commits

...

68 Commits

Author SHA1 Message Date
26c77e8522 Merge branch 'main' into refactor/model-manager2/model-metadata-store 2023-12-22 12:52:56 -05:00
fbede84405 [feature] Download Queue (#5225)
* add base definition of download manager

* basic functionality working

* add unit tests for download queue

* add documentation and FastAPI route

* fix docs

* add missing test dependency; fix import ordering

* fix file path length checking on windows

* fix ruff check error

* move release() into the __del__ method

* disable testing of stderr messages due to issues with pytest capsys fixture

* fix unsorted imports

* harmonized implementation of start() and stop() calls in download and & install modules

* Update invokeai/app/services/download/download_base.py

Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>

* replace test datadir fixture with tmp_path

* replace DownloadJobBase->DownloadJob in download manager documentation

* make source and dest arguments to download_queue.download() an AnyHttpURL and Path respectively

* fix pydantic typecheck errors in the download unit test

* ruff formatting

* add "job cancelled" as an event rather than an exception

* fix ruff errors

* Update invokeai/app/services/download/download_default.py

Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>

* use threading.Event to stop service worker threads; handle unfinished job edge cases

* remove dangling STOP job definition

* fix ruff complaint

* fix ruff check again

* avoid race condition when start() and stop() are called simultaneously from different threads

* avoid race condition in stop() when a job becomes active while shutting down

---------

Co-authored-by: Lincoln Stein <lstein@gmail.com>
Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>
Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
Co-authored-by: Kent Keirsey <31807370+hipsterusername@users.noreply.github.com>
2023-12-22 12:35:57 -05:00
756cb9c27e fix(tests): remove graph library from test fixtures 2023-12-23 00:04:48 +11:00
78b29db458 feat(backend): disable graph library
The graph library occasionally causes issues when the default graph changes substantially between versions and pydantic validation fails. See #5289 for an example.

We are not currently using the graph library, so we can disable it until we are ready to use it. It's possible that the workflow library will supersede it anyways.
2023-12-23 00:04:48 +11:00
1225c3fb47 addresses #5224 (#5332)
Co-authored-by: Lincoln Stein <lstein@gmail.com>
2023-12-22 12:30:51 +00:00
4957a360ff close #5209 2023-12-21 23:02:57 -05:00
32ad742f3e Ti trigger from prompt util (#5294)
* Pull logic for extracting TI triggers into a util function

* Remove duplicate regex for ti triggers

* Fix linting for ruff

* Remove unused imports
2023-12-22 03:04:44 +00:00
2d11d97dad remove MacOS Sonoma check in devices.py (#5312)
* remove MacOS Sonoma check in devices.py

As of pytorch 2.1.0, float16 works with our MPS fixes on Sonoma, so the check is no longer needed.

* remove unused platform import
2023-12-22 00:42:47 +00:00
64858b2523 Update contributingToFrontend.md (#5329)
The project is no longer using yarn as a package manager and have moved
to pnpm, So I wanted to update the documentation on the contribution
page.

## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [ ] Bug Fix
- [ ] Optimization
- [x] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [x] Yes
- [] No, because:
I spoke with user: imic in the #dev-chat on discord.
      
## Have you updated all relevant documentation?
- [x] Yes
- [ ] No


## Merge Plan
- "This PR can be merged when approved"
2023-12-22 08:38:34 +11:00
d5134325f6 Merge branch 'main' into patch-1 2023-12-22 08:37:15 +11:00
702d0f68af remove (Unsaved) if workflow library is disabled 2023-12-22 07:39:17 +11:00
a0d0e9f474 Update contributingToFrontend.md
The project is no longer using yarn as a package manager and have moved to pnpm, So I wanted to update the documentation on the contribution page.
2023-12-21 14:51:17 -05:00
475823835f Update communityNodes.md
Addition of my Adapters-Linked and Metadata-linked nodes
2023-12-21 13:51:59 -05:00
b95d547ccc Add more default workflows (#5325)
## What type of PR is this? (check all applicable)

- [ ] Refactor
- [X] Feature
- [ ] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [X] Yes
- [ ] No, because:

      
## Have you updated all relevant documentation?
- [X] Yes
- [ ] No


## Description
Added more default workflows to the workflow library

## Related Tickets & Documents

<!--
For pull requests that relate or close an issue, please include them
below. 

For example having the text: "closes #1234" would connect the current
pull
request to issue 1234.  And when we merge the pull request, Github will
automatically close the issue.
-->

- Related Issue #
- Closes #

## QA Instructions, Screenshots, Recordings

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

## Merge Plan

<!--
A merge plan describes how this PR should be handled after it is
approved.

Example merge plans:
- "This PR can be merged when approved"
- "This must be squash-merged when approved"
- "DO NOT MERGE - I will rebase and tidy commits before merging"
- "#dev-chat on discord needs to be advised of this change when it is
merged"

A merge plan is particularly important for large PRs or PRs that touch
the
database in any way.
-->

## Added/updated tests?

- [ ] Yes
- [ ] No : _please replace this line with details on why tests
      have not been included_

## [optional] Are there any post deployment tasks we need to perform?
2023-12-21 14:40:19 +11:00
19be196d50 remove redundant fetch; add modified/published dates; updated docs 2023-12-20 20:48:33 -05:00
f123ee61d0 add missing dependency for pytests 2023-12-20 20:20:27 -05:00
9b4758f02f Merge branch 'main' into feat/default_workflows 2023-12-21 10:35:02 +11:00
a626ca3e1c add unit tests and documentation 2023-12-20 17:58:34 -05:00
8d2952695d translationBot(ui): update translation (Chinese (Simplified))
Currently translated at 99.8% (1363 of 1365 strings)

Co-authored-by: Surisen <zhonghx0804@outlook.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/zh_Hans/
Translation: InvokeAI/Web UI
2023-12-21 09:56:06 +11:00
8dd55cc45e t2i with LoRA 2023-12-21 09:54:12 +11:00
562fb1f3a1 add authToastMiddleware back and fix parsing 2023-12-20 14:59:33 -05:00
1940169925 Merge branch 'main' into refactor/model-manager2/model-metadata-store 2023-12-20 09:49:37 -05:00
29b049b9d9 start unit tests 2023-12-20 09:48:05 -05:00
79cf3ec9a5 Add facedetailer workflow 2023-12-20 18:53:49 +11:00
37b76caccf Added default workflows 2023-12-20 17:42:14 +11:00
Sam
a4f9bfc8f7 Update Dockerfile 2023-12-19 18:38:36 -05:00
Sam
9afdd0f4a8 Update Dockerfile 2023-12-19 18:38:36 -05:00
bee6ad1547 fix(pnpm): replace npm with pnpm in dockerfile 2023-12-19 18:38:36 -05:00
87a5b771c4 merge with main 2023-12-19 17:04:30 -05:00
fa3f1b6e41 [Feat] reimport model config records after schema migration (#5281)
* add code to repopulate model config records after schema update

* reformat for ruff

* migrate model records using db cursor rather than the ModelRecordConfigService

* ruff fixes

* tweak exception reporting

* fix: build frontend in  pypi-release workflow

This was missing, resulting in the 3.5.0rc1 having no frontend.

* fix: use node 18, set working directory

- Node 20 has  a problem with `pnpm`; set it to Node 18
- Set the working directory for the frontend commands

* Don't copy extraneous paths into installer .zip

* feat(installer): delete frontend build after creating installer

This prevents an empty `dist/` from breaking the app on startup.

* feat: add python dist as release artifact, as input to enable publish to pypi

- The release workflow never runs automatically. It must be manually kicked off.
- The release workflow has an input. When running it from the GH actions UI, you will see a "Publish build on PyPi" prompt. If this value is "true", the workflow will upload the build to PyPi, releasing it. If this is anything else (e.g. "false", the default), the workflow will build but not upload to PyPi.
- The `dist/` folder (where the python package is built) is uploaded as a workflow artifact as a zip file. This can be downloaded and inspected. This allows "dry" runs of the workflow.
- The workflow job and some steps have been renamed to clarify what they do

* translationBot(ui): update translation files

Updated by "Cleanup translation files" hook in Weblate.

Co-authored-by: Hosted Weblate <hosted@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/
Translation: InvokeAI/Web UI

* freeze yaml migration logic at upgrade to 3.5

* moved migration code to migration_3

---------

Co-authored-by: Lincoln Stein <lstein@gmail.com>
Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
Co-authored-by: Hosted Weblate <hosted@weblate.org>
2023-12-19 17:01:47 -05:00
e86f3fe29e add storage 2023-12-19 17:00:49 -05:00
d0fa131010 (feat) updater installs from PyPi instead of GitHub releases (#5316)
## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [ ] Bug Fix
- [X] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [X] Yes
- [ ] No, because:

      
## Have you updated all relevant documentation?
- [ ] Yes
- [X] No


## Description
Updater script pulls from PyPI instead of GitHub releases (this is why
the RC packages are having issues when updating through the launcher
script)

## Related Tickets & Documents

<!--
For pull requests that relate or close an issue, please include them
below. 

For example having the text: "closes #1234" would connect the current
pull
request to issue 1234.  And when we merge the pull request, Github will
automatically close the issue.
-->

- Related Issue #
- Closes #

## QA Instructions, Screenshots, Recordings

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

## Merge Plan

<!--
A merge plan describes how this PR should be handled after it is
approved.

Example merge plans:
- "This PR can be merged when approved"
- "This must be squash-merged when approved"
- "DO NOT MERGE - I will rebase and tidy commits before merging"
- "#dev-chat on discord needs to be advised of this change when it is
merged"

A merge plan is particularly important for large PRs or PRs that touch
the
database in any way.
-->

## Added/updated tests?

- [ ] Yes
- [X] No : _please replace this line with details on why tests
      have not been included_

## [optional] Are there any post deployment tasks we need to perform?
2023-12-19 13:15:38 +11:00
2f438431bd (fix) update logic for installing specific version 2023-12-19 11:05:15 +11:00
bbeb5cb477 Merge branch 'main' into feat/updater_use_pypi 2023-12-19 10:09:03 +11:00
cd3111c324 fix ruff errors 2023-12-19 09:58:10 +11:00
16b7246412 (feat) updater installs from PyPi instead of GitHub releases 2023-12-19 09:30:40 +11:00
42be78d328 translationBot(ui): update translation (Italian)
Currently translated at 97.2% (1327 of 1365 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2023-12-19 07:20:14 +11:00
e469e24a58 Update model_probe to work with diffuser-format SD TI embeddings. (#5301)
## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [x] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission

      
## Have you updated all relevant documentation?
- [x] Yes (N/A)
- [ ] No


## Description

This change enables the model probe to work with TI embeddings that have
the follow state_dict structure:

```python
{
    "<any_key>": torch.Tensor(...), # where the tensor has shape (N, embedding_dim)
}
```

## QA Instructions, Screenshots, Recordings

I can't imagine an embedding format that would previously have passed
the model probe, and would now fail after this change. That being said,
I'll exercise a bunch of existing TIs before merging.

- [x] Exercise existing TI formats


## Added/updated tests?

- [ ] Yes
- [x] No : _We could really benefit from tests for all of the supported
TI formats... but I'm not taking on that project right now._
2023-12-18 10:01:04 -05:00
cb698ff1fb Update model_probe to work with diffuser-format SD TI embeddings. 2023-12-18 09:51:16 -05:00
0e738c4290 Tag model manager v2 api as unstable (#5311)
## What type of PR is this? (check all applicable)

- [X] Refactor
- [ ] Feature
- [ ] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [X] Yes
- [ ] No, because:

      
## Have you updated all relevant documentation?
- [X] Yes
- [ ] No


## Description

As discussed with @psychedelicious , this PR changes the swagger label
on the model manager V2 routes to `model_manager_v2_unstable` in order
to warn community members that the API is liable to change.

## Related Tickets & Documents

<!--
For pull requests that relate or close an issue, please include them
below. 

For example having the text: "closes #1234" would connect the current
pull
request to issue 1234.  And when we merge the pull request, Github will
automatically close the issue.
-->

- Related Issue #
- Closes #

## QA Instructions, Screenshots, Recordings

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

## Merge Plan

<!--
A merge plan describes how this PR should be handled after it is
approved.

Example merge plans:
- "This PR can be merged when approved"
- "This must be squash-merged when approved"
- "DO NOT MERGE - I will rebase and tidy commits before merging"
- "#dev-chat on discord needs to be advised of this change when it is
merged"

A merge plan is particularly important for large PRs or PRs that touch
the
database in any way.
-->

## Added/updated tests?

- [ ] Yes
- [ ] No : _please replace this line with details on why tests
      have not been included_

## [optional] Are there any post deployment tasks we need to perform?
2023-12-18 07:09:40 -05:00
09d1bc513d Merge branch 'main' into refactor/model-manager2/mark-api-experimental 2023-12-18 07:04:00 -05:00
c610283158 add basic functionality for model metadata fetching from hf and civitai 2023-12-17 22:19:29 -05:00
aefa828237 Tiled upscaling - EvenSplit to use overlap in pixels instead tile fraction (#5309)
## What type of PR is this? (check all applicable)

- [ ] Refactor
- [x] Feature
- [ ] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [x] Yes
- [ ] No, because:

      
## Have you updated all relevant documentation?
- [ ] Yes
- [x] No


## Description
Change CalculateImageTilesEvenSplitInvocation to have an overlap in
pixels rather than as a percentage of the tile. This makes it easier to
have predictable blending of the seams as you have a known overlap size.

## Related Tickets & Documents

<!--
For pull requests that relate or close an issue, please include them
below. 

For example having the text: "closes #1234" would connect the current
pull
request to issue 1234.  And when we merge the pull request, Github will
automatically close the issue.
-->

- Related Issue #
- Closes #

## QA Instructions, Screenshots, Recordings

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

## Merge Plan

<!--
A merge plan describes how this PR should be handled after it is
approved.

Example merge plans:
- "This PR can be merged when approved"
- "This must be squash-merged when approved"
- "DO NOT MERGE - I will rebase and tidy commits before merging"
- "#dev-chat on discord needs to be advised of this change when it is
merged"

A merge plan is particularly important for large PRs or PRs that touch
the
database in any way.
-->

## Added/updated tests?

- [x] Yes
- [ ] No : _please replace this line with details on why tests
      have not been included_

## [optional] Are there any post deployment tasks we need to perform?
2023-12-17 21:13:45 -05:00
74ea592d02 tag model manager v2 api as unstable 2023-12-17 14:16:45 -05:00
457b0dfac0 Merge branch 'main' into tiled-upscaling-graph 2023-12-17 15:12:16 +00:00
96a717c4ba In CalculateImageTilesEvenSplitInvocation to have overlap_fraction becomes just overlap. This is now in pixels rather than as a fraction of the tile size.
Update calc_tiles_even_split() with the same change. Ensuring Overlap is within allowed size

Update even_split tests
2023-12-17 15:10:50 +00:00
77b74264a8 Simplify docker compose setup (#5046)
## What type of PR is this? (check all applicable)

- [x] Refactor
- [x] Feature
- [ ] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [x] Yes -
https://github.com/invoke-ai/InvokeAI/pull/5007#discussion_r1378792615
- [ ] No, because: 

      
## Have you updated all relevant documentation?
- [x] Yes
- [ ] No


## Description

Simplify Docker image creation and execution to a single script that
spins up the right service in the docker compose file.
## Related Tickets & Documents

<!--
For pull requests that relate or close an issue, please include them
below. 

For example having the text: "closes #1234" would connect the current
pull
request to issue 1234.  And when we merge the pull request, Github will
automatically close the issue.
-->

- Depends on #5007

## QA Instructions, Screenshots, Recordings
N/A
<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

## Added/updated tests?

- [ ] Yes
- [x] No : same tests should work.

## [optional] Are there any post deployment tasks we need to perform?

Not to my knowledge.
2023-12-17 17:10:56 +11:00
351078e8aa Merge branch 'main' into simplify-docker-compose-setup 2023-12-17 17:07:55 +11:00
b8354bd1a4 Merge branch 'main' into tiled-upscaling-graph 2023-12-16 19:09:28 +00:00
3b944b8af6 fix: build frontend in pypi-release workflow (#5298)
## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [x] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [x] Yes
- [ ] No, because:

## Description

This was missing, resulting in the 3.5.0rc1 having no frontend.

## Related Tickets & Documents

<!--
For pull requests that relate or close an issue, please include them
below. 

For example having the text: "closes #1234" would connect the current
pull
request to issue 1234.  And when we merge the pull request, Github will
automatically close the issue.
-->

- Discord installer thread:
https://discord.com/channels/1020123559063990373/1149513695567810630/1185200427717898260
- Comments from here in the release chat:
https://discord.com/channels/1020123559063990373/1020123559831539744/1185004017521279007

## QA Instructions, Screenshots, Recordings

I've run this locally and it works (I commented out the final steps of
the workflow that do PyPi stuff to ensure I didn't accidentally deploy
something).

You can run the workflow locally with https://github.com/nektos/act.
Suggest using the `gh` CLI version, its very easy to set up if you have
the github CLI installed. Then you can run `gh act -W
.github/workflows/pypi-release.yml` to run the workflow locally in a
docker image.

I don't know this local action runner would actually release to PyPi -
as mentioned, I commented those steps out when testing - but it does
successfully do both frontend and backend builds.

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

## Merge Plan

This needs @lstein 's approval.

<!--
A merge plan describes how this PR should be handled after it is
approved.

Example merge plans:
- "This PR can be merged when approved"
- "This must be squash-merged when approved"
- "DO NOT MERGE - I will rebase and tidy commits before merging"
- "#dev-chat on discord needs to be advised of this change when it is
merged"

A merge plan is particularly important for large PRs or PRs that touch
the
database in any way.
-->

## [optional] Are there any post deployment tasks we need to perform?

Cut an RC2
2023-12-16 10:40:36 -05:00
b811c037bd Merge branch 'main' into fix/pypi-release-frontend-build 2023-12-16 10:36:03 -05:00
5bf61382a4 feat: add python dist as release artifact, as input to enable publish to pypi
- The release workflow never runs automatically. It must be manually kicked off.
- The release workflow has an input. When running it from the GH actions UI, you will see a "Publish build on PyPi" prompt. If this value is "true", the workflow will upload the build to PyPi, releasing it. If this is anything else (e.g. "false", the default), the workflow will build but not upload to PyPi.
- The `dist/` folder (where the python package is built) is uploaded as a workflow artifact as a zip file. This can be downloaded and inspected. This allows "dry" runs of the workflow.
- The workflow job and some steps have been renamed to clarify what they do
2023-12-16 20:02:09 +11:00
0f1c5f382a feat(installer): delete frontend build after creating installer
This prevents an empty `dist/` from breaking the app on startup.
2023-12-16 19:39:29 +11:00
4af1695c60 translationBot(ui): update translation files
Updated by "Cleanup translation files" hook in Weblate.

Co-authored-by: Hosted Weblate <hosted@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/
Translation: InvokeAI/Web UI
2023-12-16 13:10:47 +11:00
df9a903a50 fix(ui): do not cache VAE decode on linear
The VAE decode on linear graphs was getting cached. This caused some unexpected behaviour around image outputs.

For example, say you ran the exact same graph twice. The first time, you get an image written to disk and added to gallery. The second time, the VAE decode is cached and no image file is created. But, the UI still gets the graph complete event and selects the first image in the gallery. The second run does not add an image to the gallery.

There are probbably edge cases related to this - the UI does not expect this to happen. I'm not sure how to handle it any better in the UI.

The solution is to not cache VAE decode on the linear graphs, ever. If you run a graph twice in linear, you expect two images.

This simple change disables the node cache for terminal VAE decode nodes in all linear graphs, ensuring you always get images. If they graph was fully cached, all images after the first will be created very quickly of course.
2023-12-16 12:37:49 +11:00
311be8f97d Merge branch 'main' into fix/pypi-release-frontend-build 2023-12-16 10:15:32 +11:00
3f970c8326 Don't copy extraneous paths into installer .zip 2023-12-15 11:27:21 -05:00
fc150acde5 [feat] Make model prober recognize yet another LoRA format (#5296)
## What type of PR is this? (check all applicable)

- [ ] Refactor
- [X] Feature
- [ ] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [X] Yes
- [ ] No, because:

      
## Have you updated all relevant documentation?
- [X] Yes
- [ ] No


## Description

This adds a probe for the SDXL LoRA format found in the wild at
https://civitai.com/models/224641.

## Related Tickets & Documents

<!--
For pull requests that relate or close an issue, please include them
below. 

For example having the text: "closes #1234" would connect the current
pull
request to issue 1234.  And when we merge the pull request, Github will
automatically close the issue.
-->

See discord message at:
https://discord.com/channels/1020123559063990373/1149510134058471514/1184982133912113182

## QA Instructions, Screenshots, Recordings

Try installing the SDXL LoRA at the URL given above.
## Merge Plan

This can be merged when approved.
## Added/updated tests?

- [ ] Yes
- [X] No : we do not yet have a comprehensive suite of models to test
probing on.

## [optional] Are there any post deployment tasks we need to perform?
2023-12-15 09:49:51 -05:00
1615df3aa1 fix: use node 18, set working directory
- Node 20 has  a problem with `pnpm`; set it to Node 18
- Set the working directory for the frontend commands
2023-12-16 00:32:31 +11:00
b2a8c45553 fix: build frontend in pypi-release workflow
This was missing, resulting in the 3.5.0rc1 having no frontend.
2023-12-15 23:56:31 +11:00
212dbaf9a2 fix comment 2023-12-15 00:25:27 -05:00
ac3cf48d7f make probe recognize lora format at https://civitai.com/models/224641 2023-12-15 00:25:27 -05:00
296060db63 Add cpu and rocm profiles. Let invokeai-nvidia service be the default. 2023-12-13 23:23:43 -05:00
d1d8ee71fc Simplify docker compose setup 2023-12-13 23:23:43 -05:00
612912a6c9 updated tests with a test for tile > image for calc_tiles_min_overlap() 2023-12-12 14:12:22 +00:00
bca2372280 updated comment 2023-12-12 14:02:28 +00:00
0b860582f0 remove unneeded if else 2023-12-12 14:00:06 +00:00
87ff380fe4 fix for calc_tiles_min_overlap when tile size is bigger than image size 2023-12-12 13:40:28 +00:00
82 changed files with 10787 additions and 304 deletions

View File

@ -21,16 +21,16 @@ jobs:
if: github.event.pull_request.draft == false
runs-on: ubuntu-22.04
steps:
- name: Setup Node 20
- name: Setup Node 18
uses: actions/setup-node@v4
with:
node-version: '20'
node-version: '18'
- name: Checkout
uses: actions/checkout@v4
- name: Setup pnpm
uses: pnpm/action-setup@v2
with:
version: 8
version: '8.12.1'
- name: Install dependencies
run: 'pnpm install --prefer-frozen-lockfile'
- name: Typescript

View File

@ -1,13 +1,15 @@
name: PyPI Release
on:
push:
paths:
- 'invokeai/version/invokeai_version.py'
workflow_dispatch:
inputs:
publish_package:
description: 'Publish build on PyPi? [true/false]'
required: true
default: 'false'
jobs:
release:
build-and-release:
if: github.repository == 'invoke-ai/InvokeAI'
runs-on: ubuntu-22.04
env:
@ -15,19 +17,43 @@ jobs:
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
TWINE_NON_INTERACTIVE: 1
steps:
- name: checkout sources
uses: actions/checkout@v3
- name: Checkout
uses: actions/checkout@v4
- name: install deps
- name: Setup Node 18
uses: actions/setup-node@v4
with:
node-version: '18'
- name: Setup pnpm
uses: pnpm/action-setup@v2
with:
version: '8.12.1'
- name: Install frontend dependencies
run: pnpm install --prefer-frozen-lockfile
working-directory: invokeai/frontend/web
- name: Build frontend
run: pnpm run build
working-directory: invokeai/frontend/web
- name: Install python dependencies
run: pip install --upgrade build twine
- name: build package
- name: Build python package
run: python3 -m build
- name: check distribution
- name: Upload build as workflow artifact
uses: actions/upload-artifact@v4
with:
name: dist
path: dist
- name: Check distribution
run: twine check dist/*
- name: check PyPI versions
- name: Check PyPI versions
if: github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')
run: |
pip install --upgrade requests
@ -36,6 +62,6 @@ jobs:
EXISTS=scripts.pypi_helper.local_on_pypi(); \
print(f'PACKAGE_EXISTS={EXISTS}')" >> $GITHUB_ENV
- name: upload package
if: env.PACKAGE_EXISTS == 'False' && env.TWINE_PASSWORD != ''
- name: Publish build on PyPi
if: env.PACKAGE_EXISTS == 'False' && env.TWINE_PASSWORD != '' && github.event.inputs.publish_package == 'true'
run: twine upload dist/*

View File

@ -270,7 +270,7 @@ upgrade script.** See the next section for a Windows recipe.
3. Select option [1] to upgrade to the latest release.
4. Once the upgrade is finished you will be returned to the launcher
menu. Select option [7] "Re-run the configure script to fix a broken
menu. Select option [6] "Re-run the configure script to fix a broken
install or to complete a major upgrade".
This will run the configure script against the v2.3 directory and

View File

@ -59,14 +59,16 @@ RUN --mount=type=cache,target=/root/.cache/pip \
# #### Build the Web UI ------------------------------------
FROM node:18 AS web-builder
FROM node:18-slim AS web-builder
ENV PNPM_HOME="/pnpm"
ENV PATH="$PNPM_HOME:$PATH"
RUN corepack enable
WORKDIR /build
COPY invokeai/frontend/web/ ./
RUN --mount=type=cache,target=/usr/lib/node_modules \
npm install --include dev
RUN --mount=type=cache,target=/usr/lib/node_modules \
yarn vite build
RUN --mount=type=cache,target=/pnpm/store \
pnpm install --frozen-lockfile
RUN pnpm run build
#### Runtime stage ---------------------------------------

View File

@ -23,7 +23,7 @@ This is done via Docker Desktop preferences
1. Make a copy of `env.sample` and name it `.env` (`cp env.sample .env` (Mac/Linux) or `copy example.env .env` (Windows)). Make changes as necessary. Set `INVOKEAI_ROOT` to an absolute path to:
a. the desired location of the InvokeAI runtime directory, or
b. an existing, v3.0.0 compatible runtime directory.
1. `docker compose up`
1. Execute `run.sh`
The image will be built automatically if needed.
@ -39,7 +39,7 @@ The Docker daemon on the system must be already set up to use the GPU. In case o
## Customize
Check the `.env.sample` file. It contains some environment variables for running in Docker. Copy it, name it `.env`, and fill it in with your own values. Next time you run `docker compose up`, your custom values will be used.
Check the `.env.sample` file. It contains some environment variables for running in Docker. Copy it, name it `.env`, and fill it in with your own values. Next time you run `run.sh`, your custom values will be used.
You can also set these values in `docker-compose.yml` directly, but `.env` will help avoid conflicts when code is updated.

View File

@ -1,11 +0,0 @@
#!/usr/bin/env bash
set -e
build_args=""
[[ -f ".env" ]] && build_args=$(awk '$1 ~ /\=[^$]/ {print "--build-arg " $0 " "}' .env)
echo "docker compose build args:"
echo $build_args
docker compose build $build_args

View File

@ -2,23 +2,8 @@
version: '3.8'
services:
invokeai:
x-invokeai: &invokeai
image: "local/invokeai:latest"
# edit below to run on a container runtime other than nvidia-container-runtime.
# not yet tested with rocm/AMD GPUs
# Comment out the "deploy" section to run on CPU only
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
# For AMD support, comment out the deploy section above and uncomment the devices section below:
#devices:
# - /dev/kfd:/dev/kfd
# - /dev/dri:/dev/dri
build:
context: ..
dockerfile: docker/Dockerfile
@ -50,3 +35,27 @@ services:
# - |
# invokeai-model-install --yes --default-only --config_file ${INVOKEAI_ROOT}/config_custom.yaml
# invokeai-nodes-web --host 0.0.0.0
services:
invokeai-nvidia:
<<: *invokeai
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
invokeai-cpu:
<<: *invokeai
profiles:
- cpu
invokeai-rocm:
<<: *invokeai
devices:
- /dev/kfd:/dev/kfd
- /dev/dri:/dev/dri
profiles:
- rocm

View File

@ -1,11 +1,28 @@
#!/usr/bin/env bash
set -e
# This script is provided for backwards compatibility with the old docker setup.
# it doesn't do much aside from wrapping the usual docker compose CLI.
run() {
local scriptdir=$(dirname "${BASH_SOURCE[0]}")
cd "$scriptdir" || exit 1
SCRIPTDIR=$(dirname "${BASH_SOURCE[0]}")
cd "$SCRIPTDIR" || exit 1
local build_args=""
local profile=""
docker compose up -d
docker compose logs -f
[[ -f ".env" ]] &&
build_args=$(awk '$1 ~ /=[^$]/ && $0 !~ /^#/ {print "--build-arg " $0 " "}' .env) &&
profile="$(awk -F '=' '/GPU_DRIVER/ {print $2}' .env)"
local service_name="invokeai-$profile"
printf "%s\n" "docker compose build args:"
printf "%s\n" "$build_args"
docker compose build $build_args
unset build_args
printf "%s\n" "starting service $service_name"
docker compose --profile "$profile" up -d "$service_name"
docker compose logs -f
}
run

View File

@ -0,0 +1,277 @@
# The InvokeAI Download Queue
The DownloadQueueService provides a multithreaded parallel download
queue for arbitrary URLs, with queue prioritization, event handling,
and restart capabilities.
## Simple Example
```
from invokeai.app.services.download import DownloadQueueService, TqdmProgress
download_queue = DownloadQueueService()
for url in ['https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/assets/a-painting-of-a-fire.png?raw=true',
'https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/assets/birdhouse.png?raw=true',
'https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/assets/missing.png',
'https://civitai.com/api/download/models/152309?type=Model&format=SafeTensor',
]:
# urls start downloading as soon as download() is called
download_queue.download(source=url,
dest='/tmp/downloads',
on_progress=TqdmProgress().update
)
download_queue.join() # wait for all downloads to finish
for job in download_queue.list_jobs():
print(job.model_dump_json(exclude_none=True, indent=4),"\n")
```
Output:
```
{
"source": "https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/assets/a-painting-of-a-fire.png?raw=true",
"dest": "/tmp/downloads",
"id": 0,
"priority": 10,
"status": "completed",
"download_path": "/tmp/downloads/a-painting-of-a-fire.png",
"job_started": "2023-12-04T05:34:41.742174",
"job_ended": "2023-12-04T05:34:42.592035",
"bytes": 666734,
"total_bytes": 666734
}
{
"source": "https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/assets/birdhouse.png?raw=true",
"dest": "/tmp/downloads",
"id": 1,
"priority": 10,
"status": "completed",
"download_path": "/tmp/downloads/birdhouse.png",
"job_started": "2023-12-04T05:34:41.741975",
"job_ended": "2023-12-04T05:34:42.652841",
"bytes": 774949,
"total_bytes": 774949
}
{
"source": "https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/assets/missing.png",
"dest": "/tmp/downloads",
"id": 2,
"priority": 10,
"status": "error",
"job_started": "2023-12-04T05:34:41.742079",
"job_ended": "2023-12-04T05:34:42.147625",
"bytes": 0,
"total_bytes": 0,
"error_type": "HTTPError(Not Found)",
"error": "Traceback (most recent call last):\n File \"/home/lstein/Projects/InvokeAI/invokeai/app/services/download/download_default.py\", line 182, in _download_next_item\n self._do_download(job)\n File \"/home/lstein/Projects/InvokeAI/invokeai/app/services/download/download_default.py\", line 206, in _do_download\n raise HTTPError(resp.reason)\nrequests.exceptions.HTTPError: Not Found\n"
}
{
"source": "https://civitai.com/api/download/models/152309?type=Model&format=SafeTensor",
"dest": "/tmp/downloads",
"id": 3,
"priority": 10,
"status": "completed",
"download_path": "/tmp/downloads/xl_more_art-full_v1.safetensors",
"job_started": "2023-12-04T05:34:42.147645",
"job_ended": "2023-12-04T05:34:43.735990",
"bytes": 719020768,
"total_bytes": 719020768
}
```
## The API
The default download queue is `DownloadQueueService`, an
implementation of ABC `DownloadQueueServiceBase`. It juggles multiple
background download requests and provides facilities for interrogating
and cancelling the requests. Access to a current or past download task
is mediated via `DownloadJob` objects which report the current status
of a job request
### The Queue Object
A default download queue is located in
`ApiDependencies.invoker.services.download_queue`. However, you can
create additional instances if you need to isolate your queue from the
main one.
```
queue = DownloadQueueService(event_bus=events)
```
`DownloadQueueService()` takes three optional arguments:
| **Argument** | **Type** | **Default** | **Description** |
|----------------|-----------------|---------------|-----------------|
| `max_parallel_dl` | int | 5 | Maximum number of simultaneous downloads allowed |
| `event_bus` | EventServiceBase | None | System-wide FastAPI event bus for reporting download events |
| `requests_session` | requests.sessions.Session | None | An alternative requests Session object to use for the download |
`max_parallel_dl` specifies how many download jobs are allowed to run
simultaneously. Each will run in a different thread of execution.
`event_bus` is an EventServiceBase, typically the one created at
InvokeAI startup. If present, download events are periodically emitted
on this bus to allow clients to follow download progress.
`requests_session` is a url library requests Session object. It is
used for testing.
### The Job object
The queue operates on a series of download job objects. These objects
specify the source and destination of the download, and keep track of
the progress of the download.
The only job type currently implemented is `DownloadJob`, a pydantic object with the
following fields:
| **Field** | **Type** | **Default** | **Description** |
|----------------|-----------------|---------------|-----------------|
| _Fields passed in at job creation time_ |
| `source` | AnyHttpUrl | | Where to download from |
| `dest` | Path | | Where to download to |
| `access_token` | str | | [optional] string containing authentication token for access |
| `on_start` | Callable | | [optional] callback when the download starts |
| `on_progress` | Callable | | [optional] callback called at intervals during download progress |
| `on_complete` | Callable | | [optional] callback called after successful download completion |
| `on_error` | Callable | | [optional] callback called after an error occurs |
| `id` | int | auto assigned | Job ID, an integer >= 0 |
| `priority` | int | 10 | Job priority. Lower priorities run before higher priorities |
| |
| _Fields updated over the course of the download task_
| `status` | DownloadJobStatus| | Status code |
| `download_path` | Path | | Path to the location of the downloaded file |
| `job_started` | float | | Timestamp for when the job started running |
| `job_ended` | float | | Timestamp for when the job completed or errored out |
| `job_sequence` | int | | A counter that is incremented each time a model is dequeued |
| `bytes` | int | 0 | Bytes downloaded so far |
| `total_bytes` | int | 0 | Total size of the file at the remote site |
| `error_type` | str | | String version of the exception that caused an error during download |
| `error` | str | | String version of the traceback associated with an error |
| `cancelled` | bool | False | Set to true if the job was cancelled by the caller|
When you create a job, you can assign it a `priority`. If multiple
jobs are queued, the job with the lowest priority runs first.
Every job has a `source` and a `dest`. `source` is a pydantic.networks AnyHttpUrl object.
The `dest` is a path on the local filesystem that specifies the
destination for the downloaded object. Its semantics are
described below.
When the job is submitted, it is assigned a numeric `id`. The id can
then be used to fetch the job object from the queue.
The `status` field is updated by the queue to indicate where the job
is in its lifecycle. Values are defined in the string enum
`DownloadJobStatus`, a symbol available from
`invokeai.app.services.download_manager`. Possible values are:
| **Value** | **String Value** | ** Description ** |
|--------------|---------------------|-------------------|
| `WAITING` | waiting | Job is on the queue but not yet running|
| `RUNNING` | running | The download is started |
| `COMPLETED` | completed | Job has finished its work without an error |
| `ERROR` | error | Job encountered an error and will not run again|
`job_started` and `job_ended` indicate when the job
was started (using a python timestamp) and when it completed.
In case of an error, the job's status will be set to `DownloadJobStatus.ERROR`, the text of the
Exception that caused the error will be placed in the `error_type`
field and the traceback that led to the error will be in `error`.
A cancelled job will have status `DownloadJobStatus.ERROR` and an
`error_type` field of "DownloadJobCancelledException". In addition,
the job's `cancelled` property will be set to True.
### Callbacks
Download jobs can be associated with a series of callbacks, each with
the signature `Callable[["DownloadJob"], None]`. The callbacks are assigned
using optional arguments `on_start`, `on_progress`, `on_complete` and
`on_error`. When the corresponding event occurs, the callback wil be
invoked and passed the job. The callback will be run in a `try:`
context in the same thread as the download job. Any exceptions that
occur during execution of the callback will be caught and converted
into a log error message, thereby allowing the download to continue.
#### `TqdmProgress`
The `invokeai.app.services.download.download_default` module defines a
class named `TqdmProgress` which can be used as an `on_progress`
handler to display a completion bar in the console. Use as follows:
```
from invokeai.app.services.download import TqdmProgress
download_queue.download(source='http://some.server.somewhere/some_file',
dest='/tmp/downloads',
on_progress=TqdmProgress().update
)
```
### Events
If the queue was initialized with the InvokeAI event bus (the case
when using `ApiDependencies.invoker.services.download_queue`), then
download events will also be issued on the bus. The events are:
* `download_started` -- This is issued when a job is taken off the
queue and a request is made to the remote server for the URL headers, but before any data
has been downloaded. The event payload will contain the keys `source`
and `download_path`. The latter contains the path that the URL will be
downloaded to.
* `download_progress -- This is issued periodically as the download
runs. The payload contains the keys `source`, `download_path`,
`current_bytes` and `total_bytes`. The latter two fields can be
used to display the percent complete.
* `download_complete` -- This is issued when the download completes
successfully. The payload contains the keys `source`, `download_path`
and `total_bytes`.
* `download_error` -- This is issued when the download stops because
of an error condition. The payload contains the fields `error_type`
and `error`. The former is the text representation of the exception,
and the latter is a traceback showing where the error occurred.
### Job control
To create a job call the queue's `download()` method. You can list all
jobs using `list_jobs()`, fetch a single job by its with
`id_to_job()`, cancel a running job with `cancel_job()`, cancel all
running jobs with `cancel_all_jobs()`, and wait for all jobs to finish
with `join()`.
#### job = queue.download(source, dest, priority, access_token)
Create a new download job and put it on the queue, returning the
DownloadJob object.
#### jobs = queue.list_jobs()
Return a list of all active and inactive `DownloadJob`s.
#### job = queue.id_to_job(id)
Return the job corresponding to given ID.
Return a list of all active and inactive `DownloadJob`s.
#### queue.prune_jobs()
Remove inactive (complete or errored) jobs from the listing returned
by `list_jobs()`.
#### queue.join()
Block until all pending jobs have run to completion or errored out.

View File

@ -15,7 +15,12 @@ model. These are the:
their metadata, and `ModelRecordServiceBase` to store that
information. It is also responsible for managing the InvokeAI
`models` directory and its contents.
* _ModelMetadataStore_ and _ModelMetaDataFetch_ Backend modules that
are able to retrieve metadata from online model repositories,
transform them into Pydantic models, and cache them to the InvokeAI
SQL database.
* _DownloadQueueServiceBase_ (**CURRENTLY UNDER DEVELOPMENT - NOT IMPLEMENTED**)
A multithreaded downloader responsible
for downloading models from a remote source to disk. The download
@ -1184,3 +1189,248 @@ other resources that it might have been using.
This will start/pause/cancel all jobs that have been submitted to the
queue and have not yet reached a terminal state.
***
## This Meta be Good: Model Metadata Storage
The modules found under `invokeai.backend.model_manager.metadata`
provide a straightforward API for fetching model metadatda from online
repositories. Currently two repositories are supported: HuggingFace
and Civitai. However, the modules are easily extended for additional
repos, provided that they have defined APIs for metadata access.
Metadata comprises any descriptive information that is not essential
for getting the model to run. For example "author" is metadata, while
"type", "base" and "format" are not. The latter fields are part of the
model's config, as defined in `invokeai.backend.model_manager.config`.
### Example Usage:
```
from invokeai.backend.model_manager.metadata import (
AnyModelRepoMetadata,
CivitaiMetadataFetch,
CivitaiMetadata
ModelMetadataStore,
)
# to access the initialized sql database
from invokeai.app.api.dependencies import ApiDependencies
civitai = CivitaiMetadataFetch()
# fetch the metadata
model_metadata = civitai.from_url("https://civitai.com/models/215796")
# get some common metadata fields
author = model_metadata.author
tags = model_metadata.tags
# get some Civitai-specific fields
assert isinstance(model_metadata, CivitaiMetadata)
trained_words = model_metadata.trained_words
base_model = model_metadata.base_model_trained_on
thumbnail = model_metadata.thumbnail_url
# cache the metadata to the database using the key corresponding to
# an existing model config record in the `model_config` table
sql_cache = ModelMetadataStore(ApiDependencies.invoker.services.db)
sql_cache.add_metadata('fb237ace520b6716adc98bcb16e8462c', model_metadata)
# now we can search the database by tag, author or model name
# matches will contain a list of model keys that match the search
matches = sql_cache.search_by_tag({"tool", "turbo"})
```
### Structure of the Metadata objects
There is a short class hierarchy of Metadata objects, all of which
descend from the Pydantic `BaseModel`.
#### `ModelMetadataBase`
This is the common base class for metadata:
| **Field Name** | **Type** | **Description** |
|----------------|-----------------|------------------|
| `name` | str | Repository's name for the model |
| `author` | str | Model's author |
| `tags` | Set[str] | Model tags |
Note that the model config record also has a `name` field. It is
intended that the config record version be locally customizable, while
the metadata version is read-only. However, enforcing this is expected
to be part of the business logic.
Descendents of the base add additional fields.
#### `HuggingFaceMetadata`
This descends from `ModelMetadataBase` and adds the following fields:
| **Field Name** | **Type** | **Description** |
|----------------|-----------------|------------------|
| `type` | Literal["huggingface"] | Used for the discriminated union of metadata classes|
| `id` | str | HuggingFace repo_id |
| `tag_dict` | Dict[str, Any] | A dictionary of tag/value pairs provided in addition to `tags` |
| `last_modified`| datetime | Date of last commit of this model to the repo |
| `files` | List[Path] | List of the files in the model repo |
#### `CivitaiMetadata`
This descends from `ModelMetadataBase` and adds the following fields:
| **Field Name** | **Type** | **Description** |
|----------------|-----------------|------------------|
| `type` | Literal["civitai"] | Used for the discriminated union of metadata classes|
| `id` | int | Civitai model id |
| `version_name` | str | Name of this version of the model (distinct from model name) |
| `version_id` | int | Civitai model version id (distinct from model id) |
| `created` | datetime | Date this version of the model was created |
| `updated` | datetime | Date this version of the model was last updated |
| `published` | datetime | Date this version of the model was published to Civitai |
| `description` | str | Model description. Quite verbose and contains HTML tags |
| `version_description` | str | Model version description, usually describes changes to the model |
| `nsfw` | bool | Whether the model tends to generate NSFW content |
| `restrictions` | LicenseRestrictions | An object that describes what is and isn't allowed with this model |
| `trained_words`| Set[str] | Trigger words for this model, if any |
| `download_url` | AnyHttpUrl | URL for downloading this version of the model |
| `base_model_trained_on` | str | Name of the model that this version was trained on |
| `thumbnail_url` | AnyHttpUrl | URL to access a representative thumbnail image of the model's output |
| `weight_min` | int | For LoRA sliders, the minimum suggested weight to apply |
| `weight_max` | int | For LoRA sliders, the maximum suggested weight to apply |
Note that `weight_min` and `weight_max` are not currently populated
and take the default values of (-1.0, +2.0). The issue is that these
values aren't part of the structured data but appear in the text
description. Some regular expression or LLM coding may be able to
extract these values.
Also be aware that `base_model_trained_on` is free text and doesn't
correspond to our `ModelType` enum.
`CivitaiMetadata` also defines some convenience properties relating to
licensing restrictions: `credit_required`, `allow_commercial_use`,
`allow_derivatives` and `allow_different_license`.
#### `AnyModelRepoMetadata`
This is a discriminated Union of `CivitaiMetadata` and
`HuggingFaceMetadata`.
### Fetching Metadata from Online Repos
The `HuggingFaceMetadataFetch` and `CivitaiMetadataFetch` classes will
retrieve metadata from their corresponding repositories and return
`AnyModelRepoMetadata` objects. Their base class
`ModelMetadataFetchBase` is an abstract class that defines two
methods: `from_url()` and `from_id()`. The former accepts the type of
model URLs that the user will try to cut and paste into the model
import form. The latter accepts a string ID in the format recognized
by the repository of choice. Both methods return an
`AnyModelRepoMetadata`.
The base class also has a class method `from_json()` which will take
the JSON representation of a `ModelMetadata` object, validate it, and
return the corresponding `AnyModelRepoMetadata` object.
When initializing one of the metadata fetching classes, you may
provide a `requests.Session` argument. This allows you to customize
the low-level HTTP fetch requests and is used, for instance, in the
testing suite to avoid hitting the internet.
The HuggingFace and Civitai fetcher subclasses add additional
repo-specific fetching methods:
#### HuggingFaceMetadataFetch
This overrides its base class `from_json()` method to return a
`HuggingFaceMetadata` object directly.
#### CivitaiMetadataFetch
This adds the following methods:
`from_civitai_modelid()` This takes the ID of a model, finds the
default version of the model, and then retrieves the metadata for
that version, returning a `CivitaiMetadata` object directly.
`from_civitai_versionid()` This takes the ID of a model version and
retrieves its metadata. Functionally equivalent to `from_id()`, the
only difference is that it returna a `CivitaiMetadata` object rather
than an `AnyModelRepoMetadata`.
### Metadata Storage
The `ModelMetadataStore` provides a simple facility to store model
metadata in the `invokeai.db` database. The data is stored as a JSON
blob, with a few common fields (`name`, `author`, `tags`) broken out
to be searchable.
When a metadata object is saved to the database, it is identified
using the model key, _and this key must correspond to an existing
model key in the model_config table_. There is a foreign key integrity
constraint between the `model_config.id` field and the
`model_metadata.id` field such that if you attempt to save metadata
under an unknown key, the attempt will result in an
`UnknownModelException`. Likewise, when a model is deleted from
`model_config`, the deletion of the corresponding metadata record will
be triggered.
Tags are stored in a normalized fashion in the tables `model_tags` and
`tags`. Triggers keep the tag table in sync with the `model_metadata`
table.
To create the storage object, initialize it with the InvokeAI
`SqliteDatabase` object. This is often done this way:
```
from invokeai.app.api.dependencies import ApiDependencies
metadata_store = ModelMetadataStore(ApiDependencies.invoker.services.db)
```
You can then access the storage with the following methods:
#### `add_metadata(key, metadata)`
Add the metadata using a previously-defined model key.
There is currently no `delete_metadata()` method. The metadata will
persist until the matching config is deleted from the `model_config`
table.
#### `get_metadata(key) -> AnyModelRepoMetadata`
Retrieve the metadata corresponding to the model key.
#### `update_metadata(key, new_metadata)`
Update an existing metadata record with new metadata.
#### `search_by_tag(tags: Set[str]) -> Set[str]`
Given a set of tags, find models that are tagged with them. If
multiple tags are provided then a matching model must be tagged with
*all* the tags in the set. This method returns a set of model keys and
is intended to be used in conjunction with the `ModelRecordService`:
```
model_config_store = ApiDependencies.invoker.services.model_records
matches = metadata_store.search_by_tag({'license:other'})
models = [model_config_store.get(x) for x in matches]
```
#### `search_by_name(name: str) -> Set[str]
Find all model metadata records that have the given name and return a
set of keys to the corresponding model config objects.
#### `search_by_author(author: str) -> Set[str]
Find all model metadata records that have the given author and return
a set of keys to the corresponding model config objects.

View File

@ -46,17 +46,18 @@ We encourage you to ping @psychedelicious and @blessedcoolant on [Discord](http
```bash
node --version
```
2. Install [yarn classic](https://classic.yarnpkg.com/lang/en/) and confirm it is installed by running this:
2. Install [pnpm](https://pnpm.io/) and confirm it is installed by running this:
```bash
npm install --global yarn
yarn --version
npm install --global pnpm
pnpm --version
```
From `invokeai/frontend/web/` run `yarn install` to get everything set up.
From `invokeai/frontend/web/` run `pnpm install` to get everything set up.
Start everything in dev mode:
1. Ensure your virtual environment is running
2. Start the dev server: `yarn dev`
2. Start the dev server: `pnpm dev`
3. Start the InvokeAI Nodes backend: `python scripts/invokeai-web.py # run from the repo root`
4. Point your browser to the dev server address e.g. [http://localhost:5173/](http://localhost:5173/)
@ -72,4 +73,4 @@ For a number of technical and logistical reasons, we need to commit UI build art
If you submit a PR, there is a good chance we will ask you to include a separate commit with a build of the app.
To build for production, run `yarn build`.
To build for production, run `pnpm build`.

View File

@ -13,6 +13,7 @@ If you'd prefer, you can also just download the whole node folder from the linke
To use a community workflow, download the the `.json` node graph file and load it into Invoke AI via the **Load Workflow** button in the Workflow Editor.
- Community Nodes
+ [Adapters-Linked](#adapters-linked-nodes)
+ [Average Images](#average-images)
+ [Clean Image Artifacts After Cut](#clean-image-artifacts-after-cut)
+ [Close Color Mask](#close-color-mask)
@ -32,8 +33,9 @@ To use a community workflow, download the the `.json` node graph file and load i
+ [Image Resize Plus](#image-resize-plus)
+ [Load Video Frame](#load-video-frame)
+ [Make 3D](#make-3d)
+ [Mask Operations](#mask-operations)
+ [Mask Operations](#mask-operations)
+ [Match Histogram](#match-histogram)
+ [Metadata-Linked](#metadata-linked-nodes)
+ [Negative Image](#negative-image)
+ [Oobabooga](#oobabooga)
+ [Prompt Tools](#prompt-tools)
@ -51,6 +53,19 @@ To use a community workflow, download the the `.json` node graph file and load i
- [Help](#help)
--------------------------------
### Adapters Linked Nodes
**Description:** A set of nodes for linked adapters (ControlNet, IP-Adaptor & T2I-Adapter). This allows multiple adapters to be chained together without using a `collect` node which means it can be used inside an `iterate` node without any collecting on every iteration issues.
- `ControlNet-Linked` - Collects ControlNet info to pass to other nodes.
- `IP-Adapter-Linked` - Collects IP-Adapter info to pass to other nodes.
- `T2I-Adapter-Linked` - Collects T2I-Adapter info to pass to other nodes.
Note: These are inherited from the core nodes so any update to the core nodes should be reflected in these.
**Node Link:** https://github.com/skunkworxdark/adapters-linked-nodes
--------------------------------
### Average Images
@ -307,6 +322,20 @@ See full docs here: https://github.com/skunkworxdark/Prompt-tools-nodes/edit/mai
<img src="https://github.com/skunkworxdark/match_histogram/assets/21961335/ed12f329-a0ef-444a-9bae-129ed60d6097" width="300" />
--------------------------------
### Metadata Linked Nodes
**Description:** A set of nodes for Metadata. Collect Metadata from within an `iterate` node & extract metadata from an image.
- `Metadata Item Linked` - Allows collecting of metadata while within an iterate node with no need for a collect node or conversion to metadata node.
- `Metadata From Image` - Provides Metadata from an image.
- `Metadata To String` - Extracts a String value of a label from metadata.
- `Metadata To Integer` - Extracts an Integer value of a label from metadata.
- `Metadata To Float` - Extracts a Float value of a label from metadata.
- `Metadata To Scheduler` - Extracts a Scheduler value of a label from metadata.
**Node Link:** https://github.com/skunkworxdark/metadata-linked-nodes
--------------------------------
### Negative Image

View File

@ -91,9 +91,11 @@ rm -rf InvokeAI-Installer
# copy content
mkdir InvokeAI-Installer
for f in templates lib *.txt *.reg; do
for f in templates *.txt *.reg; do
cp -r ${f} InvokeAI-Installer/
done
mkdir InvokeAI-Installer/lib
cp lib/*.py InvokeAI-Installer/lib
# Move the wheel
mv dist/*.whl InvokeAI-Installer/lib/
@ -111,6 +113,6 @@ cp WinLongPathsEnabled.reg InvokeAI-Installer/
zip -r InvokeAI-installer-$VERSION.zip InvokeAI-Installer
# clean up
rm -rf InvokeAI-Installer tmp dist
rm -rf InvokeAI-Installer tmp dist ../invokeai/frontend/web/dist/
exit 0

View File

@ -11,6 +11,7 @@ from ..services.board_images.board_images_default import BoardImagesService
from ..services.board_records.board_records_sqlite import SqliteBoardRecordStorage
from ..services.boards.boards_default import BoardService
from ..services.config import InvokeAIAppConfig
from ..services.download import DownloadQueueService
from ..services.image_files.image_files_disk import DiskImageFileStorage
from ..services.image_records.image_records_sqlite import SqliteImageRecordStorage
from ..services.images.images_default import ImageService
@ -29,8 +30,7 @@ from ..services.model_records import ModelRecordServiceSQL
from ..services.names.names_default import SimpleNameService
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
from ..services.shared.default_graphs import create_system_graphs
from ..services.shared.graph import GraphExecutionState, LibraryGraph
from ..services.shared.graph import GraphExecutionState
from ..services.urls.urls_default import LocalUrlService
from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
from .events import FastAPIEventService
@ -80,13 +80,13 @@ class ApiDependencies:
boards = BoardService()
events = FastAPIEventService(event_handler_id)
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")
graph_library = SqliteItemStorage[LibraryGraph](db=db, table_name="graphs")
image_records = SqliteImageRecordStorage(db=db)
images = ImageService()
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
model_manager = ModelManagerService(config, logger)
model_record_service = ModelRecordServiceSQL(db=db)
download_queue_service = DownloadQueueService(event_bus=events)
model_install_service = ModelInstallService(
app_config=config, record_store=model_record_service, event_bus=events
)
@ -107,7 +107,6 @@ class ApiDependencies:
configuration=configuration,
events=events,
graph_execution_manager=graph_execution_manager,
graph_library=graph_library,
image_files=image_files,
image_records=image_records,
images=images,
@ -116,6 +115,7 @@ class ApiDependencies:
logger=logger,
model_manager=model_manager,
model_records=model_record_service,
download_queue=download_queue_service,
model_install=model_install_service,
names=names,
performance_statistics=performance_statistics,
@ -127,8 +127,6 @@ class ApiDependencies:
workflow_records=workflow_records,
)
create_system_graphs(services.graph_library)
ApiDependencies.invoker = Invoker(services)
db.clean()

View File

@ -0,0 +1,111 @@
# Copyright (c) 2023 Lincoln D. Stein
"""FastAPI route for the download queue."""
from typing import List, Optional
from fastapi import Body, Path, Response
from fastapi.routing import APIRouter
from pydantic.networks import AnyHttpUrl
from starlette.exceptions import HTTPException
from invokeai.app.services.download import (
DownloadJob,
UnknownJobIDException,
)
from ..dependencies import ApiDependencies
download_queue_router = APIRouter(prefix="/v1/download_queue", tags=["download_queue"])
@download_queue_router.get(
"/",
operation_id="list_downloads",
)
async def list_downloads() -> List[DownloadJob]:
"""Get a list of active and inactive jobs."""
queue = ApiDependencies.invoker.services.download_queue
return queue.list_jobs()
@download_queue_router.patch(
"/",
operation_id="prune_downloads",
responses={
204: {"description": "All completed jobs have been pruned"},
400: {"description": "Bad request"},
},
)
async def prune_downloads():
"""Prune completed and errored jobs."""
queue = ApiDependencies.invoker.services.download_queue
queue.prune_jobs()
return Response(status_code=204)
@download_queue_router.post(
"/i/",
operation_id="download",
)
async def download(
source: AnyHttpUrl = Body(description="download source"),
dest: str = Body(description="download destination"),
priority: int = Body(default=10, description="queue priority"),
access_token: Optional[str] = Body(default=None, description="token for authorization to download"),
) -> DownloadJob:
"""Download the source URL to the file or directory indicted in dest."""
queue = ApiDependencies.invoker.services.download_queue
return queue.download(source, dest, priority, access_token)
@download_queue_router.get(
"/i/{id}",
operation_id="get_download_job",
responses={
200: {"description": "Success"},
404: {"description": "The requested download JobID could not be found"},
},
)
async def get_download_job(
id: int = Path(description="ID of the download job to fetch."),
) -> DownloadJob:
"""Get a download job using its ID."""
try:
job = ApiDependencies.invoker.services.download_queue.id_to_job(id)
return job
except UnknownJobIDException as e:
raise HTTPException(status_code=404, detail=str(e))
@download_queue_router.delete(
"/i/{id}",
operation_id="cancel_download_job",
responses={
204: {"description": "Job has been cancelled"},
404: {"description": "The requested download JobID could not be found"},
},
)
async def cancel_download_job(
id: int = Path(description="ID of the download job to cancel."),
):
"""Cancel a download job using its ID."""
try:
queue = ApiDependencies.invoker.services.download_queue
job = queue.id_to_job(id)
queue.cancel_job(job)
return Response(status_code=204)
except UnknownJobIDException as e:
raise HTTPException(status_code=404, detail=str(e))
@download_queue_router.delete(
"/i",
operation_id="cancel_all_download_jobs",
responses={
204: {"description": "Download jobs have been cancelled"},
},
)
async def cancel_all_download_jobs():
"""Cancel all download jobs."""
ApiDependencies.invoker.services.download_queue.cancel_all_jobs()
return Response(status_code=204)

View File

@ -26,7 +26,7 @@ from invokeai.backend.model_manager.config import (
from ..dependencies import ApiDependencies
model_records_router = APIRouter(prefix="/v1/model/record", tags=["model_manager_v2"])
model_records_router = APIRouter(prefix="/v1/model/record", tags=["model_manager_v2_unstable"])
class ModelsList(BaseModel):

View File

@ -45,6 +45,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
app_info,
board_images,
boards,
download_queue,
images,
model_records,
models,
@ -116,6 +117,7 @@ app.include_router(sessions.session_router, prefix="/api")
app.include_router(utilities.utilities_router, prefix="/api")
app.include_router(models.models_router, prefix="/api")
app.include_router(model_records.model_records_router, prefix="/api")
app.include_router(download_queue.download_queue_router, prefix="/api")
app.include_router(images.images_router, prefix="/api")
app.include_router(boards.boards_router, prefix="/api")
app.include_router(board_images.board_images_router, prefix="/api")

View File

@ -1,4 +1,3 @@
import re
from dataclasses import dataclass
from typing import List, Optional, Union
@ -17,6 +16,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
from ...backend.model_management.lora import ModelPatcher
from ...backend.model_management.models import ModelNotFoundException, ModelType
from ...backend.util.devices import torch_dtype
from ..util.ti_utils import extract_ti_triggers_from_prompt
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
@ -87,7 +87,7 @@ class CompelInvocation(BaseInvocation):
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
ti_list = []
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
for trigger in extract_ti_triggers_from_prompt(self.prompt):
name = trigger[1:-1]
try:
ti_list.append(
@ -210,7 +210,7 @@ class SDXLPromptInvocationBase:
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
ti_list = []
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
for trigger in extract_ti_triggers_from_prompt(prompt):
name = trigger[1:-1]
try:
ti_list.append(

View File

@ -1,7 +1,6 @@
# Copyright (c) 2023 Borisov Sergey (https://github.com/StAlKeR7779)
import inspect
import re
# from contextlib import ExitStack
from typing import List, Literal, Union
@ -21,6 +20,7 @@ from invokeai.backend import BaseModelType, ModelType, SubModelType
from ...backend.model_management import ONNXModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.util import choose_torch_device
from ..util.ti_utils import extract_ti_triggers_from_prompt
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
@ -78,7 +78,7 @@ class ONNXPromptInvocation(BaseInvocation):
]
ti_list = []
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
for trigger in extract_ti_triggers_from_prompt(self.prompt):
name = trigger[1:-1]
try:
ti_list.append(

View File

@ -77,7 +77,7 @@ class CalculateImageTilesInvocation(BaseInvocation):
title="Calculate Image Tiles Even Split",
tags=["tiles"],
category="tiles",
version="1.0.0",
version="1.1.0",
classification=Classification.Beta,
)
class CalculateImageTilesEvenSplitInvocation(BaseInvocation):
@ -97,11 +97,11 @@ class CalculateImageTilesEvenSplitInvocation(BaseInvocation):
ge=1,
description="Number of tiles to divide image into on the y axis",
)
overlap_fraction: float = InputField(
default=0.25,
overlap: int = InputField(
default=128,
ge=0,
lt=1,
description="Overlap between adjacent tiles as a fraction of the tile's dimensions (0-1)",
multiple_of=8,
description="The overlap, in pixels, between adjacent tiles.",
)
def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput:
@ -110,7 +110,7 @@ class CalculateImageTilesEvenSplitInvocation(BaseInvocation):
image_width=self.image_width,
num_tiles_x=self.num_tiles_x,
num_tiles_y=self.num_tiles_y,
overlap_fraction=self.overlap_fraction,
overlap=self.overlap,
)
return CalculateImageTilesOutput(tiles=tiles)

View File

@ -356,7 +356,7 @@ class InvokeAIAppConfig(InvokeAISettings):
else:
root = self.find_root().expanduser().absolute()
self.root = root # insulate ourselves from relative paths that may change
return root
return root.resolve()
@property
def root_dir(self) -> Path:

View File

@ -0,0 +1,12 @@
"""Init file for download queue."""
from .download_base import DownloadJob, DownloadJobStatus, DownloadQueueServiceBase, UnknownJobIDException
from .download_default import DownloadQueueService, TqdmProgress
__all__ = [
"DownloadJob",
"DownloadQueueServiceBase",
"DownloadQueueService",
"TqdmProgress",
"DownloadJobStatus",
"UnknownJobIDException",
]

View File

@ -0,0 +1,217 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""Model download service."""
from abc import ABC, abstractmethod
from enum import Enum
from functools import total_ordering
from pathlib import Path
from typing import Any, Callable, List, Optional
from pydantic import BaseModel, Field, PrivateAttr
from pydantic.networks import AnyHttpUrl
class DownloadJobStatus(str, Enum):
"""State of a download job."""
WAITING = "waiting" # not enqueued, will not run
RUNNING = "running" # actively downloading
COMPLETED = "completed" # finished running
CANCELLED = "cancelled" # user cancelled
ERROR = "error" # terminated with an error message
class DownloadJobCancelledException(Exception):
"""This exception is raised when a download job is cancelled."""
class UnknownJobIDException(Exception):
"""This exception is raised when an invalid job id is referened."""
class ServiceInactiveException(Exception):
"""This exception is raised when user attempts to initiate a download before the service is started."""
DownloadEventHandler = Callable[["DownloadJob"], None]
@total_ordering
class DownloadJob(BaseModel):
"""Class to monitor and control a model download request."""
# required variables to be passed in on creation
source: AnyHttpUrl = Field(description="Where to download from. Specific types specified in child classes.")
dest: Path = Field(description="Destination of downloaded model on local disk; a directory or file path")
access_token: Optional[str] = Field(default=None, description="authorization token for protected resources")
# automatically assigned on creation
id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel
priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
# set internally during download process
status: DownloadJobStatus = Field(default=DownloadJobStatus.WAITING, description="Status of the download")
download_path: Optional[Path] = Field(default=None, description="Final location of downloaded file")
job_started: Optional[str] = Field(default=None, description="Timestamp for when the download job started")
job_ended: Optional[str] = Field(
default=None, description="Timestamp for when the download job ende1d (completed or errored)"
)
bytes: int = Field(default=0, description="Bytes downloaded so far")
total_bytes: int = Field(default=0, description="Total file size (bytes)")
# set when an error occurs
error_type: Optional[str] = Field(default=None, description="Name of exception that caused an error")
error: Optional[str] = Field(default=None, description="Traceback of the exception that caused an error")
# internal flag
_cancelled: bool = PrivateAttr(default=False)
# optional event handlers passed in on creation
_on_start: Optional[DownloadEventHandler] = PrivateAttr(default=None)
_on_progress: Optional[DownloadEventHandler] = PrivateAttr(default=None)
_on_complete: Optional[DownloadEventHandler] = PrivateAttr(default=None)
_on_cancelled: Optional[DownloadEventHandler] = PrivateAttr(default=None)
_on_error: Optional[DownloadEventHandler] = PrivateAttr(default=None)
def __le__(self, other: "DownloadJob") -> bool:
"""Return True if this job's priority is less than another's."""
return self.priority <= other.priority
def cancel(self) -> None:
"""Call to cancel the job."""
self._cancelled = True
# cancelled and the callbacks are private attributes in order to prevent
# them from being serialized and/or used in the Json Schema
@property
def cancelled(self) -> bool:
"""Call to cancel the job."""
return self._cancelled
@property
def on_start(self) -> Optional[DownloadEventHandler]:
"""Return the on_start event handler."""
return self._on_start
@property
def on_progress(self) -> Optional[DownloadEventHandler]:
"""Return the on_progress event handler."""
return self._on_progress
@property
def on_complete(self) -> Optional[DownloadEventHandler]:
"""Return the on_complete event handler."""
return self._on_complete
@property
def on_error(self) -> Optional[DownloadEventHandler]:
"""Return the on_error event handler."""
return self._on_error
@property
def on_cancelled(self) -> Optional[DownloadEventHandler]:
"""Return the on_cancelled event handler."""
return self._on_cancelled
def set_callbacks(
self,
on_start: Optional[DownloadEventHandler] = None,
on_progress: Optional[DownloadEventHandler] = None,
on_complete: Optional[DownloadEventHandler] = None,
on_cancelled: Optional[DownloadEventHandler] = None,
on_error: Optional[DownloadEventHandler] = None,
) -> None:
"""Set the callbacks for download events."""
self._on_start = on_start
self._on_progress = on_progress
self._on_complete = on_complete
self._on_error = on_error
self._on_cancelled = on_cancelled
class DownloadQueueServiceBase(ABC):
"""Multithreaded queue for downloading models via URL."""
@abstractmethod
def start(self, *args: Any, **kwargs: Any) -> None:
"""Start the download worker threads."""
@abstractmethod
def stop(self, *args: Any, **kwargs: Any) -> None:
"""Stop the download worker threads."""
@abstractmethod
def download(
self,
source: AnyHttpUrl,
dest: Path,
priority: int = 10,
access_token: Optional[str] = None,
on_start: Optional[DownloadEventHandler] = None,
on_progress: Optional[DownloadEventHandler] = None,
on_complete: Optional[DownloadEventHandler] = None,
on_cancelled: Optional[DownloadEventHandler] = None,
on_error: Optional[DownloadEventHandler] = None,
) -> DownloadJob:
"""
Create a download job.
:param source: Source of the download as a URL.
:param dest: Path to download to. See below.
:param on_start, on_progress, on_complete, on_error: Callbacks for the indicated
events.
:returns: A DownloadJob object for monitoring the state of the download.
The `dest` argument is a Path object. Its behavior is:
1. If the path exists and is a directory, then the URL contents will be downloaded
into that directory using the filename indicated in the response's `Content-Disposition` field.
If no content-disposition is present, then the last component of the URL will be used (similar to
wget's behavior).
2. If the path does not exist, then it is taken as the name of a new file to create with the downloaded
content.
3. If the path exists and is an existing file, then the downloader will try to resume the download from
the end of the existing file.
"""
pass
@abstractmethod
def list_jobs(self) -> List[DownloadJob]:
"""
List active download jobs.
:returns List[DownloadJob]: List of download jobs whose state is not "completed."
"""
pass
@abstractmethod
def id_to_job(self, id: int) -> DownloadJob:
"""
Return the DownloadJob corresponding to the integer ID.
:param id: ID of the DownloadJob.
Exceptions:
* UnknownJobIDException
"""
pass
@abstractmethod
def cancel_all_jobs(self):
"""Cancel all active and enquedjobs."""
pass
@abstractmethod
def prune_jobs(self):
"""Prune completed and errored queue items from the job list."""
pass
@abstractmethod
def cancel_job(self, job: DownloadJob):
"""Cancel the job, clearing partial downloads and putting it into ERROR state."""
pass
@abstractmethod
def join(self):
"""Wait until all jobs are off the queue."""
pass

View File

@ -0,0 +1,418 @@
# Copyright (c) 2023, Lincoln D. Stein
"""Implementation of multithreaded download queue for invokeai."""
import os
import re
import threading
import traceback
from logging import Logger
from pathlib import Path
from queue import Empty, PriorityQueue
from typing import Any, Dict, List, Optional, Set
import requests
from pydantic.networks import AnyHttpUrl
from requests import HTTPError
from tqdm import tqdm
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.util.misc import get_iso_timestamp
from invokeai.backend.util.logging import InvokeAILogger
from .download_base import (
DownloadEventHandler,
DownloadJob,
DownloadJobCancelledException,
DownloadJobStatus,
DownloadQueueServiceBase,
ServiceInactiveException,
UnknownJobIDException,
)
# Maximum number of bytes to download during each call to requests.iter_content()
DOWNLOAD_CHUNK_SIZE = 100000
class DownloadQueueService(DownloadQueueServiceBase):
"""Class for queued download of models."""
_jobs: Dict[int, DownloadJob]
_max_parallel_dl: int = 5
_worker_pool: Set[threading.Thread]
_queue: PriorityQueue[DownloadJob]
_stop_event: threading.Event
_lock: threading.Lock
_logger: Logger
_events: Optional[EventServiceBase] = None
_next_job_id: int = 0
_accept_download_requests: bool = False
_requests: requests.sessions.Session
def __init__(
self,
max_parallel_dl: int = 5,
event_bus: Optional[EventServiceBase] = None,
requests_session: Optional[requests.sessions.Session] = None,
):
"""
Initialize DownloadQueue.
:param max_parallel_dl: Number of simultaneous downloads allowed [5].
:param requests_session: Optional requests.sessions.Session object, for unit tests.
"""
self._jobs = {}
self._next_job_id = 0
self._queue = PriorityQueue()
self._stop_event = threading.Event()
self._worker_pool = set()
self._lock = threading.Lock()
self._logger = InvokeAILogger.get_logger("DownloadQueueService")
self._event_bus = event_bus
self._requests = requests_session or requests.Session()
self._accept_download_requests = False
self._max_parallel_dl = max_parallel_dl
def start(self, *args: Any, **kwargs: Any) -> None:
"""Start the download worker threads."""
with self._lock:
if self._worker_pool:
raise Exception("Attempt to start the download service twice")
self._stop_event.clear()
self._start_workers(self._max_parallel_dl)
self._accept_download_requests = True
def stop(self, *args: Any, **kwargs: Any) -> None:
"""Stop the download worker threads."""
with self._lock:
if not self._worker_pool:
raise Exception("Attempt to stop the download service before it was started")
self._accept_download_requests = False # reject attempts to add new jobs to queue
queued_jobs = [x for x in self.list_jobs() if x.status == DownloadJobStatus.WAITING]
active_jobs = [x for x in self.list_jobs() if x.status == DownloadJobStatus.RUNNING]
if queued_jobs:
self._logger.warning(f"Cancelling {len(queued_jobs)} queued downloads")
if active_jobs:
self._logger.info(f"Waiting for {len(active_jobs)} active download jobs to complete")
with self._queue.mutex:
self._queue.queue.clear()
self.join() # wait for all active jobs to finish
self._stop_event.set()
self._worker_pool.clear()
def download(
self,
source: AnyHttpUrl,
dest: Path,
priority: int = 10,
access_token: Optional[str] = None,
on_start: Optional[DownloadEventHandler] = None,
on_progress: Optional[DownloadEventHandler] = None,
on_complete: Optional[DownloadEventHandler] = None,
on_cancelled: Optional[DownloadEventHandler] = None,
on_error: Optional[DownloadEventHandler] = None,
) -> DownloadJob:
"""Create a download job and return its ID."""
if not self._accept_download_requests:
raise ServiceInactiveException(
"The download service is not currently accepting requests. Please call start() to initialize the service."
)
with self._lock:
id = self._next_job_id
self._next_job_id += 1
job = DownloadJob(
id=id,
source=source,
dest=dest,
priority=priority,
access_token=access_token,
)
job.set_callbacks(
on_start=on_start,
on_progress=on_progress,
on_complete=on_complete,
on_cancelled=on_cancelled,
on_error=on_error,
)
self._jobs[id] = job
self._queue.put(job)
return job
def join(self) -> None:
"""Wait for all jobs to complete."""
self._queue.join()
def list_jobs(self) -> List[DownloadJob]:
"""List all the jobs."""
return list(self._jobs.values())
def prune_jobs(self) -> None:
"""Prune completed and errored queue items from the job list."""
with self._lock:
to_delete = set()
for job_id, job in self._jobs.items():
if self._in_terminal_state(job):
to_delete.add(job_id)
for job_id in to_delete:
del self._jobs[job_id]
def id_to_job(self, id: int) -> DownloadJob:
"""Translate a job ID into a DownloadJob object."""
try:
return self._jobs[id]
except KeyError as excp:
raise UnknownJobIDException("Unrecognized job") from excp
def cancel_job(self, job: DownloadJob) -> None:
"""
Cancel the indicated job.
If it is running it will be stopped.
job.status will be set to DownloadJobStatus.CANCELLED
"""
with self._lock:
job.cancel()
def cancel_all_jobs(self, preserve_partial: bool = False) -> None:
"""Cancel all jobs (those not in enqueued, running or paused state)."""
for job in self._jobs.values():
if not self._in_terminal_state(job):
self.cancel_job(job)
def _in_terminal_state(self, job: DownloadJob) -> bool:
return job.status in [
DownloadJobStatus.COMPLETED,
DownloadJobStatus.CANCELLED,
DownloadJobStatus.ERROR,
]
def _start_workers(self, max_workers: int) -> None:
"""Start the requested number of worker threads."""
self._stop_event.clear()
for i in range(0, max_workers): # noqa B007
worker = threading.Thread(target=self._download_next_item, daemon=True)
self._logger.debug(f"Download queue worker thread {worker.name} starting.")
worker.start()
self._worker_pool.add(worker)
def _download_next_item(self) -> None:
"""Worker thread gets next job on priority queue."""
done = False
while not done:
if self._stop_event.is_set():
done = True
continue
try:
job = self._queue.get(timeout=1)
except Empty:
continue
try:
job.job_started = get_iso_timestamp()
self._do_download(job)
self._signal_job_complete(job)
except (OSError, HTTPError) as excp:
job.error_type = excp.__class__.__name__ + f"({str(excp)})"
job.error = traceback.format_exc()
self._signal_job_error(job)
except DownloadJobCancelledException:
self._signal_job_cancelled(job)
self._cleanup_cancelled_job(job)
finally:
job.job_ended = get_iso_timestamp()
self._queue.task_done()
self._logger.debug(f"Download queue worker thread {threading.current_thread().name} exiting.")
def _do_download(self, job: DownloadJob) -> None:
"""Do the actual download."""
url = job.source
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
open_mode = "wb"
# Make a streaming request. This will retrieve headers including
# content-length and content-disposition, but not fetch any content itself
resp = self._requests.get(str(url), headers=header, stream=True)
if not resp.ok:
raise HTTPError(resp.reason)
content_length = int(resp.headers.get("content-length", 0))
job.total_bytes = content_length
if job.dest.is_dir():
file_name = os.path.basename(str(url.path)) # default is to use the last bit of the URL
if match := re.search('filename="(.+)"', resp.headers.get("Content-Disposition", "")):
remote_name = match.group(1)
if self._validate_filename(job.dest.as_posix(), remote_name):
file_name = remote_name
job.download_path = job.dest / file_name
else:
job.dest.parent.mkdir(parents=True, exist_ok=True)
job.download_path = job.dest
assert job.download_path
# Don't clobber an existing file. See commit 82c2c85202f88c6d24ff84710f297cfc6ae174af
# for code that instead resumes an interrupted download.
if job.download_path.exists():
raise OSError(f"[Errno 17] File {job.download_path} exists")
# append ".downloading" to the path
in_progress_path = self._in_progress_path(job.download_path)
# signal caller that the download is starting. At this point, key fields such as
# download_path and total_bytes will be populated. We call it here because the might
# discover that the local file is already complete and generate a COMPLETED status.
self._signal_job_started(job)
# "range not satisfiable" - local file is at least as large as the remote file
if resp.status_code == 416 or (content_length > 0 and job.bytes >= content_length):
self._logger.warning(f"{job.download_path}: complete file found. Skipping.")
return
# "partial content" - local file is smaller than remote file
elif resp.status_code == 206 or job.bytes > 0:
self._logger.warning(f"{job.download_path}: partial file found. Resuming")
# some other error
elif resp.status_code != 200:
raise HTTPError(resp.reason)
self._logger.debug(f"{job.source}: Downloading {job.download_path}")
report_delta = job.total_bytes / 100 # report every 1% change
last_report_bytes = 0
# DOWNLOAD LOOP
with open(in_progress_path, open_mode) as file:
for data in resp.iter_content(chunk_size=DOWNLOAD_CHUNK_SIZE):
if job.cancelled:
raise DownloadJobCancelledException("Job was cancelled at caller's request")
job.bytes += file.write(data)
if (job.bytes - last_report_bytes >= report_delta) or (job.bytes >= job.total_bytes):
last_report_bytes = job.bytes
self._signal_job_progress(job)
# if we get here we are done and can rename the file to the original dest
in_progress_path.rename(job.download_path)
def _validate_filename(self, directory: str, filename: str) -> bool:
pc_name_max = os.pathconf(directory, "PC_NAME_MAX") if hasattr(os, "pathconf") else 260 # hardcoded for windows
pc_path_max = (
os.pathconf(directory, "PC_PATH_MAX") if hasattr(os, "pathconf") else 32767
) # hardcoded for windows with long names enabled
if "/" in filename:
return False
if filename.startswith(".."):
return False
if len(filename) > pc_name_max:
return False
if len(os.path.join(directory, filename)) > pc_path_max:
return False
return True
def _in_progress_path(self, path: Path) -> Path:
return path.with_name(path.name + ".downloading")
def _signal_job_started(self, job: DownloadJob) -> None:
job.status = DownloadJobStatus.RUNNING
if job.on_start:
try:
job.on_start(job)
except Exception as e:
self._logger.error(e)
if self._event_bus:
assert job.download_path
self._event_bus.emit_download_started(str(job.source), job.download_path.as_posix())
def _signal_job_progress(self, job: DownloadJob) -> None:
if job.on_progress:
try:
job.on_progress(job)
except Exception as e:
self._logger.error(e)
if self._event_bus:
assert job.download_path
self._event_bus.emit_download_progress(
str(job.source),
download_path=job.download_path.as_posix(),
current_bytes=job.bytes,
total_bytes=job.total_bytes,
)
def _signal_job_complete(self, job: DownloadJob) -> None:
job.status = DownloadJobStatus.COMPLETED
if job.on_complete:
try:
job.on_complete(job)
except Exception as e:
self._logger.error(e)
if self._event_bus:
assert job.download_path
self._event_bus.emit_download_complete(
str(job.source), download_path=job.download_path.as_posix(), total_bytes=job.total_bytes
)
def _signal_job_cancelled(self, job: DownloadJob) -> None:
job.status = DownloadJobStatus.CANCELLED
if job.on_cancelled:
try:
job.on_cancelled(job)
except Exception as e:
self._logger.error(e)
if self._event_bus:
self._event_bus.emit_download_cancelled(str(job.source))
def _signal_job_error(self, job: DownloadJob) -> None:
job.status = DownloadJobStatus.ERROR
if job.on_error:
try:
job.on_error(job)
except Exception as e:
self._logger.error(e)
if self._event_bus:
assert job.error_type
assert job.error
self._event_bus.emit_download_error(str(job.source), error_type=job.error_type, error=job.error)
def _cleanup_cancelled_job(self, job: DownloadJob) -> None:
self._logger.warning(f"Cleaning up leftover files from cancelled download job {job.download_path}")
try:
if job.download_path:
partial_file = self._in_progress_path(job.download_path)
partial_file.unlink()
except OSError as excp:
self._logger.warning(excp)
# Example on_progress event handler to display a TQDM status bar
# Activate with:
# download_service.download('http://foo.bar/baz', '/tmp', on_progress=TqdmProgress().job_update
class TqdmProgress(object):
"""TQDM-based progress bar object to use in on_progress handlers."""
_bars: Dict[int, tqdm] # the tqdm object
_last: Dict[int, int] # last bytes downloaded
def __init__(self) -> None: # noqa D107
self._bars = {}
self._last = {}
def update(self, job: DownloadJob) -> None: # noqa D102
job_id = job.id
# new job
if job_id not in self._bars:
assert job.download_path
dest = Path(job.download_path).name
self._bars[job_id] = tqdm(
desc=dest,
initial=0,
total=job.total_bytes,
unit="iB",
unit_scale=True,
)
self._last[job_id] = 0
self._bars[job_id].update(job.bytes - self._last[job_id])
self._last[job_id] = job.bytes

View File

@ -17,6 +17,7 @@ from invokeai.backend.model_management.models.base import BaseModelType, ModelTy
class EventServiceBase:
queue_event: str = "queue_event"
download_event: str = "download_event"
model_event: str = "model_event"
"""Basic event bus, to have an empty stand-in when not needed"""
@ -32,6 +33,13 @@ class EventServiceBase:
payload={"event": event_name, "data": payload},
)
def __emit_download_event(self, event_name: str, payload: dict) -> None:
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.download_event,
payload={"event": event_name, "data": payload},
)
def __emit_model_event(self, event_name: str, payload: dict) -> None:
payload["timestamp"] = get_timestamp()
self.dispatch(
@ -323,6 +331,79 @@ class EventServiceBase:
payload={"queue_id": queue_id},
)
def emit_download_started(self, source: str, download_path: str) -> None:
"""
Emit when a download job is started.
:param url: The downloaded url
"""
self.__emit_download_event(
event_name="download_started",
payload={"source": source, "download_path": download_path},
)
def emit_download_progress(self, source: str, download_path: str, current_bytes: int, total_bytes: int) -> None:
"""
Emit "download_progress" events at regular intervals during a download job.
:param source: The downloaded source
:param download_path: The local downloaded file
:param current_bytes: Number of bytes downloaded so far
:param total_bytes: The size of the file being downloaded (if known)
"""
self.__emit_download_event(
event_name="download_progress",
payload={
"source": source,
"download_path": download_path,
"current_bytes": current_bytes,
"total_bytes": total_bytes,
},
)
def emit_download_complete(self, source: str, download_path: str, total_bytes: int) -> None:
"""
Emit a "download_complete" event at the end of a successful download.
:param source: Source URL
:param download_path: Path to the locally downloaded file
:param total_bytes: The size of the downloaded file
"""
self.__emit_download_event(
event_name="download_complete",
payload={
"source": source,
"download_path": download_path,
"total_bytes": total_bytes,
},
)
def emit_download_cancelled(self, source: str) -> None:
"""Emit a "download_cancelled" event in the event that the download was cancelled by user."""
self.__emit_download_event(
event_name="download_cancelled",
payload={
"source": source,
},
)
def emit_download_error(self, source: str, error_type: str, error: str) -> None:
"""
Emit a "download_error" event when an download job encounters an exception.
:param source: Source URL
:param error_type: The name of the exception that raised the error
:param error: The traceback from this error
"""
self.__emit_download_event(
event_name="download_error",
payload={
"source": source,
"error_type": error_type,
"error": error,
},
)
def emit_model_install_started(self, source: str) -> None:
"""
Emitted when an install job is started.

View File

@ -11,6 +11,7 @@ if TYPE_CHECKING:
from .board_records.board_records_base import BoardRecordStorageBase
from .boards.boards_base import BoardServiceABC
from .config import InvokeAIAppConfig
from .download import DownloadQueueServiceBase
from .events.events_base import EventServiceBase
from .image_files.image_files_base import ImageFileStorageBase
from .image_records.image_records_base import ImageRecordStorageBase
@ -27,7 +28,7 @@ if TYPE_CHECKING:
from .names.names_base import NameServiceBase
from .session_processor.session_processor_base import SessionProcessorBase
from .session_queue.session_queue_base import SessionQueueBase
from .shared.graph import GraphExecutionState, LibraryGraph
from .shared.graph import GraphExecutionState
from .urls.urls_base import UrlServiceBase
from .workflow_records.workflow_records_base import WorkflowRecordsStorageBase
@ -43,7 +44,6 @@ class InvocationServices:
configuration: "InvokeAIAppConfig"
events: "EventServiceBase"
graph_execution_manager: "ItemStorageABC[GraphExecutionState]"
graph_library: "ItemStorageABC[LibraryGraph]"
images: "ImageServiceABC"
image_records: "ImageRecordStorageBase"
image_files: "ImageFileStorageBase"
@ -51,6 +51,7 @@ class InvocationServices:
logger: "Logger"
model_manager: "ModelManagerServiceBase"
model_records: "ModelRecordServiceBase"
download_queue: "DownloadQueueServiceBase"
model_install: "ModelInstallServiceBase"
processor: "InvocationProcessorABC"
performance_statistics: "InvocationStatsServiceBase"
@ -71,7 +72,6 @@ class InvocationServices:
configuration: "InvokeAIAppConfig",
events: "EventServiceBase",
graph_execution_manager: "ItemStorageABC[GraphExecutionState]",
graph_library: "ItemStorageABC[LibraryGraph]",
images: "ImageServiceABC",
image_files: "ImageFileStorageBase",
image_records: "ImageRecordStorageBase",
@ -79,6 +79,7 @@ class InvocationServices:
logger: "Logger",
model_manager: "ModelManagerServiceBase",
model_records: "ModelRecordServiceBase",
download_queue: "DownloadQueueServiceBase",
model_install: "ModelInstallServiceBase",
processor: "InvocationProcessorABC",
performance_statistics: "InvocationStatsServiceBase",
@ -97,7 +98,6 @@ class InvocationServices:
self.configuration = configuration
self.events = events
self.graph_execution_manager = graph_execution_manager
self.graph_library = graph_library
self.images = images
self.image_files = image_files
self.image_records = image_records
@ -105,6 +105,7 @@ class InvocationServices:
self.logger = logger
self.model_manager = model_manager
self.model_records = model_records
self.download_queue = download_queue
self.model_install = model_install
self.processor = processor
self.performance_statistics = performance_statistics

View File

@ -11,7 +11,6 @@ from typing_extensions import Annotated
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.backend.model_manager import AnyModelConfig
@ -157,12 +156,12 @@ class ModelInstallServiceBase(ABC):
:param event_bus: InvokeAI event bus for reporting events to.
"""
def start(self, invoker: Invoker) -> None:
"""Call at InvokeAI startup time."""
self.sync_to_config()
@abstractmethod
def start(self, *args: Any, **kwarg: Any) -> None:
"""Start the installer service."""
@abstractmethod
def stop(self) -> None:
def stop(self, *args: Any, **kwarg: Any) -> None:
"""Stop the model install service. After this the objection can be safely deleted."""
@property

View File

@ -71,7 +71,6 @@ class ModelInstallService(ModelInstallServiceBase):
self._install_queue = Queue()
self._cached_model_paths = set()
self._models_installed = set()
self._start_installer_thread()
@property
def app_config(self) -> InvokeAIAppConfig: # noqa D102
@ -85,8 +84,13 @@ class ModelInstallService(ModelInstallServiceBase):
def event_bus(self) -> Optional[EventServiceBase]: # noqa D102
return self._event_bus
def stop(self, *args, **kwargs) -> None:
"""Stop the install thread; after this the object can be deleted and garbage collected."""
def start(self, *args: Any, **kwarg: Any) -> None:
"""Start the installer thread."""
self._start_installer_thread()
self.sync_to_config()
def stop(self, *args: Any, **kwarg: Any) -> None:
"""Stop the installer thread; after this the object can be deleted and garbage collected."""
self._install_queue.put(STOP_JOB)
def _start_installer_thread(self) -> None:

View File

@ -5,6 +5,8 @@ from invokeai.app.services.image_files.image_files_base import ImageFileStorageB
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import build_migration_1
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import build_migration_2
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_3 import build_migration_3
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_4 import build_migration_4
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@ -27,6 +29,8 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator = SqliteMigrator(db=db)
migrator.register_migration(build_migration_1())
migrator.register_migration(build_migration_2(image_files=image_files, logger=logger))
migrator.register_migration(build_migration_3(app_config=config, logger=logger))
migrator.register_migration(build_migration_4())
migrator.run_migrations()
return db

View File

@ -23,7 +23,6 @@ class Migration2Callback:
self._drop_old_workflow_tables(cursor)
self._add_workflow_library(cursor)
self._drop_model_manager_metadata(cursor)
self._recreate_model_config(cursor)
self._migrate_embedded_workflows(cursor)
def _add_images_has_workflow(self, cursor: sqlite3.Cursor) -> None:
@ -97,40 +96,6 @@ class Migration2Callback:
"""Drops the `model_manager_metadata` table."""
cursor.execute("DROP TABLE IF EXISTS model_manager_metadata;")
def _recreate_model_config(self, cursor: sqlite3.Cursor) -> None:
"""
Drops the `model_config` table, recreating it.
In 3.4.0, this table used explicit columns but was changed to use json_extract 3.5.0.
Because this table is not used in production, we are able to simply drop it and recreate it.
"""
cursor.execute("DROP TABLE IF EXISTS model_config;")
cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS model_config (
id TEXT NOT NULL PRIMARY KEY,
-- The next 3 fields are enums in python, unrestricted string here
base TEXT GENERATED ALWAYS as (json_extract(config, '$.base')) VIRTUAL NOT NULL,
type TEXT GENERATED ALWAYS as (json_extract(config, '$.type')) VIRTUAL NOT NULL,
name TEXT GENERATED ALWAYS as (json_extract(config, '$.name')) VIRTUAL NOT NULL,
path TEXT GENERATED ALWAYS as (json_extract(config, '$.path')) VIRTUAL NOT NULL,
format TEXT GENERATED ALWAYS as (json_extract(config, '$.format')) VIRTUAL NOT NULL,
original_hash TEXT, -- could be null
-- Serialized JSON representation of the whole config object,
-- which will contain additional fields from subclasses
config TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- unique constraint on combo of name, base and type
UNIQUE(name, base, type)
);
"""
)
def _migrate_embedded_workflows(self, cursor: sqlite3.Cursor) -> None:
"""
In the v3.5.0 release, InvokeAI changed how it handles embedded workflows. The `images` table in

View File

@ -0,0 +1,79 @@
import sqlite3
from logging import Logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
from .util.migrate_yaml_config_1 import MigrateModelYamlToDb1
class Migration3Callback:
def __init__(self, app_config: InvokeAIAppConfig, logger: Logger) -> None:
self._app_config = app_config
self._logger = logger
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._drop_model_manager_metadata(cursor)
self._recreate_model_config(cursor)
self._migrate_model_config_records(cursor)
def _drop_model_manager_metadata(self, cursor: sqlite3.Cursor) -> None:
"""Drops the `model_manager_metadata` table."""
cursor.execute("DROP TABLE IF EXISTS model_manager_metadata;")
def _recreate_model_config(self, cursor: sqlite3.Cursor) -> None:
"""
Drops the `model_config` table, recreating it.
In 3.4.0, this table used explicit columns but was changed to use json_extract 3.5.0.
Because this table is not used in production, we are able to simply drop it and recreate it.
"""
cursor.execute("DROP TABLE IF EXISTS model_config;")
cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS model_config (
id TEXT NOT NULL PRIMARY KEY,
-- The next 3 fields are enums in python, unrestricted string here
base TEXT GENERATED ALWAYS as (json_extract(config, '$.base')) VIRTUAL NOT NULL,
type TEXT GENERATED ALWAYS as (json_extract(config, '$.type')) VIRTUAL NOT NULL,
name TEXT GENERATED ALWAYS as (json_extract(config, '$.name')) VIRTUAL NOT NULL,
path TEXT GENERATED ALWAYS as (json_extract(config, '$.path')) VIRTUAL NOT NULL,
format TEXT GENERATED ALWAYS as (json_extract(config, '$.format')) VIRTUAL NOT NULL,
original_hash TEXT, -- could be null
-- Serialized JSON representation of the whole config object,
-- which will contain additional fields from subclasses
config TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- unique constraint on combo of name, base and type
UNIQUE(name, base, type)
);
"""
)
def _migrate_model_config_records(self, cursor: sqlite3.Cursor) -> None:
"""After updating the model config table, we repopulate it."""
self._logger.info("Migrating model config records from models.yaml to database")
model_record_migrator = MigrateModelYamlToDb1(self._app_config, self._logger, cursor)
model_record_migrator.migrate()
def build_migration_3(app_config: InvokeAIAppConfig, logger: Logger) -> Migration:
"""
Build the migration from database version 2 to 3.
This migration does the following:
- Drops the `model_config` table, recreating it
- Migrates data from `models.yaml` into the `model_config` table
"""
migration_3 = Migration(
from_version=2,
to_version=3,
callback=Migration3Callback(app_config=app_config, logger=logger),
)
return migration_3

View File

@ -0,0 +1,94 @@
import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration4Callback:
"""Callback to do step 4 of migration."""
def __call__(self, cursor: sqlite3.Cursor) -> None: # noqa D102
self._create_model_metadata(cursor)
self._create_model_tags(cursor)
self._create_tags(cursor)
self._create_triggers(cursor)
def _create_model_metadata(self, cursor: sqlite3.Cursor) -> None:
"""Create the table used to store model metadata downloaded from remote sources."""
cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS model_metadata (
id TEXT NOT NULL,
name TEXT GENERATED ALWAYS AS (json_extract(metadata, '$.name')) VIRTUAL NOT NULL,
author TEXT GENERATED ALWAYS AS (json_extract(metadata, '$.author')) VIRTUAL NOT NULL,
-- Serialized JSON representation of the whole metadata object,
-- which will contain additional fields from subclasses
metadata TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
FOREIGN KEY(id) REFERENCES model_config(id)
);
"""
)
def _create_model_tags(self, cursor: sqlite3.Cursor) -> None:
cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS model_tags (
id TEXT NOT NULL,
tag_id INTEGER NOT NULL,
FOREIGN KEY(id) REFERENCES model_config(id),
FOREIGN KEY(tag_id) REFERENCES tags(tag_id),
UNIQUE(id,tag_id)
);
"""
)
def _create_tags(self, cursor: sqlite3.Cursor) -> None:
cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS tags (
tag_id INTEGER NOT NULL PRIMARY KEY,
tag_text TEXT NOT NULL UNIQUE
);
"""
)
def _create_triggers(self, cursor: sqlite3.Cursor) -> None:
cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS model_metadata_updated_at
AFTER UPDATE
ON model_metadata FOR EACH ROW
BEGIN
UPDATE model_metadata SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE id = old.id;
END;
"""
)
cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS model_config_deleted
AFTER DELETE
ON model_config
BEGIN
DELETE from model_metadata WHERE id=old.id;
DELETE from model_tags WHERE id=old.id;
END;
"""
)
def build_migration_4() -> Migration:
"""
Build the migration from database version 3 to 4.
Adds the tables needed to store model metadata and tags.
"""
migration_4 = Migration(
from_version=3,
to_version=4,
callback=Migration4Callback(),
)
return migration_4

View File

@ -1,8 +1,12 @@
# Copyright (c) 2023 Lincoln D. Stein
"""Migrate from the InvokeAI v2 models.yaml format to the v3 sqlite format."""
import json
import sqlite3
from hashlib import sha1
from logging import Logger
from pathlib import Path
from typing import Optional
from omegaconf import DictConfig, OmegaConf
from pydantic import TypeAdapter
@ -10,24 +14,22 @@ from pydantic import TypeAdapter
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_records import (
DuplicateModelException,
ModelRecordServiceSQL,
UnknownModelException,
)
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
ModelConfigFactory,
ModelType,
)
from invokeai.backend.model_manager.hash import FastModelHash
from invokeai.backend.util.logging import InvokeAILogger
ModelsValidator = TypeAdapter(AnyModelConfig)
class MigrateModelYamlToDb:
class MigrateModelYamlToDb1:
"""
Migrate the InvokeAI models.yaml format (VERSION 3.0.0) to SQL3 database format (VERSION 3.2.0)
Migrate the InvokeAI models.yaml format (VERSION 3.0.0) to SQL3 database format (VERSION 3.5.0).
The class has one externally useful method, migrate(), which scans the
currently models.yaml file and imports all its entries into invokeai.db.
@ -41,17 +43,12 @@ class MigrateModelYamlToDb:
config: InvokeAIAppConfig
logger: Logger
cursor: sqlite3.Cursor
def __init__(self) -> None:
self.config = InvokeAIAppConfig.get_config()
self.config.parse_args()
self.logger = InvokeAILogger.get_logger()
def get_db(self) -> ModelRecordServiceSQL:
"""Fetch the sqlite3 database for this installation."""
db_path = None if self.config.use_memory_db else self.config.db_path
db = SqliteDatabase(db_path=db_path, logger=self.logger, verbose=self.config.log_sql)
return ModelRecordServiceSQL(db)
def __init__(self, config: InvokeAIAppConfig, logger: Logger, cursor: sqlite3.Cursor = None) -> None:
self.config = config
self.logger = logger
self.cursor = cursor
def get_yaml(self) -> DictConfig:
"""Fetch the models.yaml DictConfig for this installation."""
@ -62,8 +59,10 @@ class MigrateModelYamlToDb:
def migrate(self) -> None:
"""Do the migration from models.yaml to invokeai.db."""
db = self.get_db()
yaml = self.get_yaml()
try:
yaml = self.get_yaml()
except OSError:
return
for model_key, stanza in yaml.items():
if model_key == "__metadata__":
@ -86,22 +85,62 @@ class MigrateModelYamlToDb:
new_config: AnyModelConfig = ModelsValidator.validate_python(stanza) # type: ignore # see https://github.com/pydantic/pydantic/discussions/7094
try:
if original_record := db.search_by_path(stanza.path):
key = original_record[0].key
if original_record := self._search_by_path(stanza.path):
key = original_record.key
self.logger.info(f"Updating model {model_name} with information from models.yaml using key {key}")
db.update_model(key, new_config)
self._update_model(key, new_config)
else:
self.logger.info(f"Adding model {model_name} with key {model_key}")
db.add_model(new_key, new_config)
self._add_model(new_key, new_config)
except DuplicateModelException:
self.logger.warning(f"Model {model_name} is already in the database")
except UnknownModelException:
self.logger.warning(f"Model at {stanza.path} could not be found in database")
def _search_by_path(self, path: Path) -> Optional[AnyModelConfig]:
self.cursor.execute(
"""--sql
SELECT config FROM model_config
WHERE path=?;
""",
(str(path),),
)
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self.cursor.fetchall()]
return results[0] if results else None
def main():
MigrateModelYamlToDb().migrate()
def _update_model(self, key: str, config: AnyModelConfig) -> None:
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect
json_serialized = record.model_dump_json() # and turn it into a json string.
self.cursor.execute(
"""--sql
UPDATE model_config
SET
config=?
WHERE id=?;
""",
(json_serialized, key),
)
if self.cursor.rowcount == 0:
raise UnknownModelException("model not found")
if __name__ == "__main__":
main()
def _add_model(self, key: str, config: AnyModelConfig) -> None:
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect.
json_serialized = record.model_dump_json() # and turn it into a json string.
try:
self.cursor.execute(
"""--sql
INSERT INTO model_config (
id,
original_hash,
config
)
VALUES (?,?,?);
""",
(
key,
record.original_hash,
json_serialized,
),
)
except sqlite3.IntegrityError as exc:
raise DuplicateModelException(f"{record.name}: model is already in database") from exc

View File

@ -0,0 +1,975 @@
{
"name": "Prompt from File",
"author": "InvokeAI",
"description": "Sample workflow using Prompt from File node",
"version": "0.1.0",
"contact": "invoke@invoke.ai",
"tags": "text2image, prompt from file, default",
"notes": "",
"exposedFields": [
{
"nodeId": "d6353b7f-b447-4e17-8f2e-80a88c91d426",
"fieldName": "model"
},
{
"nodeId": "1b7e0df8-8589-4915-a4ea-c0088f15d642",
"fieldName": "file_path"
}
],
"meta": {
"category": "default",
"version": "2.0.0"
},
"id": "d1609af5-eb0a-4f73-b573-c9af96a8d6bf",
"nodes": [
{
"id": "c2eaf1ba-5708-4679-9e15-945b8b432692",
"type": "invocation",
"data": {
"id": "c2eaf1ba-5708-4679-9e15-945b8b432692",
"type": "compel",
"label": "",
"isOpen": false,
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.0.0",
"nodePack": "invokeai",
"inputs": {
"prompt": {
"id": "dcdf3f6d-9b96-4bcd-9b8d-f992fefe4f62",
"name": "prompt",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "StringField"
},
"value": ""
},
"clip": {
"id": "3f1981c9-d8a9-42eb-a739-4f120eb80745",
"name": "clip",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ClipField"
}
}
},
"outputs": {
"conditioning": {
"id": "46205e6c-c5e2-44cb-9c82-1cd20b95674a",
"name": "conditioning",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ConditioningField"
}
}
}
},
"width": 320,
"height": 32,
"position": {
"x": 925,
"y": -200
}
},
{
"id": "1b7e0df8-8589-4915-a4ea-c0088f15d642",
"type": "invocation",
"data": {
"id": "1b7e0df8-8589-4915-a4ea-c0088f15d642",
"type": "prompt_from_file",
"label": "Prompts from File",
"isOpen": true,
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.0.1",
"nodePack": "invokeai",
"inputs": {
"file_path": {
"id": "37e37684-4f30-4ec8-beae-b333e550f904",
"name": "file_path",
"fieldKind": "input",
"label": "Prompts File Path",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "StringField"
},
"value": ""
},
"pre_prompt": {
"id": "7de02feb-819a-4992-bad3-72a30920ddea",
"name": "pre_prompt",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "StringField"
},
"value": ""
},
"post_prompt": {
"id": "95f191d8-a282-428e-bd65-de8cb9b7513a",
"name": "post_prompt",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "StringField"
},
"value": ""
},
"start_line": {
"id": "efee9a48-05ab-4829-8429-becfa64a0782",
"name": "start_line",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
},
"value": 1
},
"max_prompts": {
"id": "abebb428-3d3d-49fd-a482-4e96a16fff08",
"name": "max_prompts",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
},
"value": 1
}
},
"outputs": {
"collection": {
"id": "77d5d7f1-9877-4ab1-9a8c-33e9ffa9abf3",
"name": "collection",
"fieldKind": "output",
"type": {
"isCollection": true,
"isCollectionOrScalar": false,
"name": "StringField"
}
}
}
},
"width": 320,
"height": 580,
"position": {
"x": 475,
"y": -400
}
},
{
"id": "1b89067c-3f6b-42c8-991f-e3055789b251",
"type": "invocation",
"data": {
"id": "1b89067c-3f6b-42c8-991f-e3055789b251",
"type": "iterate",
"label": "",
"isOpen": false,
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.1.0",
"inputs": {
"collection": {
"id": "4c564bf8-5ed6-441e-ad2c-dda265d5785f",
"name": "collection",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": true,
"isCollectionOrScalar": false,
"name": "CollectionField"
}
}
},
"outputs": {
"item": {
"id": "36340f9a-e7a5-4afa-b4b5-313f4e292380",
"name": "item",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "CollectionItemField"
}
},
"index": {
"id": "1beca95a-2159-460f-97ff-c8bab7d89336",
"name": "index",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
},
"total": {
"id": "ead597b8-108e-4eda-88a8-5c29fa2f8df9",
"name": "total",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
}
}
},
"width": 320,
"height": 32,
"position": {
"x": 925,
"y": -400
}
},
{
"id": "d6353b7f-b447-4e17-8f2e-80a88c91d426",
"type": "invocation",
"data": {
"id": "d6353b7f-b447-4e17-8f2e-80a88c91d426",
"type": "main_model_loader",
"label": "",
"isOpen": true,
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.0.0",
"nodePack": "invokeai",
"inputs": {
"model": {
"id": "3f264259-3418-47d5-b90d-b6600e36ae46",
"name": "model",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "MainModelField"
},
"value": {
"model_name": "stable-diffusion-v1-5",
"base_model": "sd-1",
"model_type": "main"
}
}
},
"outputs": {
"unet": {
"id": "8e182ea2-9d0a-4c02-9407-27819288d4b5",
"name": "unet",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "UNetField"
}
},
"clip": {
"id": "d67d9d30-058c-46d5-bded-3d09d6d1aa39",
"name": "clip",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ClipField"
}
},
"vae": {
"id": "89641601-0429-4448-98d5-190822d920d8",
"name": "vae",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "VaeField"
}
}
}
},
"width": 320,
"height": 227,
"position": {
"x": 0,
"y": -375
}
},
{
"id": "fc9d0e35-a6de-4a19-84e1-c72497c823f6",
"type": "invocation",
"data": {
"id": "fc9d0e35-a6de-4a19-84e1-c72497c823f6",
"type": "compel",
"label": "",
"isOpen": false,
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.0.0",
"nodePack": "invokeai",
"inputs": {
"prompt": {
"id": "dcdf3f6d-9b96-4bcd-9b8d-f992fefe4f62",
"name": "prompt",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "StringField"
},
"value": ""
},
"clip": {
"id": "3f1981c9-d8a9-42eb-a739-4f120eb80745",
"name": "clip",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ClipField"
}
}
},
"outputs": {
"conditioning": {
"id": "46205e6c-c5e2-44cb-9c82-1cd20b95674a",
"name": "conditioning",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ConditioningField"
}
}
}
},
"width": 320,
"height": 32,
"position": {
"x": 925,
"y": -275
}
},
{
"id": "0eb5f3f5-1b91-49eb-9ef0-41d67c7eae77",
"type": "invocation",
"data": {
"id": "0eb5f3f5-1b91-49eb-9ef0-41d67c7eae77",
"type": "noise",
"label": "",
"isOpen": false,
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.0.1",
"nodePack": "invokeai",
"inputs": {
"seed": {
"id": "b722d84a-eeee-484f-bef2-0250c027cb67",
"name": "seed",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
},
"value": 0
},
"width": {
"id": "d5f8ce11-0502-4bfc-9a30-5757dddf1f94",
"name": "width",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
},
"value": 512
},
"height": {
"id": "f187d5ff-38a5-4c3f-b780-fc5801ef34af",
"name": "height",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
},
"value": 512
},
"use_cpu": {
"id": "12f112b8-8b76-4816-b79e-662edc9f9aa5",
"name": "use_cpu",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "BooleanField"
},
"value": true
}
},
"outputs": {
"noise": {
"id": "08576ad1-96d9-42d2-96ef-6f5c1961933f",
"name": "noise",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "LatentsField"
}
},
"width": {
"id": "f3e1f94a-258d-41ff-9789-bd999bd9f40d",
"name": "width",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
},
"height": {
"id": "6cefc357-4339-415e-a951-49b9c2be32f4",
"name": "height",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
}
}
},
"width": 320,
"height": 32,
"position": {
"x": 925,
"y": 25
}
},
{
"id": "dfc20e07-7aef-4fc0-a3a1-7bf68ec6a4e5",
"type": "invocation",
"data": {
"id": "dfc20e07-7aef-4fc0-a3a1-7bf68ec6a4e5",
"type": "rand_int",
"label": "",
"isOpen": false,
"notes": "",
"isIntermediate": true,
"useCache": false,
"version": "1.0.0",
"nodePack": "invokeai",
"inputs": {
"low": {
"id": "b9fc6cf1-469c-4037-9bf0-04836965826f",
"name": "low",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
},
"value": 0
},
"high": {
"id": "06eac725-0f60-4ba2-b8cd-7ad9f757488c",
"name": "high",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
},
"value": 2147483647
}
},
"outputs": {
"value": {
"id": "df08c84e-7346-4e92-9042-9e5cb773aaff",
"name": "value",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
}
}
},
"width": 320,
"height": 32,
"position": {
"x": 925,
"y": -50
}
},
{
"id": "491ec988-3c77-4c37-af8a-39a0c4e7a2a1",
"type": "invocation",
"data": {
"id": "491ec988-3c77-4c37-af8a-39a0c4e7a2a1",
"type": "l2i",
"label": "",
"isOpen": true,
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.2.0",
"nodePack": "invokeai",
"inputs": {
"metadata": {
"id": "022e4b33-562b-438d-b7df-41c3fd931f40",
"name": "metadata",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "MetadataField"
}
},
"latents": {
"id": "67cb6c77-a394-4a66-a6a9-a0a7dcca69ec",
"name": "latents",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "LatentsField"
}
},
"vae": {
"id": "7b3fd9ad-a4ef-4e04-89fa-3832a9902dbd",
"name": "vae",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "VaeField"
}
},
"tiled": {
"id": "5ac5680d-3add-4115-8ec0-9ef5bb87493b",
"name": "tiled",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "BooleanField"
},
"value": false
},
"fp32": {
"id": "db8297f5-55f8-452f-98cf-6572c2582152",
"name": "fp32",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "BooleanField"
},
"value": false
}
},
"outputs": {
"image": {
"id": "d8778d0c-592a-4960-9280-4e77e00a7f33",
"name": "image",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ImageField"
}
},
"width": {
"id": "c8b0a75a-f5de-4ff2-9227-f25bb2b97bec",
"name": "width",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
},
"height": {
"id": "83c05fbf-76b9-49ab-93c4-fa4b10e793e4",
"name": "height",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
}
}
},
"width": 320,
"height": 267,
"position": {
"x": 2037.861329274915,
"y": -329.8393457509562
}
},
{
"id": "2fb1577f-0a56-4f12-8711-8afcaaaf1d5e",
"type": "invocation",
"data": {
"id": "2fb1577f-0a56-4f12-8711-8afcaaaf1d5e",
"type": "denoise_latents",
"label": "",
"isOpen": true,
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.5.0",
"nodePack": "invokeai",
"inputs": {
"positive_conditioning": {
"id": "751fb35b-3f23-45ce-af1c-053e74251337",
"name": "positive_conditioning",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ConditioningField"
}
},
"negative_conditioning": {
"id": "b9dc06b6-7481-4db1-a8c2-39d22a5eacff",
"name": "negative_conditioning",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ConditioningField"
}
},
"noise": {
"id": "6e15e439-3390-48a4-8031-01e0e19f0e1d",
"name": "noise",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "LatentsField"
}
},
"steps": {
"id": "bfdfb3df-760b-4d51-b17b-0abb38b976c2",
"name": "steps",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
},
"value": 10
},
"cfg_scale": {
"id": "47770858-322e-41af-8494-d8b63ed735f3",
"name": "cfg_scale",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": true,
"name": "FloatField"
},
"value": 7.5
},
"denoising_start": {
"id": "2ba78720-ee02-4130-a348-7bc3531f790b",
"name": "denoising_start",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "FloatField"
},
"value": 0
},
"denoising_end": {
"id": "a874dffb-d433-4d1a-9f59-af4367bb05e4",
"name": "denoising_end",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "FloatField"
},
"value": 1
},
"scheduler": {
"id": "36e021ad-b762-4fe4-ad4d-17f0291c40b2",
"name": "scheduler",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "SchedulerField"
},
"value": "euler"
},
"unet": {
"id": "98d3282d-f9f6-4b5e-b9e8-58658f1cac78",
"name": "unet",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "UNetField"
}
},
"control": {
"id": "f2ea3216-43d5-42b4-887f-36e8f7166d53",
"name": "control",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": true,
"name": "ControlField"
}
},
"ip_adapter": {
"id": "d0780610-a298-47c8-a54e-70e769e0dfe2",
"name": "ip_adapter",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": true,
"name": "IPAdapterField"
}
},
"t2i_adapter": {
"id": "fdb40970-185e-4ea8-8bb5-88f06f91f46a",
"name": "t2i_adapter",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": true,
"name": "T2IAdapterField"
}
},
"cfg_rescale_multiplier": {
"id": "3af2d8c5-de83-425c-a100-49cb0f1f4385",
"name": "cfg_rescale_multiplier",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "FloatField"
},
"value": 0
},
"latents": {
"id": "e05b538a-1b5a-4aa5-84b1-fd2361289a81",
"name": "latents",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "LatentsField"
}
},
"denoise_mask": {
"id": "463a419e-df30-4382-8ffb-b25b25abe425",
"name": "denoise_mask",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "DenoiseMaskField"
}
}
},
"outputs": {
"latents": {
"id": "559ee688-66cf-4139-8b82-3d3aa69995ce",
"name": "latents",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "LatentsField"
}
},
"width": {
"id": "0b4285c2-e8b9-48e5-98f6-0a49d3f98fd2",
"name": "width",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
},
"height": {
"id": "8b0881b9-45e5-47d5-b526-24b6661de0ee",
"name": "height",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
}
}
},
"width": 320,
"height": 705,
"position": {
"x": 1570.9941088179146,
"y": -407.6505491604564
}
}
],
"edges": [
{
"id": "1b89067c-3f6b-42c8-991f-e3055789b251-fc9d0e35-a6de-4a19-84e1-c72497c823f6-collapsed",
"source": "1b89067c-3f6b-42c8-991f-e3055789b251",
"target": "fc9d0e35-a6de-4a19-84e1-c72497c823f6",
"type": "collapsed"
},
{
"id": "dfc20e07-7aef-4fc0-a3a1-7bf68ec6a4e5-0eb5f3f5-1b91-49eb-9ef0-41d67c7eae77-collapsed",
"source": "dfc20e07-7aef-4fc0-a3a1-7bf68ec6a4e5",
"target": "0eb5f3f5-1b91-49eb-9ef0-41d67c7eae77",
"type": "collapsed"
},
{
"id": "reactflow__edge-1b7e0df8-8589-4915-a4ea-c0088f15d642collection-1b89067c-3f6b-42c8-991f-e3055789b251collection",
"source": "1b7e0df8-8589-4915-a4ea-c0088f15d642",
"target": "1b89067c-3f6b-42c8-991f-e3055789b251",
"type": "default",
"sourceHandle": "collection",
"targetHandle": "collection"
},
{
"id": "reactflow__edge-d6353b7f-b447-4e17-8f2e-80a88c91d426clip-fc9d0e35-a6de-4a19-84e1-c72497c823f6clip",
"source": "d6353b7f-b447-4e17-8f2e-80a88c91d426",
"target": "fc9d0e35-a6de-4a19-84e1-c72497c823f6",
"type": "default",
"sourceHandle": "clip",
"targetHandle": "clip"
},
{
"id": "reactflow__edge-1b89067c-3f6b-42c8-991f-e3055789b251item-fc9d0e35-a6de-4a19-84e1-c72497c823f6prompt",
"source": "1b89067c-3f6b-42c8-991f-e3055789b251",
"target": "fc9d0e35-a6de-4a19-84e1-c72497c823f6",
"type": "default",
"sourceHandle": "item",
"targetHandle": "prompt"
},
{
"id": "reactflow__edge-d6353b7f-b447-4e17-8f2e-80a88c91d426clip-c2eaf1ba-5708-4679-9e15-945b8b432692clip",
"source": "d6353b7f-b447-4e17-8f2e-80a88c91d426",
"target": "c2eaf1ba-5708-4679-9e15-945b8b432692",
"type": "default",
"sourceHandle": "clip",
"targetHandle": "clip"
},
{
"id": "reactflow__edge-dfc20e07-7aef-4fc0-a3a1-7bf68ec6a4e5value-0eb5f3f5-1b91-49eb-9ef0-41d67c7eae77seed",
"source": "dfc20e07-7aef-4fc0-a3a1-7bf68ec6a4e5",
"target": "0eb5f3f5-1b91-49eb-9ef0-41d67c7eae77",
"type": "default",
"sourceHandle": "value",
"targetHandle": "seed"
},
{
"id": "reactflow__edge-fc9d0e35-a6de-4a19-84e1-c72497c823f6conditioning-2fb1577f-0a56-4f12-8711-8afcaaaf1d5epositive_conditioning",
"source": "fc9d0e35-a6de-4a19-84e1-c72497c823f6",
"target": "2fb1577f-0a56-4f12-8711-8afcaaaf1d5e",
"type": "default",
"sourceHandle": "conditioning",
"targetHandle": "positive_conditioning"
},
{
"id": "reactflow__edge-c2eaf1ba-5708-4679-9e15-945b8b432692conditioning-2fb1577f-0a56-4f12-8711-8afcaaaf1d5enegative_conditioning",
"source": "c2eaf1ba-5708-4679-9e15-945b8b432692",
"target": "2fb1577f-0a56-4f12-8711-8afcaaaf1d5e",
"type": "default",
"sourceHandle": "conditioning",
"targetHandle": "negative_conditioning"
},
{
"id": "reactflow__edge-0eb5f3f5-1b91-49eb-9ef0-41d67c7eae77noise-2fb1577f-0a56-4f12-8711-8afcaaaf1d5enoise",
"source": "0eb5f3f5-1b91-49eb-9ef0-41d67c7eae77",
"target": "2fb1577f-0a56-4f12-8711-8afcaaaf1d5e",
"type": "default",
"sourceHandle": "noise",
"targetHandle": "noise"
},
{
"id": "reactflow__edge-d6353b7f-b447-4e17-8f2e-80a88c91d426unet-2fb1577f-0a56-4f12-8711-8afcaaaf1d5eunet",
"source": "d6353b7f-b447-4e17-8f2e-80a88c91d426",
"target": "2fb1577f-0a56-4f12-8711-8afcaaaf1d5e",
"type": "default",
"sourceHandle": "unet",
"targetHandle": "unet"
},
{
"id": "reactflow__edge-2fb1577f-0a56-4f12-8711-8afcaaaf1d5elatents-491ec988-3c77-4c37-af8a-39a0c4e7a2a1latents",
"source": "2fb1577f-0a56-4f12-8711-8afcaaaf1d5e",
"target": "491ec988-3c77-4c37-af8a-39a0c4e7a2a1",
"type": "default",
"sourceHandle": "latents",
"targetHandle": "latents"
},
{
"id": "reactflow__edge-d6353b7f-b447-4e17-8f2e-80a88c91d426vae-491ec988-3c77-4c37-af8a-39a0c4e7a2a1vae",
"source": "d6353b7f-b447-4e17-8f2e-80a88c91d426",
"target": "491ec988-3c77-4c37-af8a-39a0c4e7a2a1",
"type": "default",
"sourceHandle": "vae",
"targetHandle": "vae"
}
]
}

View File

@ -0,0 +1,903 @@
{
"name": "Text to Image with LoRA",
"author": "InvokeAI",
"description": "Simple text to image workflow with a LoRA",
"version": "1.0.0",
"contact": "invoke@invoke.ai",
"tags": "text to image, lora, default",
"notes": "",
"exposedFields": [
{
"nodeId": "24e9d7ed-4836-4ec4-8f9e-e747721f9818",
"fieldName": "model"
},
{
"nodeId": "c41e705b-f2e3-4d1a-83c4-e34bb9344966",
"fieldName": "lora"
},
{
"nodeId": "c41e705b-f2e3-4d1a-83c4-e34bb9344966",
"fieldName": "weight"
},
{
"nodeId": "c3fa6872-2599-4a82-a596-b3446a66cf8b",
"fieldName": "prompt"
}
],
"meta": {
"version": "2.0.0",
"category": "default"
},
"id": "a9d70c39-4cdd-4176-9942-8ff3fe32d3b1",
"nodes": [
{
"id": "85b77bb2-c67a-416a-b3e8-291abe746c44",
"type": "invocation",
"data": {
"id": "85b77bb2-c67a-416a-b3e8-291abe746c44",
"type": "compel",
"label": "",
"isOpen": true,
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.0.0",
"inputs": {
"prompt": {
"id": "39fe92c4-38eb-4cc7-bf5e-cbcd31847b11",
"name": "prompt",
"fieldKind": "input",
"label": "Negative Prompt",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "StringField"
},
"value": ""
},
"clip": {
"id": "14313164-e5c4-4e40-a599-41b614fe3690",
"name": "clip",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ClipField"
}
}
},
"outputs": {
"conditioning": {
"id": "02140b9d-50f3-470b-a0b7-01fc6ed2dcd6",
"name": "conditioning",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ConditioningField"
}
}
}
},
"width": 320,
"height": 256,
"position": {
"x": 3425,
"y": -300
}
},
{
"id": "24e9d7ed-4836-4ec4-8f9e-e747721f9818",
"type": "invocation",
"data": {
"id": "24e9d7ed-4836-4ec4-8f9e-e747721f9818",
"type": "main_model_loader",
"label": "",
"isOpen": true,
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.0.0",
"inputs": {
"model": {
"id": "e2e1c177-ae39-4244-920e-d621fa156a24",
"name": "model",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "MainModelField"
},
"value": {
"model_name": "Analog-Diffusion",
"base_model": "sd-1",
"model_type": "main"
}
}
},
"outputs": {
"vae": {
"id": "f91410e8-9378-4298-b285-f0f40ffd9825",
"name": "vae",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "VaeField"
}
},
"clip": {
"id": "928d91bf-de0c-44a8-b0c8-4de0e2e5b438",
"name": "clip",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ClipField"
}
},
"unet": {
"id": "eacaf530-4e7e-472e-b904-462192189fc1",
"name": "unet",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "UNetField"
}
}
}
},
"width": 320,
"height": 227,
"position": {
"x": 2500,
"y": -600
}
},
{
"id": "c41e705b-f2e3-4d1a-83c4-e34bb9344966",
"type": "invocation",
"data": {
"id": "c41e705b-f2e3-4d1a-83c4-e34bb9344966",
"type": "lora_loader",
"label": "",
"isOpen": true,
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.0.0",
"inputs": {
"lora": {
"id": "36d867e8-92ea-4c3f-9ad5-ba05c64cf326",
"name": "lora",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "LoRAModelField"
},
"value": {
"model_name": "Ink scenery",
"base_model": "sd-1"
}
},
"weight": {
"id": "8be86540-ba81-49b3-b394-2b18fa70b867",
"name": "weight",
"fieldKind": "input",
"label": "LoRA Weight",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "FloatField"
},
"value": 0.75
},
"unet": {
"id": "9c4d5668-e9e1-411b-8f4b-e71115bc4a01",
"name": "unet",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "UNetField"
}
},
"clip": {
"id": "918ec00e-e76f-4ad0-aee1-3927298cf03b",
"name": "clip",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ClipField"
}
}
},
"outputs": {
"unet": {
"id": "c63f7825-1bcf-451d-b7a7-aa79f5c77416",
"name": "unet",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "UNetField"
}
},
"clip": {
"id": "6f79ef2d-00f7-4917-bee3-53e845bf4192",
"name": "clip",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ClipField"
}
}
}
},
"width": 320,
"height": 252,
"position": {
"x": 2975,
"y": -600
}
},
{
"id": "c3fa6872-2599-4a82-a596-b3446a66cf8b",
"type": "invocation",
"data": {
"id": "c3fa6872-2599-4a82-a596-b3446a66cf8b",
"type": "compel",
"label": "",
"isOpen": true,
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.0.0",
"inputs": {
"prompt": {
"id": "39fe92c4-38eb-4cc7-bf5e-cbcd31847b11",
"name": "prompt",
"fieldKind": "input",
"label": "Positive Prompt",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "StringField"
},
"value": "cute tiger cub"
},
"clip": {
"id": "14313164-e5c4-4e40-a599-41b614fe3690",
"name": "clip",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ClipField"
}
}
},
"outputs": {
"conditioning": {
"id": "02140b9d-50f3-470b-a0b7-01fc6ed2dcd6",
"name": "conditioning",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ConditioningField"
}
}
}
},
"width": 320,
"height": 256,
"position": {
"x": 3425,
"y": -575
}
},
{
"id": "ad487d0c-dcbb-49c5-bb8e-b28d4cbc5a63",
"type": "invocation",
"data": {
"id": "ad487d0c-dcbb-49c5-bb8e-b28d4cbc5a63",
"type": "denoise_latents",
"label": "",
"isOpen": true,
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.5.0",
"inputs": {
"positive_conditioning": {
"id": "025ff44b-c4c6-4339-91b4-5f461e2cadc5",
"name": "positive_conditioning",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ConditioningField"
}
},
"negative_conditioning": {
"id": "2d92b45a-a7fb-4541-9a47-7c7495f50f54",
"name": "negative_conditioning",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ConditioningField"
}
},
"noise": {
"id": "4d0deeff-24ed-4562-a1ca-7833c0649377",
"name": "noise",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "LatentsField"
}
},
"steps": {
"id": "c9907328-aece-4af9-8a95-211b4f99a325",
"name": "steps",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
},
"value": 10
},
"cfg_scale": {
"id": "7cf0f031-2078-49f4-9273-bb3a64ad7130",
"name": "cfg_scale",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": true,
"name": "FloatField"
},
"value": 7.5
},
"denoising_start": {
"id": "44cec3ba-b404-4b51-ba98-add9d783279e",
"name": "denoising_start",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "FloatField"
},
"value": 0
},
"denoising_end": {
"id": "3e7975f3-e438-4a13-8a14-395eba1fb7cd",
"name": "denoising_end",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "FloatField"
},
"value": 1
},
"scheduler": {
"id": "a6f6509b-7bb4-477d-b5fb-74baefa38111",
"name": "scheduler",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "SchedulerField"
},
"value": "euler"
},
"unet": {
"id": "5a87617a-b09f-417b-9b75-0cea4c255227",
"name": "unet",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "UNetField"
}
},
"control": {
"id": "db87aace-ace8-4f2a-8f2b-1f752389fa9b",
"name": "control",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": true,
"name": "ControlField"
}
},
"ip_adapter": {
"id": "f0c133ed-4d6d-4567-bb9a-b1779810993c",
"name": "ip_adapter",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": true,
"name": "IPAdapterField"
}
},
"t2i_adapter": {
"id": "59ee1233-887f-45e7-aa14-cbad5f6cb77f",
"name": "t2i_adapter",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": true,
"name": "T2IAdapterField"
}
},
"cfg_rescale_multiplier": {
"id": "1a12e781-4b30-4707-b432-18c31866b5c3",
"name": "cfg_rescale_multiplier",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "FloatField"
},
"value": 0
},
"latents": {
"id": "d0e593ae-305c-424b-9acd-3af830085832",
"name": "latents",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "LatentsField"
}
},
"denoise_mask": {
"id": "b81b5a79-fc2b-4011-aae6-64c92bae59a7",
"name": "denoise_mask",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "DenoiseMaskField"
}
}
},
"outputs": {
"latents": {
"id": "9ae4022a-548e-407e-90cf-cc5ca5ff8a21",
"name": "latents",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "LatentsField"
}
},
"width": {
"id": "730ba4bd-2c52-46bb-8c87-9b3aec155576",
"name": "width",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
},
"height": {
"id": "52b98f0b-b5ff-41b5-acc7-d0b1d1011a6f",
"name": "height",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
}
}
},
"width": 320,
"height": 705,
"position": {
"x": 3975,
"y": -575
}
},
{
"id": "ea18915f-2c5b-4569-b725-8e9e9122e8d3",
"type": "invocation",
"data": {
"id": "ea18915f-2c5b-4569-b725-8e9e9122e8d3",
"type": "noise",
"label": "",
"isOpen": false,
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.0.1",
"inputs": {
"seed": {
"id": "446ac80c-ba0a-4fea-a2d7-21128f52e5bf",
"name": "seed",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
},
"value": 0
},
"width": {
"id": "779831b3-20b4-4f5f-9de7-d17de57288d8",
"name": "width",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
},
"value": 512
},
"height": {
"id": "08959766-6d67-4276-b122-e54b911f2316",
"name": "height",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
},
"value": 512
},
"use_cpu": {
"id": "53b36a98-00c4-4dc5-97a4-ef3432c0a805",
"name": "use_cpu",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "BooleanField"
},
"value": true
}
},
"outputs": {
"noise": {
"id": "eed95824-580b-442f-aa35-c073733cecce",
"name": "noise",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "LatentsField"
}
},
"width": {
"id": "7985a261-dfee-47a8-908a-c5a8754f5dc4",
"name": "width",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
},
"height": {
"id": "3d00f6c1-84b0-4262-83d9-3bf755babeea",
"name": "height",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
}
}
},
"width": 320,
"height": 32,
"position": {
"x": 3425,
"y": 75
}
},
{
"id": "6fd74a17-6065-47a5-b48b-f4e2b8fa7953",
"type": "invocation",
"data": {
"id": "6fd74a17-6065-47a5-b48b-f4e2b8fa7953",
"type": "rand_int",
"label": "",
"isOpen": false,
"notes": "",
"isIntermediate": true,
"useCache": false,
"version": "1.0.0",
"inputs": {
"low": {
"id": "d25305f3-bfd6-446c-8e2c-0b025ec9e9ad",
"name": "low",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
},
"value": 0
},
"high": {
"id": "10376a3d-b8fe-4a51-b81a-ea46d8c12c78",
"name": "high",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
},
"value": 2147483647
}
},
"outputs": {
"value": {
"id": "c64878fa-53b1-4202-b88a-cfb854216a57",
"name": "value",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
}
}
},
"width": 320,
"height": 32,
"position": {
"x": 3425,
"y": 0
}
},
{
"id": "a9683c0a-6b1f-4a5e-8187-c57e764b3400",
"type": "invocation",
"data": {
"id": "a9683c0a-6b1f-4a5e-8187-c57e764b3400",
"type": "l2i",
"label": "",
"isOpen": true,
"notes": "",
"isIntermediate": false,
"useCache": true,
"version": "1.2.0",
"inputs": {
"metadata": {
"id": "b1982e8a-14ad-4029-a697-beb30af8340f",
"name": "metadata",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "MetadataField"
}
},
"latents": {
"id": "f7669388-9f91-46cc-94fc-301fa7041c3e",
"name": "latents",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "LatentsField"
}
},
"vae": {
"id": "c6f2d4db-4d0a-4e3d-acb4-b5c5a228a3e2",
"name": "vae",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "VaeField"
}
},
"tiled": {
"id": "19ef7d31-d96f-4e94-b7e5-95914e9076fc",
"name": "tiled",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "BooleanField"
},
"value": false
},
"fp32": {
"id": "a9454533-8ab7-4225-b411-646dc5e76d00",
"name": "fp32",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "BooleanField"
},
"value": false
}
},
"outputs": {
"image": {
"id": "4f81274e-e216-47f3-9fb6-f97493a40e6f",
"name": "image",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ImageField"
}
},
"width": {
"id": "61a9acfb-1547-4f1e-8214-e89bd3855ee5",
"name": "width",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
},
"height": {
"id": "b15cc793-4172-4b07-bcf4-5627bbc7d0d7",
"name": "height",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
}
}
},
"width": 320,
"height": 267,
"position": {
"x": 4450,
"y": -550
}
}
],
"edges": [
{
"id": "6fd74a17-6065-47a5-b48b-f4e2b8fa7953-ea18915f-2c5b-4569-b725-8e9e9122e8d3-collapsed",
"source": "6fd74a17-6065-47a5-b48b-f4e2b8fa7953",
"target": "ea18915f-2c5b-4569-b725-8e9e9122e8d3",
"type": "collapsed"
},
{
"id": "reactflow__edge-24e9d7ed-4836-4ec4-8f9e-e747721f9818clip-c41e705b-f2e3-4d1a-83c4-e34bb9344966clip",
"source": "24e9d7ed-4836-4ec4-8f9e-e747721f9818",
"target": "c41e705b-f2e3-4d1a-83c4-e34bb9344966",
"type": "default",
"sourceHandle": "clip",
"targetHandle": "clip"
},
{
"id": "reactflow__edge-c41e705b-f2e3-4d1a-83c4-e34bb9344966clip-c3fa6872-2599-4a82-a596-b3446a66cf8bclip",
"source": "c41e705b-f2e3-4d1a-83c4-e34bb9344966",
"target": "c3fa6872-2599-4a82-a596-b3446a66cf8b",
"type": "default",
"sourceHandle": "clip",
"targetHandle": "clip"
},
{
"id": "reactflow__edge-24e9d7ed-4836-4ec4-8f9e-e747721f9818unet-c41e705b-f2e3-4d1a-83c4-e34bb9344966unet",
"source": "24e9d7ed-4836-4ec4-8f9e-e747721f9818",
"target": "c41e705b-f2e3-4d1a-83c4-e34bb9344966",
"type": "default",
"sourceHandle": "unet",
"targetHandle": "unet"
},
{
"id": "reactflow__edge-c41e705b-f2e3-4d1a-83c4-e34bb9344966unet-ad487d0c-dcbb-49c5-bb8e-b28d4cbc5a63unet",
"source": "c41e705b-f2e3-4d1a-83c4-e34bb9344966",
"target": "ad487d0c-dcbb-49c5-bb8e-b28d4cbc5a63",
"type": "default",
"sourceHandle": "unet",
"targetHandle": "unet"
},
{
"id": "reactflow__edge-85b77bb2-c67a-416a-b3e8-291abe746c44conditioning-ad487d0c-dcbb-49c5-bb8e-b28d4cbc5a63negative_conditioning",
"source": "85b77bb2-c67a-416a-b3e8-291abe746c44",
"target": "ad487d0c-dcbb-49c5-bb8e-b28d4cbc5a63",
"type": "default",
"sourceHandle": "conditioning",
"targetHandle": "negative_conditioning"
},
{
"id": "reactflow__edge-c3fa6872-2599-4a82-a596-b3446a66cf8bconditioning-ad487d0c-dcbb-49c5-bb8e-b28d4cbc5a63positive_conditioning",
"source": "c3fa6872-2599-4a82-a596-b3446a66cf8b",
"target": "ad487d0c-dcbb-49c5-bb8e-b28d4cbc5a63",
"type": "default",
"sourceHandle": "conditioning",
"targetHandle": "positive_conditioning"
},
{
"id": "reactflow__edge-ea18915f-2c5b-4569-b725-8e9e9122e8d3noise-ad487d0c-dcbb-49c5-bb8e-b28d4cbc5a63noise",
"source": "ea18915f-2c5b-4569-b725-8e9e9122e8d3",
"target": "ad487d0c-dcbb-49c5-bb8e-b28d4cbc5a63",
"type": "default",
"sourceHandle": "noise",
"targetHandle": "noise"
},
{
"id": "reactflow__edge-6fd74a17-6065-47a5-b48b-f4e2b8fa7953value-ea18915f-2c5b-4569-b725-8e9e9122e8d3seed",
"source": "6fd74a17-6065-47a5-b48b-f4e2b8fa7953",
"target": "ea18915f-2c5b-4569-b725-8e9e9122e8d3",
"type": "default",
"sourceHandle": "value",
"targetHandle": "seed"
},
{
"id": "reactflow__edge-ad487d0c-dcbb-49c5-bb8e-b28d4cbc5a63latents-a9683c0a-6b1f-4a5e-8187-c57e764b3400latents",
"source": "ad487d0c-dcbb-49c5-bb8e-b28d4cbc5a63",
"target": "a9683c0a-6b1f-4a5e-8187-c57e764b3400",
"type": "default",
"sourceHandle": "latents",
"targetHandle": "latents"
},
{
"id": "reactflow__edge-24e9d7ed-4836-4ec4-8f9e-e747721f9818vae-a9683c0a-6b1f-4a5e-8187-c57e764b3400vae",
"source": "24e9d7ed-4836-4ec4-8f9e-e747721f9818",
"target": "a9683c0a-6b1f-4a5e-8187-c57e764b3400",
"type": "default",
"sourceHandle": "vae",
"targetHandle": "vae"
},
{
"id": "reactflow__edge-c41e705b-f2e3-4d1a-83c4-e34bb9344966clip-85b77bb2-c67a-416a-b3e8-291abe746c44clip",
"source": "c41e705b-f2e3-4d1a-83c4-e34bb9344966",
"target": "85b77bb2-c67a-416a-b3e8-291abe746c44",
"type": "default",
"sourceHandle": "clip",
"targetHandle": "clip"
}
]
}

View File

@ -0,0 +1,8 @@
import re
def extract_ti_triggers_from_prompt(prompt: str) -> list[str]:
ti_triggers = []
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
ti_triggers.append(trigger)
return ti_triggers

View File

@ -28,7 +28,7 @@ def check_invokeai_root(config: InvokeAIAppConfig):
print("== STARTUP ABORTED ==")
print("** One or more necessary files is missing from your InvokeAI root directory **")
print("** Please rerun the configuration script to fix this problem. **")
print("** From the launcher, selection option [7]. **")
print("** From the launcher, selection option [6]. **")
print(
'** From the command line, activate the virtual environment and run "invokeai-configure --yes --skip-sd-weights" **'
)

View File

@ -389,7 +389,7 @@ class TextualInversionCheckpointProbe(CheckpointProbeBase):
elif "clip_g" in checkpoint:
token_dim = checkpoint["clip_g"].shape[-1]
else:
token_dim = list(checkpoint.values())[0].shape[0]
token_dim = list(checkpoint.values())[0].shape[-1]
if token_dim == 768:
return BaseModelType.StableDiffusion1
elif token_dim == 1024:

View File

@ -9,7 +9,7 @@ def lora_token_vector_length(checkpoint: dict) -> int:
:param checkpoint: The checkpoint
"""
def _get_shape_1(key, tensor, checkpoint):
def _get_shape_1(key: str, tensor, checkpoint) -> int:
lora_token_vector_length = None
if "." not in key:
@ -57,6 +57,10 @@ def lora_token_vector_length(checkpoint: dict) -> int:
for key, tensor in checkpoint.items():
if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key):
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
elif key.startswith("lora_unet_") and (
"time_emb_proj.lora_down" in key
): # recognizes format at https://civitai.com/models/224641
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
elif key.startswith("lora_te") and "_self_attn_" in key:
tmp_length = _get_shape_1(key, tensor, checkpoint)
if key.startswith("lora_te_"):

View File

@ -0,0 +1,43 @@
"""
Initialization file for invokeai.backend.model_manager.metadata
Usage:
from invokeai.backend.model_manager.metadata import(
AnyModelRepoMetadata,
CommercialUsage,
LicenseRestrictions,
HuggingFaceMetadata,
CivitaiMetadata,
)
from invokeai.backend.model_manager.metadata.fetch import CivitaiMetadataFetch
data = CivitaiMetadataFetch().from_url("https://civitai.com/models/206883/split")
assert isinstance(data, CivitaiMetadata)
if data.allow_commercial_use:
print("Commercial use of this model is allowed")
"""
from .fetch import CivitaiMetadataFetch, HuggingFaceMetadataFetch
from .metadata_base import (
AnyModelRepoMetadata,
AnyModelRepoMetadataValidator,
CivitaiMetadata,
CommercialUsage,
HuggingFaceMetadata,
LicenseRestrictions,
)
from .metadata_store import ModelMetadataStore
__all__ = [
"AnyModelRepoMetadata",
"AnyModelRepoMetadataValidator",
"CommercialUsage",
"LicenseRestrictions",
"HuggingFaceMetadata",
"CivitaiMetadata",
"ModelMetadataStore",
"CivitaiMetadataFetch",
"HuggingFaceMetadataFetch",
]

View File

@ -0,0 +1,21 @@
"""
Initialization file for invokeai.backend.model_manager.metadata.fetch
Usage:
from invokeai.backend.model_manager.metadata.fetch import (
CivitaiMetadataFetch,
HuggingFaceMetadataFetch,
)
from invokeai.backend.model_manager.metadata import CivitaiMetadata
data = CivitaiMetadataFetch().from_url("https://civitai.com/models/206883/split")
assert isinstance(data, CivitaiMetadata)
if data.allow_commercial_use:
print("Commercial use of this model is allowed")
"""
from .fetch_base import ModelMetadataFetchBase
from .fetch_civitai import CivitaiMetadataFetch
from .fetch_huggingface import HuggingFaceMetadataFetch
__all__ = ["ModelMetadataFetchBase", "CivitaiMetadataFetch", "HuggingFaceMetadataFetch"]

View File

@ -0,0 +1,63 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
This module is the base class for subclasses that fetch metadata from model repositories
Usage:
from invokeai.backend.model_manager.metadata.fetch import CivitAIMetadataFetch
fetcher = CivitaiMetadataFetch()
metadata = fetcher.from_url("https://civitai.com/models/206883/split")
print(metadata.trained_words)
"""
from abc import ABC, abstractmethod
from typing import Optional
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator
class ModelMetadataFetchBase(ABC):
"""Fetch metadata from remote generative model repositories."""
@abstractmethod
def __init__(self, session: Optional[Session] = None):
"""
Initialize the fetcher with an optional requests.sessions.Session object.
By providing a configurable Session object, we can support unit tests on
this module without an internet connection.
"""
pass
@abstractmethod
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
"""
Given a URL to a model repository, return a ModelMetadata object.
This method will raise an `invokeai.app.services.model_records.UnknownModelException`
in the event that the requested model metadata is not found at the provided location.
"""
pass
@abstractmethod
def from_id(self, id: str) -> AnyModelRepoMetadata:
"""
Given an ID for a model, return a ModelMetadata object.
This method will raise an `invokeai.app.services.model_records.UnknownModelException`
in the event that the requested model's metadata is not found at the provided id.
"""
pass
@classmethod
def from_json(self, json: str) -> AnyModelRepoMetadata:
"""Given the JSON representation of the metadata, return the corresponding Pydantic object."""
metadata = AnyModelRepoMetadataValidator.validate_json(json)
return (
metadata # mypy complains that metadata is a <typing special form> and issues a type checking error. Why?
)

View File

@ -0,0 +1,157 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
This module fetches model metadata objects from the Civitai model repository.
In addition to the `from_url()` and `from_id()` methods inherited from the
`ModelMetadataFetchBase` base class.
Civitai has two separate ID spaces: a model ID and a version ID. The
version ID corresponds to a specific model, and is the ID accepted by
`from_id()`. The model ID corresponds to a family of related models,
such as different training checkpoints or 16 vs 32-bit versions. The
`from_civitai_modelid()` method will accept a model ID and return the
metadata from the default version within this model set. The default
version is the same as what the user sees when they click on a model's
thumbnail.
Usage:
from invokeai.backend.model_manager.metadata.fetch import CivitaiMetadataFetch
fetcher = CivitaiMetadataFetch()
metadata = fetcher.from_url("https://civitai.com/models/206883/split")
print(metadata.trained_words)
"""
import re
from datetime import datetime
from typing import Any, Dict, Optional
import requests
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
from invokeai.app.services.model_records import UnknownModelException
from ..metadata_base import (
AnyModelRepoMetadata,
AnyModelRepoMetadataValidator,
CivitaiMetadata,
CommercialUsage,
LicenseRestrictions,
)
from .fetch_base import ModelMetadataFetchBase
CIVITAI_MODEL_PAGE_RE = r"https?://civitai.com/models/(\d+)"
CIVITAI_VERSION_PAGE_RE = r"https?://civitai.com/models/(\d+)\?modelVersionId=(\d+)"
CIVITAI_DOWNLOAD_RE = r"https?://civitai.com/api/download/models/(\d+)"
CIVITAI_VERSION_ENDPOINT = "https://civitai.com/api/v1/model-versions/"
CIVITAI_MODEL_ENDPOINT = "https://civitai.com/api/v1/models/"
class CivitaiMetadataFetch(ModelMetadataFetchBase):
"""Fetch model metadata from Civitai."""
_requests: Session
def __init__(self, session: Optional[Session] = None):
"""
Initialize the fetcher with an optional requests.sessions.Session object.
By providing a configurable Session object, we can support unit tests on
this module without an internet connection.
"""
self._requests = session or requests.Session()
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
"""
Given a URL to a CivitAI model or version page, return a ModelMetadata object.
In the event that the URL points to a model page without the particular version
indicated, the default model version is returned. Otherwise, the requested version
is returned.
"""
if match := re.match(CIVITAI_MODEL_PAGE_RE, str(url)):
model_id = match.group(1)
return self.from_civitai_modelid(int(model_id))
elif match := re.match(CIVITAI_VERSION_PAGE_RE, str(url)):
version_id = match.group(1)
return self.from_civitai_versionid(int(version_id))
elif match := re.match(CIVITAI_DOWNLOAD_RE, str(url)):
version_id = match.group(1)
return self.from_civitai_versionid(int(version_id))
raise UnknownModelException("The url '{url}' does not match any known Civitai URL patterns")
def from_id(self, id: str) -> AnyModelRepoMetadata:
"""
Given a Civitai model version ID, return a ModelRepoMetadata object.
May raise an `UnknownModelException`.
"""
return self.from_civitai_versionid(int(id))
def from_civitai_modelid(self, model_id: int) -> CivitaiMetadata:
"""
Return metadata from the default version of the indicated model.
May raise an `UnknownModelException`.
"""
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
model_json = self._requests.get(model_url).json()
return self._from_model_json(model_json)
def _from_model_json(self, model_json: Dict[str, Any], version_id: Optional[int] = None) -> CivitaiMetadata:
version_id = version_id or model_json["modelVersions"][0]["id"]
# loop till we find the section containing the version requested
version_sections = [x for x in model_json["modelVersions"] if x["id"] == version_id]
if not version_sections:
raise UnknownModelException(f"Version {version_id} not found in model metadata")
version_json = version_sections[0]
safe_thumbnails = [x["url"] for x in version_json["images"] if x["nsfw"] == "None"]
return CivitaiMetadata(
id=model_json["id"],
name=model_json["name"],
version_id=version_json["id"],
version_name=version_json["name"],
created=datetime.fromisoformat(re.sub(r"Z$", "+00:00", version_json["createdAt"])),
updated=datetime.fromisoformat(re.sub(r"Z$", "+00:00", version_json["updatedAt"])),
published=datetime.fromisoformat(re.sub(r"Z$", "+00:00", version_json["publishedAt"])),
base_model_trained_on=version_json["baseModel"], # note - need a dictionary to turn into a BaseModelType
download_url=version_json["downloadUrl"],
thumbnail_url=safe_thumbnails[0] if safe_thumbnails else None,
author=model_json["creator"]["username"],
description=model_json["description"],
version_description=version_json["description"] or "",
tags=model_json["tags"],
trained_words=version_json["trainedWords"],
nsfw=model_json["nsfw"],
restrictions=LicenseRestrictions(
AllowNoCredit=model_json["allowNoCredit"],
AllowCommercialUse=CommercialUsage(model_json["allowCommercialUse"]),
AllowDerivatives=model_json["allowDerivatives"],
AllowDifferentLicense=model_json["allowDifferentLicense"],
),
)
def from_civitai_versionid(self, version_id: int) -> CivitaiMetadata:
"""
Return a CivitaiMetadata object given a model version id.
May raise an `UnknownModelException`.
"""
version_url = CIVITAI_VERSION_ENDPOINT + str(version_id)
version = self._requests.get(version_url).json()
model_url = CIVITAI_MODEL_ENDPOINT + str(version["modelId"])
model_json = self._requests.get(model_url).json()
return self._from_model_json(model_json, version_id)
@classmethod
def from_json(cls, json: str) -> CivitaiMetadata:
"""Given the JSON representation of the metadata, return the corresponding Pydantic object."""
metadata = AnyModelRepoMetadataValidator.validate_json(json)
assert isinstance(metadata, CivitaiMetadata)
return metadata

View File

@ -0,0 +1,84 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
This module fetches model metadata objects from the HuggingFace model repository,
using either a `repo_id` or the model page URL.
Usage:
from invokeai.backend.model_manager.metadata.fetch import HuggingFaceMetadataFetch
fetcher = HuggingFaceMetadataFetch()
metadata = fetcher.from_url("https://huggingface.co/stabilityai/sdxl-turbo")
print(metadata.tags)
"""
import re
from pathlib import Path
from typing import Optional
import requests
from huggingface_hub import HfApi, configure_http_backend
from huggingface_hub.utils._errors import RepositoryNotFoundError
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
from invokeai.app.services.model_records import UnknownModelException
from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator, HuggingFaceMetadata
from .fetch_base import ModelMetadataFetchBase
HF_MODEL_RE = r"https?://huggingface.co/([\w\-.]+/[\w\-.]+)"
class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
"""Fetch model metadata from HuggingFace."""
_requests: Session
def __init__(self, session: Optional[Session] = None):
"""
Initialize the fetcher with an optional requests.sessions.Session object.
By providing a configurable Session object, we can support unit tests on
this module without an internet connection.
"""
self._requests = session or requests.Session()
configure_http_backend(backend_factory=lambda: self._requests)
def from_id(self, id: str) -> AnyModelRepoMetadata:
"""Return a HuggingFaceMetadata object given the model's repo_id."""
try:
model_info = HfApi().model_info(repo_id=id, files_metadata=True)
except RepositoryNotFoundError as excp:
raise UnknownModelException(f"'{id}' not found. See trace for details.") from excp
_, name = id.split("/")
return HuggingFaceMetadata(
id=model_info.modelId,
author=model_info.author,
name=name,
last_modified=model_info.lastModified,
tag_dict=model_info.card_data.to_dict(),
tags=model_info.tags,
files=[Path(x.rfilename) for x in model_info.siblings],
)
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
"""
Return a HuggingFaceMetadata object given the model's web page URL.
In the case of an invalid or missing URL, raises a ModelNotFound exception.
"""
if match := re.match(HF_MODEL_RE, str(url)):
repo_id = match.group(1)
return self.from_id(repo_id)
else:
raise UnknownModelException(f"'{url}' does not look like a HuggingFace model page")
@classmethod
def from_json(cls, json: str) -> HuggingFaceMetadata:
"""Given the JSON representation of the metadata, return the corresponding Pydantic object."""
metadata = AnyModelRepoMetadataValidator.validate_json(json)
assert isinstance(metadata, HuggingFaceMetadata)
return metadata

View File

@ -0,0 +1,119 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""This module defines core text-to-image model metadata fields.
Metadata comprises any descriptive information that is not essential
for getting the model to run. For example "author" is metadata, while
"type", "base" and "format" are not. The latter fields are part of the
model's config, as defined in invokeai.backend.model_manager.config.
Note that the "name" and "description" are also present in `config`
records. This is intentional. The config record fields are intended to
be editable by the user as a form of customization. The metadata
versions of these fields are intended to be kept in sync with the
remote repo.
"""
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Set, Union
from pydantic import BaseModel, Field, TypeAdapter
from pydantic.networks import AnyHttpUrl
from typing_extensions import Annotated
class CommercialUsage(str, Enum):
"""Type of commercial usage allowed."""
No = "None"
Image = "Image"
Rent = "Rent"
RentCivit = "RentCivit"
Sell = "Sell"
class LicenseRestrictions(BaseModel):
"""Broad categories of licensing restrictions."""
AllowNoCredit: bool = Field(
description="if true, model can be redistributed without crediting author", default=False
)
AllowDerivatives: bool = Field(description="if true, derivatives of this model can be redistributed", default=False)
AllowDifferentLicense: bool = Field(
description="if true, derivatives of this model be redistributed under a different license", default=False
)
AllowCommercialUse: CommercialUsage = Field(
description="Type of commercial use allowed or 'No' if no commercial use is allowed.", default_factory=set
)
class ModelMetadataBase(BaseModel):
"""Base class for model metadata information."""
name: str = Field(description="model's name")
author: str = Field(description="model's author")
tags: Set[str] = Field(description="tags provided by model source")
class HuggingFaceMetadata(ModelMetadataBase):
"""Extended metadata fields provided by HuggingFace."""
type: Literal["huggingface"] = "huggingface"
id: str = Field(description="huggingface model id")
tag_dict: Dict[str, Any]
last_modified: datetime = Field(description="date of last commit to repo")
files: List[Path] = Field(description="sibling files that belong to this model", default_factory=list)
class CivitaiMetadata(ModelMetadataBase):
"""Extended metadata fields provided by Civitai."""
type: Literal["civitai"] = "civitai"
id: int = Field(description="Civitai model identifier")
version_name: str = Field(description="Version identifier, such as 'V2-alpha'")
version_id: int = Field(description="Civitai model version identifier")
created: datetime = Field(description="date the model was created")
updated: datetime = Field(description="date the model was last modified")
published: datetime = Field(description="date the model was published to Civitai")
description: str = Field(description="text description of model; may contain HTML")
version_description: str = Field(
description="text description of the model's reversion; usually change history; may contain HTML"
)
nsfw: bool = Field(description="whether the model tends to generate NSFW content", default=False)
restrictions: LicenseRestrictions = Field(description="license terms", default_factory=LicenseRestrictions)
trained_words: Set[str] = Field(description="words to trigger the model", default_factory=set)
download_url: AnyHttpUrl = Field(description="download URL for this model")
base_model_trained_on: str = Field(description="base model on which this model was trained (currently not an enum)")
thumbnail_url: Optional[AnyHttpUrl] = Field(description="a thumbnail image for this model", default=None)
weight_min: float = Field(
description="minimum suggested value for a LoRA or other secondary model", default=-1.0
) # note: For future use; not currently easily
weight_max: float = Field(
description="maximum suggested value for a LoRA or other secondary model", default=+2.0
) # recoverable from metadata
@property
def credit_required(self) -> bool:
"""Return True if you must give credit for derivatives of this model and images generated from it."""
return not self.restrictions.AllowNoCredit
@property
def allow_commercial_use(self) -> bool:
"""Return True if commercial use is allowed."""
return self.restrictions.AllowCommercialUse != CommercialUsage("None")
@property
def allow_derivatives(self) -> bool:
"""Return True if derivatives of this model can be redistributed."""
return self.restrictions.AllowDerivatives
@property
def allow_different_license(self) -> bool:
"""Return true if derivatives of this model can use a different license."""
return self.restrictions.AllowDifferentLicense
AnyModelRepoMetadata = Annotated[Union[HuggingFaceMetadata, CivitaiMetadata], Field(discriminator="type")]
AnyModelRepoMetadataValidator = TypeAdapter(AnyModelRepoMetadata)

View File

@ -0,0 +1,208 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
SQL Storage for Model Metadata
"""
import sqlite3
from typing import Optional, Set
from invokeai.app.services.model_records import UnknownModelException
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from .fetch import ModelMetadataFetchBase
from .metadata_base import AnyModelRepoMetadata
class ModelMetadataStore:
"""Store, search and fetch model metadata retrieved from remote repositories."""
_db: SqliteDatabase
_cursor: sqlite3.Cursor
def __init__(self, db: SqliteDatabase):
"""
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
:param conn: sqlite3 connection object
:param lock: threading Lock object
"""
super().__init__()
self._db = db
self._cursor = self._db.conn.cursor()
self._enable_foreign_key_constraints()
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
"""
Add a block of repo metadata to a model record.
The model record config must already exist in the database with the
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to store
"""
json_serialized = metadata.model_dump_json()
with self._db.lock:
try:
self._cursor.execute(
"""--sql
INSERT INTO model_metadata(
id,
metadata
)
VALUES (?,?);
""",
(
model_key,
json_serialized,
),
)
self._update_tags(model_key, metadata.tags)
self._db.conn.commit()
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
self._db.conn.rollback()
raise UnknownModelException from excp
except sqlite3.Error as excp:
self._db.conn.rollback()
raise excp
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
"""Retrieve the ModelRepoMetadata corresponding to model key."""
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT metadata FROM model_metadata
WHERE id=?;
""",
(model_key,),
)
rows = self._cursor.fetchone()
if not rows:
raise UnknownModelException("model metadata not found")
return ModelMetadataFetchBase.from_json(rows[0])
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
"""
Update metadata corresponding to the model with the indicated key.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to update
"""
json_serialized = metadata.model_dump_json() # turn it into a json string.
with self._db.lock:
try:
self._cursor.execute(
"""--sql
UPDATE model_metadata
SET
metadata=?
WHERE id=?;
""",
(json_serialized, model_key),
)
if self._cursor.rowcount == 0:
raise UnknownModelException("model not found")
self._update_tags(model_key, metadata.tags)
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
return self.get_metadata(model_key)
def search_by_tag(self, tags: Set[str]) -> Set[str]:
"""Return the keys of models containing all of the listed tags."""
with self._db.lock:
try:
matches: Optional[Set[str]] = None
for tag in tags:
self._cursor.execute(
"""--sql
SELECT a.id FROM model_tags AS a,
tags AS b
WHERE a.tag_id=b.tag_id
AND b.tag_text=?;
""",
(tag,),
)
model_keys = {x[0] for x in self._cursor.fetchall()}
if matches is None:
matches = model_keys
matches = matches.intersection(model_keys)
except sqlite3.Error as e:
raise e
return matches
def search_by_author(self, author: str) -> Set[str]:
"""Return the keys of models authored by the indicated author."""
self._cursor.execute(
"""--sql
SELECT id FROM model_metadata
WHERE author=?;
""",
(author,),
)
return {x[0] for x in self._cursor.fetchall()}
def search_by_name(self, name: str) -> Set[str]:
"""
Return the keys of models with the indicated name.
Note that this is the name of the model given to it by
the remote source. The user may have changed the local
name. The local name will be located in the model config
record object.
"""
self._cursor.execute(
"""--sql
SELECT id FROM model_metadata
WHERE name=?;
""",
(name,),
)
return {x[0] for x in self._cursor.fetchall()}
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
"""Update tags for the model referenced by model_key."""
# remove previous tags from this model
self._cursor.execute(
"""--sql
DELETE FROM model_tags
WHERE id=?;
""",
(model_key,),
)
for tag in tags:
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO tags (
tag_text
)
VALUES (?);
""",
(tag,),
)
self._cursor.execute(
"""--sql
SELECT tag_id
FROM tags
WHERE tag_text = ?
LIMIT 1;
""",
(tag,),
)
tag_id = self._cursor.fetchone()[0]
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO model_tags (
id,
tag_id
)
VALUES (?,?);
""",
(model_key, tag_id),
)
def _enable_foreign_key_constraints(self) -> None:
self._cursor.execute("PRAGMA foreign_keys = ON;")

View File

@ -400,6 +400,8 @@ class LoRACheckpointProbe(CheckpointProbeBase):
return BaseModelType.StableDiffusion1
elif token_vector_length == 1024:
return BaseModelType.StableDiffusion2
elif token_vector_length == 1280:
return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641
elif token_vector_length == 2048:
return BaseModelType.StableDiffusionXL
else:

View File

@ -102,7 +102,7 @@ def calc_tiles_with_overlap(
def calc_tiles_even_split(
image_height: int, image_width: int, num_tiles_x: int, num_tiles_y: int, overlap_fraction: float = 0
image_height: int, image_width: int, num_tiles_x: int, num_tiles_y: int, overlap: int = 0
) -> list[Tile]:
"""Calculate the tile coordinates for a given image shape with the number of tiles requested.
@ -111,31 +111,35 @@ def calc_tiles_even_split(
image_width (int): The image width in px.
num_x_tiles (int): The number of tile to split the image into on the X-axis.
num_y_tiles (int): The number of tile to split the image into on the Y-axis.
overlap_fraction (float, optional): The target overlap as fraction of the tiles size. Defaults to 0.
overlap (int, optional): The overlap between adjacent tiles in pixels. Defaults to 0.
Returns:
list[Tile]: A list of tiles that cover the image shape. Ordered from left-to-right, top-to-bottom.
"""
# Ensure tile size is divisible by 8
# Ensure the image is divisible by LATENT_SCALE_FACTOR
if image_width % LATENT_SCALE_FACTOR != 0 or image_height % LATENT_SCALE_FACTOR != 0:
raise ValueError(f"image size (({image_width}, {image_height})) must be divisible by {LATENT_SCALE_FACTOR}")
# Calculate the overlap size based on the percentage and adjust it to be divisible by 8 (rounding up)
overlap_x = LATENT_SCALE_FACTOR * math.ceil(
int((image_width / num_tiles_x) * overlap_fraction) / LATENT_SCALE_FACTOR
)
overlap_y = LATENT_SCALE_FACTOR * math.ceil(
int((image_height / num_tiles_y) * overlap_fraction) / LATENT_SCALE_FACTOR
)
# Calculate the tile size based on the number of tiles and overlap, and ensure it's divisible by 8 (rounding down)
tile_size_x = LATENT_SCALE_FACTOR * math.floor(
((image_width + overlap_x * (num_tiles_x - 1)) // num_tiles_x) / LATENT_SCALE_FACTOR
)
tile_size_y = LATENT_SCALE_FACTOR * math.floor(
((image_height + overlap_y * (num_tiles_y - 1)) // num_tiles_y) / LATENT_SCALE_FACTOR
)
if num_tiles_x > 1:
# ensure the overlap is not more than the maximum overlap if we only have 1 tile then we dont care about overlap
assert overlap <= image_width - (LATENT_SCALE_FACTOR * (num_tiles_x - 1))
tile_size_x = LATENT_SCALE_FACTOR * math.floor(
((image_width + overlap * (num_tiles_x - 1)) // num_tiles_x) / LATENT_SCALE_FACTOR
)
assert overlap < tile_size_x
else:
tile_size_x = image_width
if num_tiles_y > 1:
# ensure the overlap is not more than the maximum overlap if we only have 1 tile then we dont care about overlap
assert overlap <= image_height - (LATENT_SCALE_FACTOR * (num_tiles_y - 1))
tile_size_y = LATENT_SCALE_FACTOR * math.floor(
((image_height + overlap * (num_tiles_y - 1)) // num_tiles_y) / LATENT_SCALE_FACTOR
)
assert overlap < tile_size_y
else:
tile_size_y = image_height
# tiles[y * num_tiles_x + x] is the tile for the y'th row, x'th column.
tiles: list[Tile] = []
@ -143,7 +147,7 @@ def calc_tiles_even_split(
# Calculate tile coordinates. (Ignore overlap values for now.)
for tile_idx_y in range(num_tiles_y):
# Calculate the top and bottom of the row
top = tile_idx_y * (tile_size_y - overlap_y)
top = tile_idx_y * (tile_size_y - overlap)
bottom = min(top + tile_size_y, image_height)
# For the last row adjust bottom to be the height of the image
if tile_idx_y == num_tiles_y - 1:
@ -151,7 +155,7 @@ def calc_tiles_even_split(
for tile_idx_x in range(num_tiles_x):
# Calculate the left & right coordinate of each tile
left = tile_idx_x * (tile_size_x - overlap_x)
left = tile_idx_x * (tile_size_x - overlap)
right = min(left + tile_size_x, image_width)
# For the last tile in the row adjust right to be the width of the image
if tile_idx_x == num_tiles_x - 1:

View File

@ -1,11 +1,9 @@
from __future__ import annotations
import platform
from contextlib import nullcontext
from typing import Union
import torch
from packaging import version
from torch import autocast
from invokeai.app.services.config import InvokeAIAppConfig
@ -37,7 +35,7 @@ def choose_precision(device: torch.device) -> str:
device_name = torch.cuda.get_device_name(device)
if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name):
return "float16"
elif device.type == "mps" and version.parse(platform.mac_ver()[0]) < version.parse("14.0.0"):
elif device.type == "mps":
return "float16"
return "float32"

View File

@ -4,6 +4,7 @@ pip install <path_to_git_source>.
"""
import os
import platform
from distutils.version import LooseVersion
import pkg_resources
import psutil
@ -31,10 +32,6 @@ else:
console = Console(style=Style(color="grey74", bgcolor="grey19"))
def get_versions() -> dict:
return requests.get(url=INVOKE_AI_REL).json()
def invokeai_is_running() -> bool:
for p in psutil.process_iter():
try:
@ -50,6 +47,20 @@ def invokeai_is_running() -> bool:
return False
def get_pypi_versions():
url = "https://pypi.org/pypi/invokeai/json"
try:
data = requests.get(url).json()
except Exception:
raise Exception("Unable to fetch version information from PyPi")
versions = list(data["releases"].keys())
versions.sort(key=LooseVersion, reverse=True)
latest_version = [v for v in versions if "rc" not in v][0]
latest_release_candidate = [v for v in versions if "rc" in v][0]
return latest_version, latest_release_candidate, versions
def welcome(latest_release: str, latest_prerelease: str):
@group()
def text():
@ -63,8 +74,7 @@ def welcome(latest_release: str, latest_prerelease: str):
yield "[bold yellow]Options:"
yield f"""[1] Update to the latest [bold]official release[/bold] ([italic]{latest_release}[/italic])
[2] Update to the latest [bold]pre-release[/bold] (may be buggy; caveat emptor!) ([italic]{latest_prerelease}[/italic])
[3] Manually enter the [bold]tag name[/bold] for the version you wish to update to
[4] Manually enter the [bold]branch name[/bold] for the version you wish to update to"""
[3] Manually enter the [bold]version[/bold] you wish to update to"""
console.rule()
print(
@ -92,44 +102,35 @@ def get_extras():
def main():
versions = get_versions()
released_versions = [x for x in versions if not (x["draft"] or x["prerelease"])]
prerelease_versions = [x for x in versions if not x["draft"] and x["prerelease"]]
latest_release = released_versions[0]["tag_name"] if len(released_versions) else None
latest_prerelease = prerelease_versions[0]["tag_name"] if len(prerelease_versions) else None
if invokeai_is_running():
print(":exclamation: [bold red]Please terminate all running instances of InvokeAI before updating.[/red bold]")
input("Press any key to continue...")
return
latest_release, latest_prerelease, versions = get_pypi_versions()
welcome(latest_release, latest_prerelease)
tag = None
branch = None
release = None
choice = Prompt.ask("Choice:", choices=["1", "2", "3", "4"], default="1")
release = latest_release
choice = Prompt.ask("Choice:", choices=["1", "2", "3"], default="1")
if choice == "1":
release = latest_release
elif choice == "2":
release = latest_prerelease
elif choice == "3":
while not tag:
tag = Prompt.ask("Enter an InvokeAI tag name")
elif choice == "4":
while not branch:
branch = Prompt.ask("Enter an InvokeAI branch name")
while True:
release = Prompt.ask("Enter an InvokeAI version")
release.strip()
if release in versions:
break
print(f":exclamation: [bold red]'{release}' is not a recognized InvokeAI release.[/red bold]")
extras = get_extras()
print(f":crossed_fingers: Upgrading to [yellow]{tag or release or branch}[/yellow]")
if release:
cmd = f'pip install "invokeai{extras} @ {INVOKE_AI_SRC}/{release}.zip" --use-pep517 --upgrade'
elif tag:
cmd = f'pip install "invokeai{extras} @ {INVOKE_AI_TAG}/{tag}.zip" --use-pep517 --upgrade'
else:
cmd = f'pip install "invokeai{extras} @ {INVOKE_AI_BRANCH}/{branch}.zip" --use-pep517 --upgrade'
print(f":crossed_fingers: Upgrading to [yellow]{release}[/yellow]")
cmd = f'pip install "invokeai{extras}=={release}" --use-pep517 --upgrade'
print("")
print("")
if os.system(cmd) == 0:

View File

@ -727,9 +727,6 @@
"showMinimapnodes": "Mostrar el minimapa",
"reloadNodeTemplates": "Recargar las plantillas de nodos",
"loadWorkflow": "Cargar el flujo de trabajo",
"resetWorkflow": "Reiniciar e flujo de trabajo",
"resetWorkflowDesc": "¿Está seguro de que deseas restablecer este flujo de trabajo?",
"resetWorkflowDesc2": "Al reiniciar el flujo de trabajo se borrarán todos los nodos, aristas y detalles del flujo de trabajo.",
"downloadWorkflow": "Descargar el flujo de trabajo en un archivo JSON"
}
}

View File

@ -898,11 +898,8 @@
"zoomInNodes": "Ingrandire",
"fitViewportNodes": "Adatta vista",
"showGraphNodes": "Mostra sovrapposizione grafico",
"resetWorkflowDesc2": "Il ripristino dell'editor del flusso di lavoro cancellerà tutti i nodi, le connessioni e i dettagli del flusso di lavoro. I flussi di lavoro salvati non saranno interessati.",
"reloadNodeTemplates": "Ricarica i modelli di nodo",
"loadWorkflow": "Importa flusso di lavoro JSON",
"resetWorkflow": "Reimposta l'editor del flusso di lavoro",
"resetWorkflowDesc": "Sei sicuro di voler reimpostare l'editor del flusso di lavoro?",
"downloadWorkflow": "Esporta flusso di lavoro JSON",
"scheduler": "Campionatore",
"addNode": "Aggiungi nodo",
@ -1112,7 +1109,10 @@
"deletedInvalidEdge": "Eliminata connessione non valida {{source}} -> {{target}}",
"unknownInput": "Input sconosciuto: {{name}}",
"prototypeDesc": "Questa invocazione è un prototipo. Potrebbe subire modifiche sostanziali durante gli aggiornamenti dell'app e potrebbe essere rimossa in qualsiasi momento.",
"betaDesc": "Questa invocazione è in versione beta. Fino a quando non sarà stabile, potrebbe subire modifiche importanti durante gli aggiornamenti dell'app. Abbiamo intenzione di supportare questa invocazione a lungo termine."
"betaDesc": "Questa invocazione è in versione beta. Fino a quando non sarà stabile, potrebbe subire modifiche importanti durante gli aggiornamenti dell'app. Abbiamo intenzione di supportare questa invocazione a lungo termine.",
"newWorkflow": "Nuovo flusso di lavoro",
"newWorkflowDesc": "Creare un nuovo flusso di lavoro?",
"newWorkflowDesc2": "Il flusso di lavoro attuale presenta modifiche non salvate."
},
"boards": {
"autoAddBoard": "Aggiungi automaticamente bacheca",
@ -1619,7 +1619,6 @@
"saveWorkflow": "Salva flusso di lavoro",
"openWorkflow": "Apri flusso di lavoro",
"clearWorkflowSearchFilter": "Cancella il filtro di ricerca del flusso di lavoro",
"workflowEditorReset": "Reimpostazione dell'editor del flusso di lavoro",
"workflowLibrary": "Libreria",
"noRecentWorkflows": "Nessun flusso di lavoro recente",
"workflowSaved": "Flusso di lavoro salvato",
@ -1633,7 +1632,10 @@
"deleteWorkflow": "Elimina flusso di lavoro",
"workflows": "Flussi di lavoro",
"noDescription": "Nessuna descrizione",
"userWorkflows": "I miei flussi di lavoro"
"userWorkflows": "I miei flussi di lavoro",
"newWorkflowCreated": "Nuovo flusso di lavoro creato",
"downloadWorkflow": "Salva su file",
"uploadWorkflow": "Carica da file"
},
"app": {
"storeNotInitialized": "Il negozio non è inizializzato"

View File

@ -844,9 +844,6 @@
"hideLegendNodes": "Typelegende veld verbergen",
"reloadNodeTemplates": "Herlaad knooppuntsjablonen",
"loadWorkflow": "Laad werkstroom",
"resetWorkflow": "Herstel werkstroom",
"resetWorkflowDesc": "Weet je zeker dat je deze werkstroom wilt herstellen?",
"resetWorkflowDesc2": "Herstel van een werkstroom haalt alle knooppunten, randen en werkstroomdetails weg.",
"downloadWorkflow": "Download JSON van werkstroom",
"booleanPolymorphicDescription": "Een verzameling Booleanse waarden.",
"scheduler": "Planner",

View File

@ -909,9 +909,6 @@
"hideLegendNodes": "Скрыть тип поля",
"showMinimapnodes": "Показать миникарту",
"loadWorkflow": "Загрузить рабочий процесс",
"resetWorkflowDesc2": "Сброс рабочего процесса очистит все узлы, ребра и детали рабочего процесса.",
"resetWorkflow": "Сбросить рабочий процесс",
"resetWorkflowDesc": "Вы уверены, что хотите сбросить этот рабочий процесс?",
"reloadNodeTemplates": "Перезагрузить шаблоны узлов",
"downloadWorkflow": "Скачать JSON рабочего процесса",
"booleanPolymorphicDescription": "Коллекция логических значений.",
@ -1599,7 +1596,6 @@
"saveWorkflow": "Сохранить рабочий процесс",
"openWorkflow": "Открытый рабочий процесс",
"clearWorkflowSearchFilter": "Очистить фильтр поиска рабочих процессов",
"workflowEditorReset": "Сброс редактора рабочих процессов",
"workflowLibrary": "Библиотека",
"downloadWorkflow": "Скачать рабочий процесс",
"noRecentWorkflows": "Нет недавних рабочих процессов",

View File

@ -892,11 +892,8 @@
},
"nodes": {
"zoomInNodes": "放大",
"resetWorkflowDesc": "是否确定要重置工作流编辑器?",
"resetWorkflow": "重置工作流编辑器",
"loadWorkflow": "加载工作流",
"zoomOutNodes": "缩小",
"resetWorkflowDesc2": "重置工作流编辑器将清除所有节点、边际和节点图详情。不影响已保存的工作流。",
"reloadNodeTemplates": "重载节点模板",
"hideGraphNodes": "隐藏节点图信息",
"fitViewportNodes": "自适应视图",
@ -1122,7 +1119,10 @@
"deletedInvalidEdge": "已删除无效的边缘 {{source}} -> {{target}}",
"unknownInput": "未知输入:{{name}}",
"prototypeDesc": "此调用是一个原型 (prototype)。它可能会在本项目更新期间发生破坏性更改,并且随时可能被删除。",
"betaDesc": "此调用尚处于测试阶段。在稳定之前,它可能会在项目更新期间发生破坏性更改。本项目计划长期支持这种调用。"
"betaDesc": "此调用尚处于测试阶段。在稳定之前,它可能会在项目更新期间发生破坏性更改。本项目计划长期支持这种调用。",
"newWorkflow": "新建工作流",
"newWorkflowDesc": "是否创建一个新的工作流?",
"newWorkflowDesc2": "当前工作流有未保存的更改。"
},
"controlnet": {
"resize": "直接缩放",
@ -1637,9 +1637,8 @@
"saveWorkflow": "保存工作流",
"openWorkflow": "打开工作流",
"clearWorkflowSearchFilter": "清除工作流检索过滤器",
"workflowEditorReset": "工作流编辑器重置",
"workflowLibrary": "工作流库",
"downloadWorkflow": "下载工作流",
"downloadWorkflow": "保存到文件",
"noRecentWorkflows": "无最近工作流",
"workflowSaved": "已保存工作流",
"workflowIsOpen": "工作流已打开",
@ -1652,8 +1651,9 @@
"deleteWorkflow": "删除工作流",
"workflows": "工作流",
"noDescription": "无描述",
"uploadWorkflow": "上传工作流",
"userWorkflows": "我的工作流"
"uploadWorkflow": "从文件中加载",
"userWorkflows": "我的工作流",
"newWorkflowCreated": "已创建新的工作流"
},
"app": {
"storeNotInitialized": "商店尚未初始化"

View File

@ -34,6 +34,7 @@ import { actionSanitizer } from './middleware/devtools/actionSanitizer';
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
import { listenerMiddleware } from './middleware/listenerMiddleware';
import { authToastMiddleware } from 'services/api/authToastMiddleware';
const allReducers = {
canvas: canvasReducer,
@ -96,6 +97,7 @@ export const createStore = (uniqueStoreKey?: string, persist = true) =>
})
.concat(api.middleware)
.concat(dynamicMiddlewares)
.concat(authToastMiddleware)
.prepend(listenerMiddleware.middleware),
enhancers: (getDefaultEnhancers) => {
const _enhancers = getDefaultEnhancers().concat(autoBatchEnhancer());

View File

@ -2,11 +2,14 @@ import { Text } from '@chakra-ui/layout';
import { useAppSelector } from 'app/store/storeHooks';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
const TopCenterPanel = () => {
const { t } = useTranslation();
const name = useAppSelector((state) => state.workflow.name);
const isTouched = useAppSelector((state) => state.workflow.isTouched);
const isWorkflowLibraryEnabled =
useFeatureStatus('workflowLibrary').isFeatureEnabled;
return (
<Text
@ -19,7 +22,7 @@ const TopCenterPanel = () => {
opacity={0.8}
>
{name || t('workflows.unnamedWorkflow')}
{isTouched ? ` (${t('common.unsaved')})` : ''}
{isTouched && isWorkflowLibraryEnabled ? ` (${t('common.unsaved')})` : ''}
</Text>
);
};

View File

@ -144,6 +144,7 @@ export const buildCanvasImageToImageGraph = (
type: 'l2i',
id: CANVAS_OUTPUT,
is_intermediate,
use_cache: false,
},
},
edges: [
@ -255,6 +256,7 @@ export const buildCanvasImageToImageGraph = (
is_intermediate,
width: width,
height: height,
use_cache: false,
};
graph.edges.push(
@ -295,6 +297,7 @@ export const buildCanvasImageToImageGraph = (
id: CANVAS_OUTPUT,
is_intermediate,
fp32,
use_cache: false,
};
(graph.nodes[IMAGE_TO_LATENTS] as ImageToLatentsInvocation).image =

View File

@ -191,6 +191,7 @@ export const buildCanvasInpaintGraph = (
id: CANVAS_OUTPUT,
is_intermediate,
reference: canvasInitImage,
use_cache: false,
},
},
edges: [

View File

@ -199,6 +199,7 @@ export const buildCanvasOutpaintGraph = (
type: 'color_correct',
id: CANVAS_OUTPUT,
is_intermediate,
use_cache: false,
},
},
edges: [

View File

@ -266,6 +266,7 @@ export const buildCanvasSDXLImageToImageGraph = (
is_intermediate,
width: width,
height: height,
use_cache: false,
};
graph.edges.push(
@ -306,6 +307,7 @@ export const buildCanvasSDXLImageToImageGraph = (
id: CANVAS_OUTPUT,
is_intermediate,
fp32,
use_cache: false,
};
(graph.nodes[IMAGE_TO_LATENTS] as ImageToLatentsInvocation).image =

View File

@ -196,6 +196,7 @@ export const buildCanvasSDXLInpaintGraph = (
id: CANVAS_OUTPUT,
is_intermediate,
reference: canvasInitImage,
use_cache: false,
},
},
edges: [

View File

@ -204,6 +204,7 @@ export const buildCanvasSDXLOutpaintGraph = (
type: 'color_correct',
id: CANVAS_OUTPUT,
is_intermediate,
use_cache: false,
},
},
edges: [

View File

@ -258,6 +258,7 @@ export const buildCanvasSDXLTextToImageGraph = (
is_intermediate,
width: width,
height: height,
use_cache: false,
};
graph.edges.push(
@ -288,6 +289,7 @@ export const buildCanvasSDXLTextToImageGraph = (
id: CANVAS_OUTPUT,
is_intermediate,
fp32,
use_cache: false,
};
graph.edges.push({

View File

@ -246,6 +246,7 @@ export const buildCanvasTextToImageGraph = (
is_intermediate,
width: width,
height: height,
use_cache: false,
};
graph.edges.push(
@ -276,6 +277,7 @@ export const buildCanvasTextToImageGraph = (
id: CANVAS_OUTPUT,
is_intermediate,
fp32,
use_cache: false,
};
graph.edges.push({

View File

@ -143,6 +143,7 @@ export const buildLinearImageToImageGraph = (
// },
fp32,
is_intermediate,
use_cache: false,
},
},
edges: [

View File

@ -154,6 +154,7 @@ export const buildLinearSDXLImageToImageGraph = (
// },
fp32,
is_intermediate,
use_cache: false,
},
},
edges: [

View File

@ -127,6 +127,7 @@ export const buildLinearSDXLTextToImageGraph = (
id: LATENTS_TO_IMAGE,
fp32,
is_intermediate,
use_cache: false,
},
},
edges: [

View File

@ -146,6 +146,7 @@ export const buildLinearTextToImageGraph = (
id: LATENTS_TO_IMAGE,
fp32,
is_intermediate,
use_cache: false,
},
},
edges: [

View File

@ -5,12 +5,10 @@ import { t } from 'i18next';
import { z } from 'zod';
const zRejectedForbiddenAction = z.object({
action: z.object({
payload: z.object({
status: z.literal(403),
data: z.object({
detail: z.string(),
}),
payload: z.object({
status: z.literal(403),
data: z.object({
detail: z.string(),
}),
}),
});
@ -22,8 +20,8 @@ export const authToastMiddleware: Middleware =
const parsed = zRejectedForbiddenAction.parse(action);
const { dispatch } = api;
const customMessage =
parsed.action.payload.data.detail !== 'Forbidden'
? parsed.action.payload.data.detail
parsed.payload.data.detail !== 'Forbidden'
? parsed.payload.data.detail
: undefined;
dispatch(
addToast({
@ -32,7 +30,7 @@ export const authToastMiddleware: Middleware =
description: customMessage,
})
);
} catch {
} catch (error) {
// no-op
}
}

View File

@ -172,6 +172,8 @@ nav:
- Adding Tests: 'contributing/TESTS.md'
- Documentation: 'contributing/contribution_guides/documentation.md'
- Nodes: 'contributing/INVOCATIONS.md'
- Model Manager: 'contributing/MODEL_MANAGER.md'
- Download Queue: 'contributing/DOWNLOAD_QUEUE.md'
- Translation: 'contributing/contribution_guides/translation.md'
- Tutorials: 'contributing/contribution_guides/tutorials.md'
- Changelog: 'CHANGELOG.md'

View File

@ -105,6 +105,7 @@ dependencies = [
"pytest>6.0.0",
"pytest-cov",
"pytest-datadir",
"requests_testadapter",
]
"xformers" = [
"xformers==0.0.23; sys_platform!='darwin'",
@ -138,7 +139,6 @@ dependencies = [
"invokeai-node-web" = "invokeai.app.api_app:invoke_api"
"invokeai-import-images" = "invokeai.frontend.install.import_images:main"
"invokeai-db-maintenance" = "invokeai.backend.util.db_maintenance:main"
"invokeai-migrate-models-to-db" = "invokeai.backend.model_manager.migrate_to_db:main"
[project.urls]
"Homepage" = "https://invoke-ai.github.io/InvokeAI/"

View File

@ -26,7 +26,6 @@ from invokeai.app.services.shared.graph import (
Graph,
GraphExecutionState,
IterateInvocation,
LibraryGraph,
)
from invokeai.backend.util.logging import InvokeAILogger
from tests.fixtures.sqlite_database import create_mock_sqlite_database
@ -61,7 +60,6 @@ def mock_services() -> InvocationServices:
configuration=configuration,
events=TestEventService(),
graph_execution_manager=graph_execution_manager,
graph_library=SqliteItemStorage[LibraryGraph](db=db, table_name="graphs"),
image_files=None, # type: ignore
image_records=None, # type: ignore
images=None, # type: ignore
@ -70,6 +68,7 @@ def mock_services() -> InvocationServices:
logger=logging, # type: ignore
model_manager=None, # type: ignore
model_records=None, # type: ignore
download_queue=None, # type: ignore
model_install=None, # type: ignore
names=None, # type: ignore
performance_statistics=InvocationStatsService(),

View File

@ -24,7 +24,7 @@ from invokeai.app.services.invocation_stats.invocation_stats_default import Invo
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.item_storage.item_storage_sqlite import SqliteItemStorage
from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, GraphInvocation, LibraryGraph
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, GraphInvocation
@pytest.fixture
@ -66,7 +66,6 @@ def mock_services() -> InvocationServices:
configuration=configuration,
events=TestEventService(),
graph_execution_manager=graph_execution_manager,
graph_library=SqliteItemStorage[LibraryGraph](db=db, table_name="graphs"),
image_files=None, # type: ignore
image_records=None, # type: ignore
images=None, # type: ignore
@ -75,6 +74,7 @@ def mock_services() -> InvocationServices:
logger=logging, # type: ignore
model_manager=None, # type: ignore
model_records=None, # type: ignore
download_queue=None, # type: ignore
model_install=None, # type: ignore
names=None, # type: ignore
performance_statistics=InvocationStatsService(),

View File

@ -0,0 +1,223 @@
"""Test the queued download facility"""
import re
import time
from pathlib import Path
from typing import Any, Dict, List
import pytest
import requests
from pydantic import BaseModel
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
from requests_testadapter import TestAdapter
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService
from invokeai.app.services.events.events_base import EventServiceBase
# Prevent pytest deprecation warnings
TestAdapter.__test__ = False
@pytest.fixture
def session() -> requests.sessions.Session:
sess = requests.Session()
for i in ["12345", "9999", "54321"]:
content = (
b"I am a safetensors file " + bytearray(i, "utf-8") + bytearray(32_000)
) # for pause tests, must make content large
sess.mount(
f"http://www.civitai.com/models/{i}",
TestAdapter(
content,
headers={
"Content-Length": len(content),
"Content-Disposition": f'filename="mock{i}.safetensors"',
},
),
)
# here are some malformed URLs to test
# missing the content length
sess.mount(
"http://www.civitai.com/models/missing",
TestAdapter(
b"Missing content length",
headers={
"Content-Disposition": 'filename="missing.txt"',
},
),
)
# not found test
sess.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404))
return sess
class DummyEvent(BaseModel):
"""Dummy Event to use with Dummy Event service."""
event_name: str
payload: Dict[str, Any]
# A dummy event service for testing event issuing
class DummyEventService(EventServiceBase):
"""Dummy event service for testing."""
events: List[DummyEvent]
def __init__(self) -> None:
super().__init__()
self.events = []
def dispatch(self, event_name: str, payload: Any) -> None:
"""Dispatch an event by appending it to self.events."""
self.events.append(DummyEvent(event_name=payload["event"], payload=payload["data"]))
def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
events = set()
def event_handler(job: DownloadJob) -> None:
events.add(job.status)
queue = DownloadQueueService(
requests_session=session,
)
queue.start()
job = queue.download(
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
dest=tmp_path,
on_start=event_handler,
on_progress=event_handler,
on_complete=event_handler,
on_error=event_handler,
)
assert isinstance(job, DownloadJob), "expected the job to be of type DownloadJobBase"
assert isinstance(job.id, int), "expected the job id to be numeric"
queue.join()
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
assert Path(tmp_path, "mock12345.safetensors").exists(), f"expected {tmp_path}/mock12345.safetensors to exist"
assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED}
queue.stop()
def test_errors(tmp_path: Path, session: Session) -> None:
queue = DownloadQueueService(
requests_session=session,
)
queue.start()
for bad_url in ["http://www.civitai.com/models/broken", "http://www.civitai.com/models/missing"]:
queue.download(AnyHttpUrl(bad_url), dest=tmp_path)
queue.join()
jobs = queue.list_jobs()
print(jobs)
assert len(jobs) == 2
jobs_dict = {str(x.source): x for x in jobs}
assert jobs_dict["http://www.civitai.com/models/broken"].status == DownloadJobStatus.ERROR
assert jobs_dict["http://www.civitai.com/models/broken"].error_type == "HTTPError(NOT FOUND)"
assert jobs_dict["http://www.civitai.com/models/missing"].status == DownloadJobStatus.COMPLETED
assert jobs_dict["http://www.civitai.com/models/missing"].total_bytes == 0
queue.stop()
def test_event_bus(tmp_path: Path, session: Session) -> None:
event_bus = DummyEventService()
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
queue.start()
queue.download(
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
dest=tmp_path,
)
queue.join()
events = event_bus.events
assert len(events) == 3
assert events[0].payload["timestamp"] <= events[1].payload["timestamp"]
assert events[1].payload["timestamp"] <= events[2].payload["timestamp"]
assert events[0].event_name == "download_started"
assert events[1].event_name == "download_progress"
assert events[1].payload["total_bytes"] > 0
assert events[1].payload["current_bytes"] <= events[1].payload["total_bytes"]
assert events[2].event_name == "download_complete"
assert events[2].payload["total_bytes"] == 32029
# test a failure
event_bus.events = [] # reset our accumulator
queue.download(source=AnyHttpUrl("http://www.civitai.com/models/broken"), dest=tmp_path)
queue.join()
events = event_bus.events
print("\n".join([x.model_dump_json() for x in events]))
assert len(events) == 1
assert events[0].event_name == "download_error"
assert events[0].payload["error_type"] == "HTTPError(NOT FOUND)"
assert events[0].payload["error"] is not None
assert re.search(r"requests.exceptions.HTTPError: NOT FOUND", events[0].payload["error"])
queue.stop()
def test_broken_callbacks(tmp_path: Path, session: requests.sessions.Session, capsys) -> None:
queue = DownloadQueueService(
requests_session=session,
)
queue.start()
callback_ran = False
def broken_callback(job: DownloadJob) -> None:
nonlocal callback_ran
callback_ran = True
print(1 / 0) # deliberate error here
job = queue.download(
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
dest=tmp_path,
on_progress=broken_callback,
)
queue.join()
assert job.status == DownloadJobStatus.COMPLETED # should complete even though the callback is borked
assert Path(tmp_path, "mock12345.safetensors").exists()
assert callback_ran
# LS: The pytest capsys fixture does not seem to be working. I can see the
# correct stderr message in the pytest log, but it is not appearing in
# capsys.readouterr().
# captured = capsys.readouterr()
# assert re.search("division by zero", captured.err)
queue.stop()
def test_cancel(tmp_path: Path, session: requests.sessions.Session) -> None:
event_bus = DummyEventService()
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
queue.start()
cancelled = False
def slow_callback(job: DownloadJob) -> None:
time.sleep(2)
def cancelled_callback(job: DownloadJob) -> None:
nonlocal cancelled
cancelled = True
job = queue.download(
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
dest=tmp_path,
on_start=slow_callback,
on_cancelled=cancelled_callback,
)
queue.cancel_job(job)
queue.join()
assert job.status == DownloadJobStatus.CANCELLED
assert cancelled
events = event_bus.events
assert events[-1].event_name == "download_cancelled"
assert events[-1].payload["source"] == "http://www.civitai.com/models/12345"
queue.stop()

View File

@ -48,11 +48,13 @@ def store(
@pytest.fixture
def installer(app_config: InvokeAIAppConfig, store: ModelRecordServiceBase) -> ModelInstallServiceBase:
return ModelInstallService(
installer = ModelInstallService(
app_config=app_config,
record_store=store,
event_bus=DummyEventService(),
)
installer.start()
return installer
class DummyEvent(BaseModel):

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,244 @@
"""
Test model metadata fetching and storage.
"""
import datetime
from pathlib import Path
import pytest
import requests
from pydantic.networks import HttpUrl
from requests.sessions import Session
from requests_testadapter import TestAdapter
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_records import ModelRecordServiceSQL, UnknownModelException
from invokeai.backend.model_manager.config import (
BaseModelType,
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.metadata import (
CivitaiMetadata,
CivitaiMetadataFetch,
CommercialUsage,
HuggingFaceMetadata,
HuggingFaceMetadataFetch,
ModelMetadataStore,
)
from invokeai.backend.util.logging import InvokeAILogger
from tests.backend.model_manager_2.model_metadata.metadata_examples import (
RepoCivitaiModelMetadata1,
RepoCivitaiVersionMetadata1,
RepoHFMetadata1,
)
from tests.fixtures.sqlite_database import create_mock_sqlite_database
@pytest.fixture
def app_config(datadir: Path) -> InvokeAIAppConfig:
return InvokeAIAppConfig(
root=datadir / "root",
models_dir=datadir / "root/models",
)
@pytest.fixture
def record_store(app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL:
logger = InvokeAILogger.get_logger(config=app_config)
db = create_mock_sqlite_database(app_config, logger)
store = ModelRecordServiceSQL(db)
# add three simple config records to the database
raw1 = {
"path": "/tmp/foo1",
"format": ModelFormat("diffusers"),
"name": "test2",
"base": BaseModelType("sd-2"),
"type": ModelType("vae"),
"original_hash": "111222333444",
"source": "stabilityai/sdxl-vae",
}
raw2 = {
"path": "/tmp/foo2.ckpt",
"name": "model1",
"format": ModelFormat("checkpoint"),
"base": BaseModelType("sd-1"),
"type": "main",
"config": "/tmp/foo.yaml",
"variant": "normal",
"original_hash": "111222333444",
"source": "https://civitai.com/models/206883/split",
}
raw3 = {
"path": "/tmp/foo3",
"format": ModelFormat("diffusers"),
"name": "test3",
"base": BaseModelType("sdxl"),
"type": ModelType("main"),
"original_hash": "111222333444",
"source": "author3/model3",
}
store.add_model("test_config_1", raw1)
store.add_model("test_config_2", raw2)
store.add_model("test_config_3", raw3)
return store
@pytest.fixture
def session() -> Session:
sess = requests.Session()
sess.mount(
"https://huggingface.co/api/models/stabilityai/sdxl-turbo",
TestAdapter(
RepoHFMetadata1,
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(RepoHFMetadata1)},
),
)
sess.mount(
"https://civitai.com/api/v1/model-versions/242807",
TestAdapter(
RepoCivitaiVersionMetadata1,
headers={
"Content-Length": len(RepoCivitaiVersionMetadata1),
},
),
)
sess.mount(
"https://civitai.com/api/v1/models/215485",
TestAdapter(
RepoCivitaiModelMetadata1,
headers={
"Content-Length": len(RepoCivitaiModelMetadata1),
},
),
)
return sess
@pytest.fixture
def metadata_store(record_store: ModelRecordServiceSQL) -> ModelMetadataStore:
db = record_store._db # to ensure we are sharing the same database
return ModelMetadataStore(db)
def test_metadata_store_put_get(metadata_store: ModelMetadataStore) -> None:
input_metadata = HuggingFaceMetadata(
name="sdxl-vae",
author="stabilityai",
tags={"text-to-image", "diffusers"},
id="stabilityai/sdxl-vae",
tag_dict={"license": "other"},
last_modified=datetime.datetime.now(),
)
metadata_store.add_metadata("test_config_1", input_metadata)
output_metadata = metadata_store.get_metadata("test_config_1")
assert input_metadata == output_metadata
with pytest.raises(UnknownModelException):
metadata_store.add_metadata("unknown_key", input_metadata)
def test_metadata_store_update(metadata_store: ModelMetadataStore) -> None:
input_metadata = HuggingFaceMetadata(
name="sdxl-vae",
author="stabilityai",
tags={"text-to-image", "diffusers"},
id="stabilityai/sdxl-vae",
tag_dict={"license": "other"},
last_modified=datetime.datetime.now(),
)
metadata_store.add_metadata("test_config_1", input_metadata)
input_metadata.name = "new-name"
metadata_store.update_metadata("test_config_1", input_metadata)
output_metadata = metadata_store.get_metadata("test_config_1")
assert output_metadata.name == "new-name"
assert input_metadata == output_metadata
def test_metadata_search(metadata_store: ModelMetadataStore) -> None:
metadata1 = HuggingFaceMetadata(
name="sdxl-vae",
author="stabilityai",
tags={"text-to-image", "diffusers"},
id="stabilityai/sdxl-vae",
tag_dict={"license": "other"},
last_modified=datetime.datetime.now(),
)
metadata2 = HuggingFaceMetadata(
name="model2",
author="stabilityai",
tags={"text-to-image", "diffusers", "community-contributed"},
id="author2/model2",
tag_dict={"license": "other"},
last_modified=datetime.datetime.now(),
)
metadata3 = HuggingFaceMetadata(
name="model3",
author="author3",
tags={"text-to-image", "checkpoint", "community-contributed"},
id="author3/model3",
tag_dict={"license": "other"},
last_modified=datetime.datetime.now(),
)
metadata_store.add_metadata("test_config_1", metadata1)
metadata_store.add_metadata("test_config_2", metadata2)
metadata_store.add_metadata("test_config_3", metadata3)
matches = metadata_store.search_by_author("stabilityai")
assert len(matches) == 2
assert "test_config_1" in matches
assert "test_config_2" in matches
matches = metadata_store.search_by_author("Sherlock Holmes")
assert not matches
matches = metadata_store.search_by_name("model3")
assert len(matches) == 1
assert "test_config_3" in matches
matches = metadata_store.search_by_tag({"text-to-image"})
assert len(matches) == 3
matches = metadata_store.search_by_tag({"text-to-image", "diffusers"})
assert len(matches) == 2
assert "test_config_1" in matches
assert "test_config_2" in matches
matches = metadata_store.search_by_tag({"checkpoint", "community-contributed"})
assert len(matches) == 1
assert "test_config_3" in matches
# does the tag table update correctly?
matches = metadata_store.search_by_tag({"checkpoint", "licensed-for-commercial-use"})
assert not matches
metadata3.tags.add("licensed-for-commercial-use")
metadata_store.update_metadata("test_config_3", metadata3)
matches = metadata_store.search_by_tag({"checkpoint", "licensed-for-commercial-use"})
assert len(matches) == 1
def test_metadata_civitai_fetch(session: Session) -> None:
fetcher = CivitaiMetadataFetch(session)
metadata = fetcher.from_url(HttpUrl("https://civitai.com/models/215485/SDXL-turbo"))
assert isinstance(metadata, CivitaiMetadata)
assert metadata.id == 215485
assert metadata.author == "test_author" # note that this is not the same as the original from Civitai
assert metadata.allow_commercial_use # changed to make sure we are reading locally not remotely
assert metadata.restrictions.AllowCommercialUse == CommercialUsage("RentCivit")
assert metadata.version_id == 242807
assert metadata.tags == {"tool", "turbo", "sdxl turbo"}
def test_metadata_hf_fetch(session: Session) -> None:
fetcher = HuggingFaceMetadataFetch(session)
metadata = fetcher.from_url(HttpUrl("https://huggingface.co/stabilityai/sdxl-turbo"))
assert isinstance(metadata, HuggingFaceMetadata)
assert metadata.author == "test_author" # this is not the same as the original
assert metadata.files
assert metadata.tags == {
"diffusers",
"onnx",
"safetensors",
"text-to-image",
"license:other",
"has_space",
"diffusers:StableDiffusionXLPipeline",
"region:us",
}

View File

@ -305,9 +305,7 @@ def test_calc_tiles_min_overlap_input_validation(
def test_calc_tiles_even_split_single_tile():
"""Test calc_tiles_even_split() behavior when a single tile covers the image."""
tiles = calc_tiles_even_split(
image_height=512, image_width=1024, num_tiles_x=1, num_tiles_y=1, overlap_fraction=0.25
)
tiles = calc_tiles_even_split(image_height=512, image_width=1024, num_tiles_x=1, num_tiles_y=1, overlap=64)
expected_tiles = [
Tile(
@ -322,36 +320,34 @@ def test_calc_tiles_even_split_single_tile():
def test_calc_tiles_even_split_evenly_divisible():
"""Test calc_tiles_even_split() behavior when the image is evenly covered by multiple tiles."""
# Parameters mimic roughly the same output as the original tile generations of the same test name
tiles = calc_tiles_even_split(
image_height=576, image_width=1600, num_tiles_x=3, num_tiles_y=2, overlap_fraction=0.25
)
tiles = calc_tiles_even_split(image_height=576, image_width=1600, num_tiles_x=3, num_tiles_y=2, overlap=64)
expected_tiles = [
# Row 0
Tile(
coords=TBLR(top=0, bottom=320, left=0, right=624),
overlap=TBLR(top=0, bottom=72, left=0, right=136),
coords=TBLR(top=0, bottom=320, left=0, right=576),
overlap=TBLR(top=0, bottom=64, left=0, right=64),
),
Tile(
coords=TBLR(top=0, bottom=320, left=488, right=1112),
overlap=TBLR(top=0, bottom=72, left=136, right=136),
coords=TBLR(top=0, bottom=320, left=512, right=1088),
overlap=TBLR(top=0, bottom=64, left=64, right=64),
),
Tile(
coords=TBLR(top=0, bottom=320, left=976, right=1600),
overlap=TBLR(top=0, bottom=72, left=136, right=0),
coords=TBLR(top=0, bottom=320, left=1024, right=1600),
overlap=TBLR(top=0, bottom=64, left=64, right=0),
),
# Row 1
Tile(
coords=TBLR(top=248, bottom=576, left=0, right=624),
overlap=TBLR(top=72, bottom=0, left=0, right=136),
coords=TBLR(top=256, bottom=576, left=0, right=576),
overlap=TBLR(top=64, bottom=0, left=0, right=64),
),
Tile(
coords=TBLR(top=248, bottom=576, left=488, right=1112),
overlap=TBLR(top=72, bottom=0, left=136, right=136),
coords=TBLR(top=256, bottom=576, left=512, right=1088),
overlap=TBLR(top=64, bottom=0, left=64, right=64),
),
Tile(
coords=TBLR(top=248, bottom=576, left=976, right=1600),
overlap=TBLR(top=72, bottom=0, left=136, right=0),
coords=TBLR(top=256, bottom=576, left=1024, right=1600),
overlap=TBLR(top=64, bottom=0, left=64, right=0),
),
]
assert tiles == expected_tiles
@ -360,36 +356,34 @@ def test_calc_tiles_even_split_evenly_divisible():
def test_calc_tiles_even_split_not_evenly_divisible():
"""Test calc_tiles_even_split() behavior when the image requires 'uneven' overlaps to achieve proper coverage."""
# Parameters mimic roughly the same output as the original tile generations of the same test name
tiles = calc_tiles_even_split(
image_height=400, image_width=1200, num_tiles_x=3, num_tiles_y=2, overlap_fraction=0.25
)
tiles = calc_tiles_even_split(image_height=400, image_width=1200, num_tiles_x=3, num_tiles_y=2, overlap=64)
expected_tiles = [
# Row 0
Tile(
coords=TBLR(top=0, bottom=224, left=0, right=464),
overlap=TBLR(top=0, bottom=56, left=0, right=104),
coords=TBLR(top=0, bottom=232, left=0, right=440),
overlap=TBLR(top=0, bottom=64, left=0, right=64),
),
Tile(
coords=TBLR(top=0, bottom=224, left=360, right=824),
overlap=TBLR(top=0, bottom=56, left=104, right=104),
coords=TBLR(top=0, bottom=232, left=376, right=816),
overlap=TBLR(top=0, bottom=64, left=64, right=64),
),
Tile(
coords=TBLR(top=0, bottom=224, left=720, right=1200),
overlap=TBLR(top=0, bottom=56, left=104, right=0),
coords=TBLR(top=0, bottom=232, left=752, right=1200),
overlap=TBLR(top=0, bottom=64, left=64, right=0),
),
# Row 1
Tile(
coords=TBLR(top=168, bottom=400, left=0, right=464),
overlap=TBLR(top=56, bottom=0, left=0, right=104),
coords=TBLR(top=168, bottom=400, left=0, right=440),
overlap=TBLR(top=64, bottom=0, left=0, right=64),
),
Tile(
coords=TBLR(top=168, bottom=400, left=360, right=824),
overlap=TBLR(top=56, bottom=0, left=104, right=104),
coords=TBLR(top=168, bottom=400, left=376, right=816),
overlap=TBLR(top=64, bottom=0, left=64, right=64),
),
Tile(
coords=TBLR(top=168, bottom=400, left=720, right=1200),
overlap=TBLR(top=56, bottom=0, left=104, right=0),
coords=TBLR(top=168, bottom=400, left=752, right=1200),
overlap=TBLR(top=64, bottom=0, left=64, right=0),
),
]
@ -399,28 +393,26 @@ def test_calc_tiles_even_split_not_evenly_divisible():
def test_calc_tiles_even_split_difficult_size():
"""Test calc_tiles_even_split() behavior when the image is a difficult size to spilt evenly and keep div8."""
# Parameters are a difficult size for other tile gen routines to calculate
tiles = calc_tiles_even_split(
image_height=1000, image_width=1000, num_tiles_x=2, num_tiles_y=2, overlap_fraction=0.25
)
tiles = calc_tiles_even_split(image_height=1000, image_width=1000, num_tiles_x=2, num_tiles_y=2, overlap=64)
expected_tiles = [
# Row 0
Tile(
coords=TBLR(top=0, bottom=560, left=0, right=560),
overlap=TBLR(top=0, bottom=128, left=0, right=128),
coords=TBLR(top=0, bottom=528, left=0, right=528),
overlap=TBLR(top=0, bottom=64, left=0, right=64),
),
Tile(
coords=TBLR(top=0, bottom=560, left=432, right=1000),
overlap=TBLR(top=0, bottom=128, left=128, right=0),
coords=TBLR(top=0, bottom=528, left=464, right=1000),
overlap=TBLR(top=0, bottom=64, left=64, right=0),
),
# Row 1
Tile(
coords=TBLR(top=432, bottom=1000, left=0, right=560),
overlap=TBLR(top=128, bottom=0, left=0, right=128),
coords=TBLR(top=464, bottom=1000, left=0, right=528),
overlap=TBLR(top=64, bottom=0, left=0, right=64),
),
Tile(
coords=TBLR(top=432, bottom=1000, left=432, right=1000),
overlap=TBLR(top=128, bottom=0, left=128, right=0),
coords=TBLR(top=464, bottom=1000, left=464, right=1000),
overlap=TBLR(top=64, bottom=0, left=64, right=0),
),
]
@ -428,11 +420,13 @@ def test_calc_tiles_even_split_difficult_size():
@pytest.mark.parametrize(
["image_height", "image_width", "num_tiles_x", "num_tiles_y", "overlap_fraction", "raises"],
["image_height", "image_width", "num_tiles_x", "num_tiles_y", "overlap", "raises"],
[
(128, 128, 1, 1, 0.25, False), # OK
(128, 128, 1, 1, 127, False), # OK
(128, 128, 1, 1, 0, False), # OK
(128, 128, 2, 1, 0, False), # OK
(128, 128, 2, 2, 0, False), # OK
(128, 128, 2, 1, 120, True), # overlap equals tile_height.
(128, 128, 1, 2, 120, True), # overlap equals tile_width.
(127, 127, 1, 1, 0, True), # image size must be dividable by 8
],
)
@ -441,15 +435,15 @@ def test_calc_tiles_even_split_input_validation(
image_width: int,
num_tiles_x: int,
num_tiles_y: int,
overlap_fraction: float,
overlap: int,
raises: bool,
):
"""Test that calc_tiles_even_split() raises an exception if the inputs are invalid."""
if raises:
with pytest.raises(ValueError):
calc_tiles_even_split(image_height, image_width, num_tiles_x, num_tiles_y, overlap_fraction)
with pytest.raises((AssertionError, ValueError)):
calc_tiles_even_split(image_height, image_width, num_tiles_x, num_tiles_y, overlap)
else:
calc_tiles_even_split(image_height, image_width, num_tiles_x, num_tiles_y, overlap_fraction)
calc_tiles_even_split(image_height, image_width, num_tiles_x, num_tiles_y, overlap)
#############################################