Merge branch 'main' into feat/batch-graphs

This commit is contained in:
Brandon Rising 2023-07-31 13:22:11 -04:00
commit bb681a8a11
289 changed files with 13713 additions and 9327 deletions

View File

@ -20,13 +20,13 @@ def calc_images_mean_L1(image1_path, image2_path):
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('image1_path') parser.add_argument("image1_path")
parser.add_argument('image2_path') parser.add_argument("image2_path")
args = parser.parse_args() args = parser.parse_args()
return args return args
if __name__ == '__main__': if __name__ == "__main__":
args = parse_args() args = parse_args()
mean_L1 = calc_images_mean_L1(args.image1_path, args.image2_path) mean_L1 = calc_images_mean_L1(args.image1_path, args.image2_path)
print(mean_L1) print(mean_L1)

View File

@ -1 +1,2 @@
b3dccfaeb636599c02effc377cdd8a87d658256c b3dccfaeb636599c02effc377cdd8a87d658256c
218b6d0546b990fc449c876fb99f44b50c4daa35

View File

@ -1,11 +1,11 @@
name: Close inactive issues name: Close inactive issues
on: on:
schedule: schedule:
- cron: "00 6 * * *" - cron: "00 4 * * *"
env: env:
DAYS_BEFORE_ISSUE_STALE: 14 DAYS_BEFORE_ISSUE_STALE: 30
DAYS_BEFORE_ISSUE_CLOSE: 28 DAYS_BEFORE_ISSUE_CLOSE: 14
jobs: jobs:
close-issues: close-issues:
@ -14,7 +14,7 @@ jobs:
issues: write issues: write
pull-requests: write pull-requests: write
steps: steps:
- uses: actions/stale@v5 - uses: actions/stale@v8
with: with:
days-before-issue-stale: ${{ env.DAYS_BEFORE_ISSUE_STALE }} days-before-issue-stale: ${{ env.DAYS_BEFORE_ISSUE_STALE }}
days-before-issue-close: ${{ env.DAYS_BEFORE_ISSUE_CLOSE }} days-before-issue-close: ${{ env.DAYS_BEFORE_ISSUE_CLOSE }}
@ -23,5 +23,6 @@ jobs:
close-issue-message: "Due to inactivity, this issue was automatically closed. If you are still experiencing the issue, please recreate the issue." close-issue-message: "Due to inactivity, this issue was automatically closed. If you are still experiencing the issue, please recreate the issue."
days-before-pr-stale: -1 days-before-pr-stale: -1
days-before-pr-close: -1 days-before-pr-close: -1
exempt-issue-labels: "Active Issue"
repo-token: ${{ secrets.GITHUB_TOKEN }} repo-token: ${{ secrets.GITHUB_TOKEN }}
operations-per-run: 500 operations-per-run: 500

27
.github/workflows/style-checks.yml vendored Normal file
View File

@ -0,0 +1,27 @@
name: Black # TODO: add isort and flake8 later
on:
pull_request: {}
push:
branches: master
tags: "*"
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install dependencies with pip
run: |
pip install --upgrade pip wheel
pip install .[test]
# - run: isort --check-only .
- run: black --check .
# - run: flake8

1
.gitignore vendored
View File

@ -38,7 +38,6 @@ develop-eggs/
downloads/ downloads/
eggs/ eggs/
.eggs/ .eggs/
lib/
lib64/ lib64/
parts/ parts/
sdist/ sdist/

10
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,10 @@
# See https://pre-commit.com/ for usage and config
repos:
- repo: local
hooks:
- id: black
name: black
stages: [commit]
language: system
entry: black
types: [python]

290
LICENSE-SDXL.txt Normal file
View File

@ -0,0 +1,290 @@
Copyright (c) 2023 Stability AI
CreativeML Open RAIL++-M License dated July 26, 2023
Section I: PREAMBLE
Multimodal generative models are being widely adopted and used, and
have the potential to transform the way artists, among other
individuals, conceive and benefit from AI or ML technologies as a tool
for content creation.
Notwithstanding the current and potential benefits that these
artifacts can bring to society at large, there are also concerns about
potential misuses of them, either due to their technical limitations
or ethical considerations.
In short, this license strives for both the open and responsible
downstream use of the accompanying model. When it comes to the open
character, we took inspiration from open source permissive licenses
regarding the grant of IP rights. Referring to the downstream
responsible use, we added use-based restrictions not permitting the
use of the model in very specific scenarios, in order for the licensor
to be able to enforce the license in case potential misuses of the
Model may occur. At the same time, we strive to promote open and
responsible research on generative models for art and content
generation.
Even though downstream derivative versions of the model could be
released under different licensing terms, the latter will always have
to include - at minimum - the same use-based restrictions as the ones
in the original license (this license). We believe in the intersection
between open and responsible AI development; thus, this agreement aims
to strike a balance between both in order to enable responsible
open-science in the field of AI.
This CreativeML Open RAIL++-M License governs the use of the model
(and its derivatives) and is informed by the model card associated
with the model.
NOW THEREFORE, You and Licensor agree as follows:
Definitions
"License" means the terms and conditions for use, reproduction, and
Distribution as defined in this document.
"Data" means a collection of information and/or content extracted from
the dataset used with the Model, including to train, pretrain, or
otherwise evaluate the Model. The Data is not licensed under this
License.
"Output" means the results of operating a Model as embodied in
informational content resulting therefrom.
"Model" means any accompanying machine-learning based assemblies
(including checkpoints), consisting of learnt weights, parameters
(including optimizer states), corresponding to the model architecture
as embodied in the Complementary Material, that have been trained or
tuned, in whole or in part on the Data, using the Complementary
Material.
"Derivatives of the Model" means all modifications to the Model, works
based on the Model, or any other model which is created or initialized
by transfer of patterns of the weights, parameters, activations or
output of the Model, to the other model, in order to cause the other
model to perform similarly to the Model, including - but not limited
to - distillation methods entailing the use of intermediate data
representations or methods based on the generation of synthetic data
by the Model for training the other model.
"Complementary Material" means the accompanying source code and
scripts used to define, run, load, benchmark or evaluate the Model,
and used to prepare data for training or evaluation, if any. This
includes any accompanying documentation, tutorials, examples, etc, if
any.
"Distribution" means any transmission, reproduction, publication or
other sharing of the Model or Derivatives of the Model to a third
party, including providing the Model as a hosted service made
available by electronic or other remote means - e.g. API-based or web
access.
"Licensor" means the copyright owner or entity authorized by the
copyright owner that is granting the License, including the persons or
entities that may have rights in the Model and/or distributing the
Model.
"You" (or "Your") means an individual or Legal Entity exercising
permissions granted by this License and/or making use of the Model for
whichever purpose and in any field of use, including usage of the
Model in an end-use application - e.g. chatbot, translator, image
generator.
"Third Parties" means individuals or legal entities that are not under
common control with Licensor or You.
"Contribution" means any work of authorship, including the original
version of the Model and any modifications or additions to that Model
or Derivatives of the Model thereof, that is intentionally submitted
to Licensor for inclusion in the Model by the copyright owner or by an
individual or Legal Entity authorized to submit on behalf of the
copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent to
the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control
systems, and issue tracking systems that are managed by, or on behalf
of, the Licensor for the purpose of discussing and improving the
Model, but excluding communication that is conspicuously marked or
otherwise designated in writing by the copyright owner as "Not a
Contribution."
"Contributor" means Licensor and any individual or Legal Entity on
behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Model.
Section II: INTELLECTUAL PROPERTY RIGHTS
Both copyright and patent grants apply to the Model, Derivatives of
the Model and Complementary Material. The Model and Derivatives of the
Model are subject to additional terms as described in
Section III.
Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare, publicly display, publicly
perform, sublicense, and distribute the Complementary Material, the
Model, and Derivatives of the Model.
Grant of Patent License. Subject to the terms and conditions of this
License and where and as applicable, each Contributor hereby grants to
You a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
irrevocable (except as stated in this paragraph) patent license to
make, have made, use, offer to sell, sell, import, and otherwise
transfer the Model and the Complementary Material, where such license
applies only to those patent claims licensable by such Contributor
that are necessarily infringed by their Contribution(s) alone or by
combination of their Contribution(s) with the Model to which such
Contribution(s) was submitted. If You institute patent litigation
against any entity (including a cross-claim or counterclaim in a
lawsuit) alleging that the Model and/or Complementary Material or a
Contribution incorporated within the Model and/or Complementary
Material constitutes direct or contributory patent infringement, then
any patent licenses granted to You under this License for the Model
and/or Work shall terminate as of the date such litigation is asserted
or filed.
Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION
Distribution and Redistribution. You may host for Third Party remote
access purposes (e.g. software-as-a-service), reproduce and distribute
copies of the Model or Derivatives of the Model thereof in any medium,
with or without modifications, provided that You meet the following
conditions: Use-based restrictions as referenced in paragraph 5 MUST
be included as an enforceable provision by You in any type of legal
agreement (e.g. a license) governing the use and/or distribution of
the Model or Derivatives of the Model, and You shall give notice to
subsequent users You Distribute to, that the Model or Derivatives of
the Model are subject to paragraph 5. This provision does not apply to
the use of Complementary Material. You must give any Third Party
recipients of the Model or Derivatives of the Model a copy of this
License; You must cause any modified files to carry prominent notices
stating that You changed the files; You must retain all copyright,
patent, trademark, and attribution notices excluding those notices
that do not pertain to any part of the Model, Derivatives of the
Model. You may add Your own copyright statement to Your modifications
and may provide additional or different license terms and conditions -
respecting paragraph 4.a. - for use, reproduction, or Distribution of
Your modifications, or for any such Derivatives of the Model as a
whole, provided Your use, reproduction, and Distribution of the Model
otherwise complies with the conditions stated in this License.
Use-based restrictions. The restrictions set forth in Attachment A are
considered Use-based restrictions. Therefore You cannot use the Model
and the Derivatives of the Model for the specified restricted
uses. You may use the Model subject to this License, including only
for lawful purposes and in accordance with the License. Use may
include creating any content with, finetuning, updating, running,
training, evaluating and/or reparametrizing the Model. You shall
require all of Your users who use the Model or a Derivative of the
Model to comply with the terms of this paragraph (paragraph 5).
The Output You Generate. Except as set forth herein, Licensor claims
no rights in the Output You generate using the Model. You are
accountable for the Output you generate and its subsequent uses. No
use of the output can contravene any provision as stated in the
License.
Section IV: OTHER PROVISIONS
Updates and Runtime Restrictions. To the maximum extent permitted by
law, Licensor reserves the right to restrict (remotely or otherwise)
usage of the Model in violation of this License.
Trademarks and related. Nothing in this License permits You to make
use of Licensors trademarks, trade names, logos or to otherwise
suggest endorsement or misrepresent the relationship between the
parties; and any rights not expressly granted herein are reserved by
the Licensors.
Disclaimer of Warranty. Unless required by applicable law or agreed to
in writing, Licensor provides the Model and the Complementary Material
(and each Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Model, Derivatives of
the Model, and the Complementary Material and assume any risks
associated with Your exercise of permissions under this License.
Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise, unless
required by applicable law (such as deliberate and grossly negligent
acts) or agreed to in writing, shall any Contributor be liable to You
for damages, including any direct, indirect, special, incidental, or
consequential damages of any character arising as a result of this
License or out of the use or inability to use the Model and the
Complementary Material (including but not limited to damages for loss
of goodwill, work stoppage, computer failure or malfunction, or any
and all other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
Accepting Warranty or Additional Liability. While redistributing the
Model, Derivatives of the Model and the Complementary Material
thereof, You may choose to offer, and charge a fee for, acceptance of
support, warranty, indemnity, or other liability obligations and/or
rights consistent with this License. However, in accepting such
obligations, You may act only on Your own behalf and on Your sole
responsibility, not on behalf of any other Contributor, and only if
You agree to indemnify, defend, and hold each Contributor harmless for
any liability incurred by, or claims asserted against, such
Contributor by reason of your accepting any such warranty or
additional liability.
If any provision of this License is held to be invalid, illegal or
unenforceable, the remaining provisions shall be unaffected thereby
and remain valid as if such provision had not been set forth herein.
END OF TERMS AND CONDITIONS
Attachment A
Use Restrictions
You agree not to use the Model or Derivatives of the Model:
* In any way that violates any applicable national, federal, state,
local or international law or regulation;
* For the purpose of exploiting, harming or attempting to exploit or
harm minors in any way;
* To generate or disseminate verifiably false information and/or
content with the purpose of harming others;
* To generate or disseminate personal identifiable information that
can be used to harm an individual;
* To defame, disparage or otherwise harass others;
* For fully automated decision making that adversely impacts an
individuals legal rights or otherwise creates or modifies a
binding, enforceable obligation;
* For any use intended to or which has the effect of discriminating
against or harming individuals or groups based on online or offline
social behavior or known or predicted personal or personality
characteristics;
* To exploit any of the vulnerabilities of a specific group of persons
based on their age, social, physical or mental characteristics, in
order to materially distort the behavior of a person pertaining to
that group in a manner that causes or is likely to cause that person
or another person physical or psychological harm;
* For any use intended to or which has the effect of discriminating
against individuals or groups based on legally protected
characteristics or categories;
* To provide medical advice and medical results interpretation;
* To generate or disseminate information for the purpose to be used
for administration of justice, law enforcement, immigration or
asylum processes, such as predicting an individual will commit
fraud/crime commitment (e.g. by text profiling, drawing causal
relationships between assertions made in documents, indiscriminate
and arbitrarily-targeted use).

View File

@ -123,7 +123,7 @@ and go to http://localhost:9090.
### Command-Line Installation (for developers and users familiar with Terminals) ### Command-Line Installation (for developers and users familiar with Terminals)
You must have Python 3.9 or 3.10 installed on your machine. Earlier or You must have Python 3.9 through 3.11 installed on your machine. Earlier or
later versions are not supported. later versions are not supported.
Node.js also needs to be installed along with yarn (can be installed with Node.js also needs to be installed along with yarn (can be installed with
the command `npm install -g yarn` if needed) the command `npm install -g yarn` if needed)

Binary file not shown.

After

Width:  |  Height:  |  Size: 131 KiB

View File

@ -65,7 +65,6 @@ InvokeAI:
esrgan: true esrgan: true
internet_available: true internet_available: true
log_tokenization: false log_tokenization: false
nsfw_checker: false
patchmatch: true patchmatch: true
restore: true restore: true
... ...
@ -136,19 +135,16 @@ command-line options by giving the `--help` argument:
``` ```
(.venv) > invokeai-web --help (.venv) > invokeai-web --help
usage: InvokeAI [-h] [--host HOST] [--port PORT] [--allow_origins [ALLOW_ORIGINS ...]] [--allow_credentials | --no-allow_credentials] usage: InvokeAI [-h] [--host HOST] [--port PORT] [--allow_origins [ALLOW_ORIGINS ...]] [--allow_credentials | --no-allow_credentials] [--allow_methods [ALLOW_METHODS ...]]
[--allow_methods [ALLOW_METHODS ...]] [--allow_headers [ALLOW_HEADERS ...]] [--esrgan | --no-esrgan] [--allow_headers [ALLOW_HEADERS ...]] [--esrgan | --no-esrgan] [--internet_available | --no-internet_available] [--log_tokenization | --no-log_tokenization]
[--internet_available | --no-internet_available] [--log_tokenization | --no-log_tokenization] [--patchmatch | --no-patchmatch] [--restore | --no-restore]
[--nsfw_checker | --no-nsfw_checker] [--patchmatch | --no-patchmatch] [--restore | --no-restore] [--always_use_cpu | --no-always_use_cpu] [--free_gpu_mem | --no-free_gpu_mem] [--max_loaded_models MAX_LOADED_MODELS] [--max_cache_size MAX_CACHE_SIZE]
[--always_use_cpu | --no-always_use_cpu] [--free_gpu_mem | --no-free_gpu_mem] [--max_cache_size MAX_CACHE_SIZE] [--max_vram_cache_size MAX_VRAM_CACHE_SIZE] [--gpu_mem_reserved GPU_MEM_RESERVED] [--precision {auto,float16,float32,autocast}]
[--max_vram_cache_size MAX_VRAM_CACHE_SIZE] [--precision {auto,float16,float32,autocast}] [--sequential_guidance | --no-sequential_guidance] [--xformers_enabled | --no-xformers_enabled] [--tiled_decode | --no-tiled_decode] [--root ROOT]
[--sequential_guidance | --no-sequential_guidance] [--xformers_enabled | --no-xformers_enabled] [--autoimport_dir AUTOIMPORT_DIR] [--lora_dir LORA_DIR] [--embedding_dir EMBEDDING_DIR] [--controlnet_dir CONTROLNET_DIR] [--conf_path CONF_PATH]
[--tiled_decode | --no-tiled_decode] [--root ROOT] [--autoimport_dir AUTOIMPORT_DIR] [--lora_dir LORA_DIR] [--models_dir MODELS_DIR] [--legacy_conf_dir LEGACY_CONF_DIR] [--db_dir DB_DIR] [--outdir OUTDIR] [--from_file FROM_FILE]
[--embedding_dir EMBEDDING_DIR] [--controlnet_dir CONTROLNET_DIR] [--conf_path CONF_PATH] [--models_dir MODELS_DIR] [--use_memory_db | --no-use_memory_db] [--model MODEL] [--log_handlers [LOG_HANDLERS ...]] [--log_format {plain,color,syslog,legacy}]
[--legacy_conf_dir LEGACY_CONF_DIR] [--db_dir DB_DIR] [--outdir OUTDIR] [--from_file FROM_FILE] [--log_level {debug,info,warning,error,critical}] [--version | --no-version]
[--use_memory_db | --no-use_memory_db] [--model MODEL] [--log_handlers [LOG_HANDLERS ...]]
[--log_format {plain,color,syslog,legacy}] [--log_level {debug,info,warning,error,critical}]
...
``` ```
## The Configuration Settings ## The Configuration Settings
@ -178,7 +174,6 @@ These configuration settings allow you to enable and disable various InvokeAI fe
| `esrgan` | `true` | Activate the ESRGAN upscaling options| | `esrgan` | `true` | Activate the ESRGAN upscaling options|
| `internet_available` | `true` | When a resource is not available locally, try to fetch it via the internet | | `internet_available` | `true` | When a resource is not available locally, try to fetch it via the internet |
| `log_tokenization` | `false` | Before each text2image generation, print a color-coded representation of the prompt to the console; this can help understand why a prompt is not working as expected | | `log_tokenization` | `false` | Before each text2image generation, print a color-coded representation of the prompt to the console; this can help understand why a prompt is not working as expected |
| `nsfw_checker` | `true` | Activate the NSFW checker to blur out risque images |
| `patchmatch` | `true` | Activate the "patchmatch" algorithm for improved inpainting | | `patchmatch` | `true` | Activate the "patchmatch" algorithm for improved inpainting |
| `restore` | `true` | Activate the facial restoration features (DEPRECATED; restoration features will be removed in 3.0.0) | | `restore` | `true` | Activate the facial restoration features (DEPRECATED; restoration features will be removed in 3.0.0) |

View File

@ -61,11 +61,13 @@ A noise scheduler (eg. DPM++ 2M Karras) schedules the subtraction of noise from
| ImageInverseLerp | Inverse linear interpolation of all pixels of an image | | ImageInverseLerp | Inverse linear interpolation of all pixels of an image |
| ImageLerp | Linear interpolation of all pixels of an image | | ImageLerp | Linear interpolation of all pixels of an image |
| ImageMultiply | Multiplies two images together using `PIL.ImageChops.Multiply()` | | ImageMultiply | Multiplies two images together using `PIL.ImageChops.Multiply()` |
| ImageNSFWBlurInvocation | Detects and blurs images that may contain sexually explicit content |
| ImagePaste | Pastes an image into another image | | ImagePaste | Pastes an image into another image |
| ImageProcessor | Base class for invocations that reprocess images for ControlNet | | ImageProcessor | Base class for invocations that reprocess images for ControlNet |
| ImageResize | Resizes an image to specific dimensions | | ImageResize | Resizes an image to specific dimensions |
| ImageScale | Scales an image by a factor | | ImageScale | Scales an image by a factor |
| ImageToLatents | Scales latents by a given factor | | ImageToLatents | Scales latents by a given factor |
| ImageWatermarkInvocation | Adds an invisible watermark to images |
| InfillColor | Infills transparent areas of an image with a solid color | | InfillColor | Infills transparent areas of an image with a solid color |
| InfillPatchMatch | Infills transparent areas of an image using the PatchMatch algorithm | | InfillPatchMatch | Infills transparent areas of an image using the PatchMatch algorithm |
| InfillTile | Infills transparent areas of an image with tiles of the image | | InfillTile | Infills transparent areas of an image with tiles of the image |

View File

@ -16,21 +16,24 @@ Output Example:
--- ---
## **Seamless Tiling** ## **Invisible Watermark**
The seamless tiling mode causes generated images to seamlessly tile In keeping with the principles for responsible AI generation, and to
with itself creating repetitive wallpaper-like patterns. To use it, help AI researchers avoid synthetic images contaminating their
activate the Seamless Tiling option in the Web GUI and then select training sets, InvokeAI adds an invisible watermark to each of the
whether to tile on the X (horizontal) and/or Y (vertical) axes. Tiling final images it generates. The watermark consists of the text
will then be active for the next set of generations. "InvokeAI" and can be viewed using the
[invisible-watermarks](https://github.com/ShieldMnt/invisible-watermark)
tool.
A nice prompt to test seamless tiling with is: Watermarking is controlled using the `invisible-watermark` setting in
`invokeai.yaml`. To turn it off, add the following line under the `Features`
category.
``` ```
pond garden with lotus by claude monet" invisible_watermark: false
``` ```
---
## **Weighted Prompts** ## **Weighted Prompts**
@ -39,34 +42,10 @@ priority to them, by adding `:<percent>` to the end of the section you wish to u
example consider this prompt: example consider this prompt:
```bash ```bash
tabby cat:0.25 white duck:0.75 hybrid (tabby cat):0.25 (white duck):0.75 hybrid
``` ```
This will tell the sampler to invest 25% of its effort on the tabby cat aspect of the image and 75% This will tell the sampler to invest 25% of its effort on the tabby cat aspect of the image and 75%
on the white duck aspect (surprisingly, this example actually works). The prompt weights can use any on the white duck aspect (surprisingly, this example actually works). The prompt weights can use any
combination of integers and floating point numbers, and they do not need to add up to 1. combination of integers and floating point numbers, and they do not need to add up to 1.
## **Thresholding and Perlin Noise Initialization Options**
Under the Noise section of the Web UI, you will find two options named
Perlin Noise and Noise Threshold. [Perlin
noise](https://en.wikipedia.org/wiki/Perlin_noise) is a type of
structured noise used to simulate terrain and other natural
textures. The slider controls the percentage of perlin noise that will
be mixed into the image at the beginning of generation. Adding a little
perlin noise to a generation will alter the image substantially.
The noise threshold limits the range of the latent values during
sampling and helps combat the oversharpening seem with higher CFG
scale values.
For better intuition into what these options do in practice:
![here is a graphic demonstrating them both](../assets/truncation_comparison.jpg)
In generating this graphic, perlin noise at initialization was
programmatically varied going across on the diagram by values 0.0,
0.1, 0.2, 0.4, 0.5, 0.6, 0.8, 0.9, 1.0; and the threshold was varied
going down from 0, 1, 2, 3, 4, 5, 10, 20, 100. The other options are
fixed using the prompt "a portrait of a beautiful young lady" a CFG of
20, 100 steps, and a seed of 1950357039.

View File

@ -1,12 +1,40 @@
--- ---
title: The NSFW Checker title: Watermarking, NSFW Image Checking
--- ---
# :material-image-off: NSFW Checker # :material-image-off: Invisible Watermark and the NSFW Checker
## Watermarking
InvokeAI does not apply watermarking to images by default. However,
many computer scientists working in the field of generative AI worry
that a flood of computer-generated imagery will contaminate the image
data sets needed to train future generations of generative models.
InvokeAI offers an optional watermarking mode that writes a small bit
of text, **InvokeAI**, into each image that it generates using an
"invisible" watermarking library that spreads the information
throughout the image in a way that is not perceptible to the human
eye. If you are planning to share your generated images on
internet-accessible services, we encourage you to activate the
invisible watermark mode in order to help preserve the digital image
environment.
The downside of watermarking is that it increases the size of the
image moderately, and has been reported by some individuals to degrade
image quality. Your mileage may vary.
To read the watermark in an image, activate the InvokeAI virtual
environment (called the "developer's console" in the launcher) and run
the command:
```
invisible-watermark -a decode -t bytes -m dwtDct -l 64 /path/to/image.png
```
## The NSFW ("Safety") Checker ## The NSFW ("Safety") Checker
The Stable Diffusion image generation models will produce sexual Stable Diffusion 1.5-based image generation models will produce sexual
imagery if deliberately prompted, and will occasionally produce such imagery if deliberately prompted, and will occasionally produce such
images when this is not intended. Such images are colloquially known images when this is not intended. Such images are colloquially known
as "Not Safe for Work" (NSFW). This behavior is due to the nature of as "Not Safe for Work" (NSFW). This behavior is due to the nature of
@ -18,35 +46,17 @@ jurisdictions it may be illegal to publicly distribute such imagery,
including mounting a publicly-available server that provides including mounting a publicly-available server that provides
unfiltered images to the public. Furthermore, the [Stable Diffusion unfiltered images to the public. Furthermore, the [Stable Diffusion
weights weights
License](https://github.com/invoke-ai/InvokeAI/blob/main/LICENSE-ModelWeights.txt) License](https://github.com/invoke-ai/InvokeAI/blob/main/LICENSE-SD1+SD2.txt),
forbids the model from being used to "exploit any of the and the [Stable Diffusion XL
License][https://github.com/invoke-ai/InvokeAI/blob/main/LICENSE-SDXL.txt]
both forbid the models from being used to "exploit any of the
vulnerabilities of a specific group of persons." vulnerabilities of a specific group of persons."
For these reasons Stable Diffusion offers a "safety checker," a For these reasons Stable Diffusion offers a "safety checker," a
machine learning model trained to recognize potentially disturbing machine learning model trained to recognize potentially disturbing
imagery. When a potentially NSFW image is detected, the checker will imagery. When a potentially NSFW image is detected, the checker will
blur the image and paste a warning icon on top. The checker can be blur the image and paste a warning icon on top. The checker can be
turned on and off on the command line using `--nsfw_checker` and turned on and off in the Web interface under Settings.
`--no-nsfw_checker`.
At installation time, InvokeAI will ask whether the checker should be
activated by default (neither argument given on the command line). The
response is stored in the InvokeAI initialization file
(`invokeai.yaml` in the InvokeAI root directory). You can change the
default at any time by opening this file in a text editor and
changing the line `nsfw_checker:` from true to false or vice-versa:
```
...
Features:
esrgan: true
internet_available: true
log_tokenization: false
nsfw_checker: true
patchmatch: true
restore: true
```
## Caveats ## Caveats
@ -84,10 +94,3 @@ are encouraged to turn **off** intermediate image rendering when you
are using the checker. Future versions of InvokeAI will apply are using the checker. Future versions of InvokeAI will apply
additional blurring to intermediate images when the checker is active. additional blurring to intermediate images when the checker is active.
### Watermarking
InvokeAI does not apply any sort of watermark to images it
generates. However, it does write metadata into the PNG data area,
including the prompt used to generate the image and relevant parameter
settings. These fields can be examined using the `sd-metadata.py`
script that comes with the InvokeAI package.

View File

@ -80,11 +80,11 @@ Q&A</a>]
!!! note !!! note
This fork is rapidly evolving. Please use the [Issues tab](https://github.com/invoke-ai/InvokeAI/issues) to report bugs and make feature requests. Be sure to use the provided templates. They will help aid diagnose issues faster. This software is rapidly evolving. Please use the [Issues tab](https://github.com/invoke-ai/InvokeAI/issues) to report bugs and make feature requests. Be sure to use the provided templates. They will help aid diagnose issues faster.
## :octicons-package-dependencies-24: Installation ## :octicons-package-dependencies-24: Installation
This fork is supported across Linux, Windows and Macintosh. Linux users can use This software is supported across Linux, Windows and Macintosh. Linux users can use
either an Nvidia-based card (with CUDA support) or an AMD card (using the ROCm either an Nvidia-based card (with CUDA support) or an AMD card (using the ROCm
driver). driver).
@ -95,6 +95,8 @@ driver).
This method is recommended for experienced users and developers This method is recommended for experienced users and developers
#### [Docker Installation](installation/040_INSTALL_DOCKER.md) #### [Docker Installation](installation/040_INSTALL_DOCKER.md)
This method is recommended for those familiar with running Docker containers This method is recommended for those familiar with running Docker containers
#### [Installation Troubleshooting](installation/010_INSTALL_AUTOMATED.md#troubleshooting)
Installation troubleshooting guide.
### Other Installation Guides ### Other Installation Guides
- [PyPatchMatch](installation/060_INSTALL_PATCHMATCH.md) - [PyPatchMatch](installation/060_INSTALL_PATCHMATCH.md)
- [XFormers](installation/070_INSTALL_XFORMERS.md) - [XFormers](installation/070_INSTALL_XFORMERS.md)
@ -148,7 +150,7 @@ images in full-precision mode:
- [Model Merging](features/MODEL_MERGING.md) - [Model Merging](features/MODEL_MERGING.md)
- [ControlNet Models](features/CONTROLNET.md) - [ControlNet Models](features/CONTROLNET.md)
- [Style/Subject Concepts and Embeddings](features/CONCEPTS.md) - [Style/Subject Concepts and Embeddings](features/CONCEPTS.md)
- [Not Safe for Work (NSFW) Checker](features/NSFW.md) - [Watermarking and the Not Safe for Work (NSFW) Checker](features/WATERMARK+NSFW.md)
<!-- seperator --> <!-- seperator -->
### Prompt Engineering ### Prompt Engineering
- [Prompt Syntax](features/PROMPTS.md) - [Prompt Syntax](features/PROMPTS.md)
@ -230,7 +232,7 @@ encouraged to do so.
## :octicons-person-24: Contributors ## :octicons-person-24: Contributors
This fork is a combined effort of various people from across the world. This software is a combined effort of various people from across the world.
[Check out the list of all these amazing people](other/CONTRIBUTORS.md). We [Check out the list of all these amazing people](other/CONTRIBUTORS.md). We
thank them for their time, hard work and effort. thank them for their time, hard work and effort.

View File

@ -40,10 +40,8 @@ experimental versions later.
this, open up a command-line window ("Terminal" on Linux and this, open up a command-line window ("Terminal" on Linux and
Macintosh, "Command" or "Powershell" on Windows) and type `python Macintosh, "Command" or "Powershell" on Windows) and type `python
--version`. If Python is installed, it will print out the version --version`. If Python is installed, it will print out the version
number. If it is version `3.9.*` or `3.10.*`, you meet number. If it is version `3.9.*`, `3.10.*` or `3.11.*` you meet
requirements. We do not recommend using Python 3.11 or higher, requirements.
as not all the libraries that InvokeAI depends on work properly
with this version.
!!! warning "What to do if you have an unsupported version" !!! warning "What to do if you have an unsupported version"
@ -215,17 +213,6 @@ experimental versions later.
Generally the defaults are fine, and you can come back to this screen at Generally the defaults are fine, and you can come back to this screen at
any time to tweak your system. Here are the options you can adjust: any time to tweak your system. Here are the options you can adjust:
- ***Output directory for images***
This is the path to a directory in which InvokeAI will store all its
generated images.
- ***NSFW checker***
If checked, InvokeAI will test images for potential sexual content
and blur them out if found. Note that the NSFW checker consumes
an additional 0.6 GB of VRAM on top of the 2-3 GB of VRAM used
by most image models. If you have a low VRAM GPU (4-6 GB), you
can reduce out of memory errors by disabling the checker.
- ***HuggingFace Access Token*** - ***HuggingFace Access Token***
InvokeAI has the ability to download embedded styles and subjects InvokeAI has the ability to download embedded styles and subjects
from the HuggingFace Concept Library on-demand. However, some of from the HuggingFace Concept Library on-demand. However, some of
@ -257,20 +244,30 @@ experimental versions later.
and graphics cards. The "autocast" option is deprecated and and graphics cards. The "autocast" option is deprecated and
shouldn't be used unless you are asked to by a member of the team. shouldn't be used unless you are asked to by a member of the team.
- ***Number of models to cache in CPU memory*** - **Size of the RAM cache used for fast model switching***
This allows you to keep models in memory and switch rapidly among This allows you to keep models in memory and switch rapidly among
them rather than having them load from disk each time. This slider them rather than having them load from disk each time. This slider
controls how many models to keep loaded at once. Each controls how many models to keep loaded at once. A typical SD-1 or SD-2 model
model will use 2-4 GB of RAM, so use this cautiously uses 2-3 GB of memory. A typical SDXL model uses 6-7 GB. Providing more
RAM will allow more models to be co-resident.
- ***Directory containing embedding/textual inversion files*** - ***Output directory for images***
This is the directory in which you can place custom embedding This is the path to a directory in which InvokeAI will store all its
files (.pt or .bin). During startup, this directory will be generated images.
scanned and InvokeAI will print out the text terms that
are available to trigger the embeddings. - ***Autoimport Folder***
This is the directory in which you can place models you have
downloaded and wish to load into InvokeAI. You can place a variety
of models in this directory, including diffusers folders, .ckpt files,
.safetensors files, as well as LoRAs, ControlNet and Textual Inversion
files (both folder and file versions). To help organize this folder,
you can create several levels of subfolders and drop your models into
whichever ones you want.
- ***Autoimport FolderLICENSE***
At the bottom of the screen you will see a checkbox for accepting At the bottom of the screen you will see a checkbox for accepting
the CreativeML Responsible AI License. You need to accept the license the CreativeML Responsible AI Licenses. You need to accept the license
in order to download Stable Diffusion models from the next screen. in order to download Stable Diffusion models from the next screen.
_You can come back to the startup options form_ as many times as you like. _You can come back to the startup options form_ as many times as you like.
@ -375,8 +372,71 @@ experimental versions later.
Once InvokeAI is installed, do not move or remove this directory." Once InvokeAI is installed, do not move or remove this directory."
<a name="troubleshooting"></a>
## Troubleshooting ## Troubleshooting
### _OSErrors on Windows while installing dependencies_
During a zip file installation or an online update, installation stops
with an error like this:
![broken-dependency-screenshot](../assets/troubleshooting/broken-dependency.png){:width="800px"}
This seems to happen particularly often with the `pydantic` and
`numpy` packages. The most reliable solution requires several manual
steps to complete installation.
Open up a Powershell window and navigate to the `invokeai` directory
created by the installer. Then give the following series of commands:
```cmd
rm .\.venv -r -force
python -mvenv .venv
.\.venv\Scripts\activate
pip install invokeai
invokeai-configure --yes --root .
```
If you see anything marked as an error during this process please stop
and seek help on the Discord [installation support
channel](https://discord.com/channels/1020123559063990373/1041391462190956654). A
few warning messages are OK.
If you are updating from a previous version, this should restore your
system to a working state. If you are installing from scratch, there
is one additional command to give:
```cmd
wget -O invoke.bat https://raw.githubusercontent.com/invoke-ai/InvokeAI/main/installer/templates/invoke.bat.in
```
This will create the `invoke.bat` script needed to launch InvokeAI and
its related programs.
### _Stable Diffusion XL Generation Fails after Trying to Load unet_
InvokeAI is working in other respects, but when trying to generate
images with Stable Diffusion XL you get a "Server Error". The text log
in the launch window contains this log line above several more lines of
error messages:
```INFO --> Loading model:D:\LONG\PATH\TO\MODEL, type sdxl:main:unet```
This failure mode occurs when there is a network glitch during
downloading the very large SDXL model.
To address this, first go to the Web Model Manager and delete the
Stable-Diffusion-XL-base-1.X model. Then navigate to HuggingFace and
manually download the .safetensors version of the model. The 1.0
version is located at
https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/tree/main
and the file is named `sd_xl_base_1.0.safetensors`.
Save this file to disk and then reenter the Model Manager. Navigate to
Import Models->Add Model, then type (or drag-and-drop) the path to the
.safetensors file. Press "Add Model".
### _Package dependency conflicts_ ### _Package dependency conflicts_
If you have previously installed InvokeAI or another Stable Diffusion If you have previously installed InvokeAI or another Stable Diffusion

View File

@ -32,7 +32,7 @@ gaming):
* **Python** * **Python**
version 3.9 or 3.10 (3.11 is not recommended). version 3.9 through 3.11
* **CUDA Tools** * **CUDA Tools**
@ -65,7 +65,7 @@ gaming):
To install InvokeAI with virtual environments and the PIP package To install InvokeAI with virtual environments and the PIP package
manager, please follow these steps: manager, please follow these steps:
1. Please make sure you are using Python 3.9 or 3.10. The rest of the install 1. Please make sure you are using Python 3.9 through 3.11. The rest of the install
procedure depends on this and will not work with other versions: procedure depends on this and will not work with other versions:
```bash ```bash

View File

@ -14,20 +14,25 @@ The nodes linked below have been developed and contributed by members of the Inv
## List of Nodes ## List of Nodes
### Face Mask ### FaceTools
**Description:** This node autodetects a face in the image using MediaPipe and masks it by making it transparent. Via outpainting you can swap faces with other faces, or invert the mask and swap things around the face with other things. Additionally, you can supply X and Y offset values to scale/change the shape of the mask for finer control. The node also outputs an all-white mask in the same dimensions as the input image. This is needed by the inpaint node (and unified canvas) for outpainting. **Description:** FaceTools is a collection of nodes created to manipulate faces as you would in Unified Canvas. It includes FaceMask, FaceOff, and FacePlace. FaceMask autodetects a face in the image using MediaPipe and creates a mask from it. FaceOff similarly detects a face, then takes the face off of the image by adding a square bounding box around it and cropping/scaling it. FacePlace puts the bounded face image from FaceOff back onto the original image. Using these nodes with other inpainting node(s), you can put new faces on existing things, put new things around existing faces, and work closer with a face as a bounded image. Additionally, you can supply X and Y offset values to scale/change the shape of the mask for finer control on FaceMask and FaceOff. See GitHub repository below for usage examples.
**Node Link:** https://github.com/ymgenesis/InvokeAI/blob/facemaskmediapipe/invokeai/app/invocations/facemask.py **Node Link:** https://github.com/ymgenesis/FaceTools/
**Example Node Graph:** https://www.mediafire.com/file/gohn5sb1bfp8use/21-July_2023-FaceMask.json/file **FaceMask Output Examples**
**Output Examples** ![5cc8abce-53b0-487a-b891-3bf94dcc8960](https://github.com/invoke-ai/InvokeAI/assets/25252829/43f36d24-1429-4ab1-bd06-a4bedfe0955e)
![b920b710-1882-49a0-8d02-82dff2cca907](https://github.com/invoke-ai/InvokeAI/assets/25252829/7660c1ed-bf7d-4d0a-947f-1fc1679557ba)
![71a91805-fda5-481c-b380-264665703133](https://github.com/invoke-ai/InvokeAI/assets/25252829/f8f6a2ee-2b68-4482-87da-b90221d5c3e2)
![2e3168cb-af6a-475d-bfac-c7b7fd67b4c2](https://github.com/ymgenesis/InvokeAI/assets/25252829/a5ad7d44-2ada-4b3c-a56e-a21f8244a1ac) <hr>
![2_annotated](https://github.com/ymgenesis/InvokeAI/assets/25252829/53416c8a-a23b-4d76-bb6d-3cfd776e0096)
![2fe2150c-fd08-4e26-8c36-f0610bf441bb](https://github.com/ymgenesis/InvokeAI/assets/25252829/b0f7ecfe-f093-4147-a904-b9f131b41dc9) ### Ideal Size
![831b6b98-4f0f-4360-93c8-69a9c1338cbe](https://github.com/ymgenesis/InvokeAI/assets/25252829/fc7b0622-e361-4155-8a76-082894d084f0)
**Description:** This node calculates an ideal image size for a first pass of a multi-pass upscaling. The aim is to avoid duplication that results from choosing a size larger than the model is capable of.
**Node Link:** https://github.com/JPPhoto/ideal-size-node
-------------------------------- --------------------------------
### Super Cool Node Template ### Super Cool Node Template
@ -42,11 +47,5 @@ The nodes linked below have been developed and contributed by members of the Inv
![Invoke AI](https://invoke-ai.github.io/InvokeAI/assets/invoke_ai_banner.png) ![Invoke AI](https://invoke-ai.github.io/InvokeAI/assets/invoke_ai_banner.png)
### Ideal Size
**Description:** This node calculates an ideal image size for a first pass of a multi-pass upscaling. The aim is to avoid duplication that results from choosing a size larger than the model is capable of.
**Node Link:** https://github.com/JPPhoto/ideal-size-node
## Help ## Help
If you run into any issues with a node, please post in the [InvokeAI Discord](https://discord.gg/ZmtBAhwWhy). If you run into any issues with a node, please post in the [InvokeAI Discord](https://discord.gg/ZmtBAhwWhy).

25
flake.lock Normal file
View File

@ -0,0 +1,25 @@
{
"nodes": {
"nixpkgs": {
"locked": {
"lastModified": 1690630721,
"narHash": "sha256-Y04onHyBQT4Erfr2fc82dbJTfXGYrf4V0ysLUYnPOP8=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "d2b52322f35597c62abf56de91b0236746b2a03d",
"type": "github"
},
"original": {
"id": "nixpkgs",
"type": "indirect"
}
},
"root": {
"inputs": {
"nixpkgs": "nixpkgs"
}
}
},
"root": "root",
"version": 7
}

81
flake.nix Normal file
View File

@ -0,0 +1,81 @@
# Important note: this flake does not attempt to create a fully isolated, 'pure'
# Python environment for InvokeAI. Instead, it depends on local invocations of
# virtualenv/pip to install the required (binary) packages, most importantly the
# prebuilt binary pytorch packages with CUDA support.
# ML Python packages with CUDA support, like pytorch, are notoriously expensive
# to compile so it's purposefuly not what this flake does.
{
description = "An (impure) flake to develop on InvokeAI.";
outputs = { self, nixpkgs }:
let
system = "x86_64-linux";
pkgs = import nixpkgs {
inherit system;
config.allowUnfree = true;
};
python = pkgs.python310;
mkShell = { dir, install }:
let
setupScript = pkgs.writeScript "setup-invokai" ''
# This must be sourced using 'source', not executed.
${python}/bin/python -m venv ${dir}
${dir}/bin/python -m pip install ${install}
# ${dir}/bin/python -c 'import torch; assert(torch.cuda.is_available())'
source ${dir}/bin/activate
'';
in
pkgs.mkShell rec {
buildInputs = with pkgs; [
# Backend: graphics, CUDA.
cudaPackages.cudnn
cudaPackages.cuda_nvrtc
cudatoolkit
freeglut
glib
gperf
procps
libGL
libGLU
linuxPackages.nvidia_x11
python
stdenv.cc
stdenv.cc.cc.lib
xorg.libX11
xorg.libXext
xorg.libXi
xorg.libXmu
xorg.libXrandr
xorg.libXv
zlib
# Pre-commit hooks.
black
# Frontend.
yarn
nodejs
];
LD_LIBRARY_PATH = pkgs.lib.makeLibraryPath buildInputs;
CUDA_PATH = pkgs.cudatoolkit;
EXTRA_LDFLAGS = "-L${pkgs.linuxPackages.nvidia_x11}/lib";
shellHook = ''
if [[ -f "${dir}/bin/activate" ]]; then
source "${dir}/bin/activate"
echo "Using Python: $(which python)"
else
echo "Use 'source ${setupScript}' to set up the environment."
fi
'';
};
in
{
devShells.${system} = rec {
develop = mkShell { dir = "venv"; install = "-e '.[xformers]' --extra-index-url https://download.pytorch.org/whl/cu118"; };
default = develop;
};
};
}

View File

@ -9,13 +9,17 @@ cd $scriptdir
function version { echo "$@" | awk -F. '{ printf("%d%03d%03d%03d\n", $1,$2,$3,$4); }'; } function version { echo "$@" | awk -F. '{ printf("%d%03d%03d%03d\n", $1,$2,$3,$4); }'; }
MINIMUM_PYTHON_VERSION=3.9.0 MINIMUM_PYTHON_VERSION=3.9.0
MAXIMUM_PYTHON_VERSION=3.11.0 MAXIMUM_PYTHON_VERSION=3.11.100
PYTHON="" PYTHON=""
for candidate in python3.10 python3.9 python3 python ; do for candidate in python3.11 python3.10 python3.9 python3 python ; do
if ppath=`which $candidate`; then if ppath=`which $candidate`; then
# when using `pyenv`, the executable for an inactive Python version will exist but will not be operational
# we check that this found executable can actually run
if [ $($candidate --version &>/dev/null; echo ${PIPESTATUS}) -gt 0 ]; then continue; fi
python_version=$($ppath -V | awk '{ print $2 }') python_version=$($ppath -V | awk '{ print $2 }')
if [ $(version $python_version) -ge $(version "$MINIMUM_PYTHON_VERSION") ]; then if [ $(version $python_version) -ge $(version "$MINIMUM_PYTHON_VERSION") ]; then
if [ $(version $python_version) -lt $(version "$MAXIMUM_PYTHON_VERSION") ]; then if [ $(version $python_version) -le $(version "$MAXIMUM_PYTHON_VERSION") ]; then
PYTHON=$ppath PYTHON=$ppath
break break
fi fi

View File

@ -13,7 +13,7 @@ from pathlib import Path
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Union from typing import Union
SUPPORTED_PYTHON = ">=3.9.0,<3.11" SUPPORTED_PYTHON = ">=3.9.0,<=3.11.100"
INSTALLER_REQS = ["rich", "semver", "requests", "plumbum", "prompt-toolkit"] INSTALLER_REQS = ["rich", "semver", "requests", "plumbum", "prompt-toolkit"]
BOOTSTRAP_VENV_PREFIX = "invokeai-installer-tmp" BOOTSTRAP_VENV_PREFIX = "invokeai-installer-tmp"
@ -141,7 +141,6 @@ class Installer:
# upgrade pip in Python 3.9 environments # upgrade pip in Python 3.9 environments
if int(platform.python_version_tuple()[1]) == 9: if int(platform.python_version_tuple()[1]) == 9:
from plumbum import FG, local from plumbum import FG, local
pip = local[get_pip_from_venv(venv_dir)] pip = local[get_pip_from_venv(venv_dir)]
@ -149,7 +148,9 @@ class Installer:
return venv_dir return venv_dir
def install(self, root: str = "~/invokeai-3", version: str = "latest", yes_to_all=False, find_links: Path = None) -> None: def install(
self, root: str = "~/invokeai", version: str = "latest", yes_to_all=False, find_links: Path = None
) -> None:
""" """
Install the InvokeAI application into the given runtime path Install the InvokeAI application into the given runtime path
@ -167,7 +168,8 @@ class Installer:
messages.welcome() messages.welcome()
self.dest = Path(root).expanduser().resolve() if yes_to_all else messages.dest_path(root) default_path = os.environ.get("INVOKEAI_ROOT") or Path(root).expanduser().resolve()
self.dest = default_path if yes_to_all else messages.dest_path(root)
# create the venv for the app # create the venv for the app
self.venv = self.app_venv() self.venv = self.app_venv()
@ -188,6 +190,7 @@ class Installer:
# run through the configuration flow # run through the configuration flow
self.instance.configure() self.instance.configure()
class InvokeAiInstance: class InvokeAiInstance:
""" """
Manages an installed instance of InvokeAI, comprising a virtual environment and a runtime directory. Manages an installed instance of InvokeAI, comprising a virtual environment and a runtime directory.
@ -196,7 +199,6 @@ class InvokeAiInstance:
""" """
def __init__(self, runtime: Path, venv: Path, version: str) -> None: def __init__(self, runtime: Path, venv: Path, version: str) -> None:
self.runtime = runtime self.runtime = runtime
self.venv = venv self.venv = venv
self.pip = get_pip_from_venv(venv) self.pip = get_pip_from_venv(venv)
@ -247,6 +249,9 @@ class InvokeAiInstance:
pip[ pip[
"install", "install",
"--require-virtualenv", "--require-virtualenv",
"numpy~=1.24.0", # choose versions that won't be uninstalled during phase 2
"urllib3~=1.26.0",
"requests~=2.28.0",
"torch~=2.0.0", "torch~=2.0.0",
"torchmetrics==0.11.4", "torchmetrics==0.11.4",
"torchvision>=0.14.1", "torchvision>=0.14.1",
@ -312,7 +317,7 @@ class InvokeAiInstance:
"install", "install",
"--require-virtualenv", "--require-virtualenv",
"--use-pep517", "--use-pep517",
str(src)+(optional_modules if optional_modules else ''), str(src) + (optional_modules if optional_modules else ""),
"--find-links" if find_links is not None else None, "--find-links" if find_links is not None else None,
find_links, find_links,
"--extra-index-url" if extra_index_url is not None else None, "--extra-index-url" if extra_index_url is not None else None,
@ -331,10 +336,10 @@ class InvokeAiInstance:
new_argv = [sys.argv[0]] new_argv = [sys.argv[0]]
for i in range(1, len(sys.argv)): for i in range(1, len(sys.argv)):
el = sys.argv[i] el = sys.argv[i]
if el in ['-r','--root']: if el in ["-r", "--root"]:
new_argv.append(el) new_argv.append(el)
new_argv.append(sys.argv[i + 1]) new_argv.append(sys.argv[i + 1])
elif el in ['-y','--yes','--yes-to-all']: elif el in ["-y", "--yes", "--yes-to-all"]:
new_argv.append(el) new_argv.append(el)
sys.argv = new_argv sys.argv = new_argv
@ -353,16 +358,16 @@ class InvokeAiInstance:
invokeai_configure() invokeai_configure()
succeeded = True succeeded = True
except requests.exceptions.ConnectionError as e: except requests.exceptions.ConnectionError as e:
print(f'\nA network error was encountered during configuration and download: {str(e)}') print(f"\nA network error was encountered during configuration and download: {str(e)}")
except OSError as e: except OSError as e:
print(f'\nAn OS error was encountered during configuration and download: {str(e)}') print(f"\nAn OS error was encountered during configuration and download: {str(e)}")
except Exception as e: except Exception as e:
print(f'\nA problem was encountered during the configuration and download steps: {str(e)}') print(f"\nA problem was encountered during the configuration and download steps: {str(e)}")
finally: finally:
if not succeeded: if not succeeded:
print('To try again, find the "invokeai" directory, run the script "invoke.sh" or "invoke.bat"') print('To try again, find the "invokeai" directory, run the script "invoke.sh" or "invoke.bat"')
print('and choose option 7 to fix a broken install, optionally followed by option 5 to install models.') print("and choose option 7 to fix a broken install, optionally followed by option 5 to install models.")
print('Alternatively you can relaunch the installer.') print("Alternatively you can relaunch the installer.")
def install_user_scripts(self): def install_user_scripts(self):
""" """
@ -372,10 +377,10 @@ class InvokeAiInstance:
ext = "bat" if OS == "Windows" else "sh" ext = "bat" if OS == "Windows" else "sh"
# scripts = ['invoke', 'update'] # scripts = ['invoke', 'update']
scripts = ['invoke'] scripts = ["invoke"]
for script in scripts: for script in scripts:
src = Path(__file__).parent / '..' / "templates" / f"{script}.{ext}.in" src = Path(__file__).parent / ".." / "templates" / f"{script}.{ext}.in"
dest = self.runtime / f"{script}.{ext}" dest = self.runtime / f"{script}.{ext}"
shutil.copy(src, dest) shutil.copy(src, dest)
os.chmod(dest, 0o0755) os.chmod(dest, 0o0755)
@ -420,11 +425,7 @@ def set_sys_path(venv_path: Path) -> None:
# filter out any paths in sys.path that may be system- or user-wide # filter out any paths in sys.path that may be system- or user-wide
# but leave the temporary bootstrap virtualenv as it contains packages we # but leave the temporary bootstrap virtualenv as it contains packages we
# temporarily need at install time # temporarily need at install time
sys.path = list(filter( sys.path = list(filter(lambda p: not p.endswith("-packages") or p.find(BOOTSTRAP_VENV_PREFIX) != -1, sys.path))
lambda p: not p.endswith("-packages")
or p.find(BOOTSTRAP_VENV_PREFIX) != -1,
sys.path
))
# determine site-packages/lib directory location for the venv # determine site-packages/lib directory location for the venv
lib = "Lib" if OS == "Windows" else f"lib/python{sys.version_info.major}.{sys.version_info.minor}" lib = "Lib" if OS == "Windows" else f"lib/python{sys.version_info.major}.{sys.version_info.minor}"
@ -461,9 +462,9 @@ def get_torch_source() -> (Union[str, None],str):
elif device == "cpu": elif device == "cpu":
url = "https://download.pytorch.org/whl/cpu" url = "https://download.pytorch.org/whl/cpu"
if device == 'cuda': if device == "cuda":
url = 'https://download.pytorch.org/whl/cu117' url = "https://download.pytorch.org/whl/cu117"
optional_modules = '[xformers]' optional_modules = "[xformers]"
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13 # in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13

View File

@ -3,6 +3,7 @@ InvokeAI Installer
""" """
import argparse import argparse
import os
from pathlib import Path from pathlib import Path
from installer import Installer from installer import Installer
@ -15,7 +16,7 @@ if __name__ == "__main__":
dest="root", dest="root",
type=str, type=str,
help="Destination path for installation", help="Destination path for installation",
default="~/invokeai", default=os.environ.get("INVOKEAI_ROOT") or "~/invokeai",
) )
parser.add_argument( parser.add_argument(
"-y", "-y",

View File

@ -36,13 +36,15 @@ else:
def welcome(): def welcome():
@group() @group()
def text(): def text():
if (platform_specific := _platform_specific_help()) != "": if (platform_specific := _platform_specific_help()) != "":
yield platform_specific yield platform_specific
yield "" yield ""
yield Text.from_markup("Some of the installation steps take a long time to run. Please be patient. If the script appears to hang for more than 10 minutes, please interrupt with [i]Control-C[/] and retry.", justify="center") yield Text.from_markup(
"Some of the installation steps take a long time to run. Please be patient. If the script appears to hang for more than 10 minutes, please interrupt with [i]Control-C[/] and retry.",
justify="center",
)
console.rule() console.rule()
print( print(
@ -58,6 +60,7 @@ def welcome():
) )
console.line() console.line()
def confirm_install(dest: Path) -> bool: def confirm_install(dest: Path) -> bool:
if dest.exists(): if dest.exists():
print(f":exclamation: Directory {dest} already exists :exclamation:") print(f":exclamation: Directory {dest} already exists :exclamation:")
@ -92,7 +95,6 @@ def dest_path(dest=None) -> Path:
dest_confirmed = confirm_install(dest) dest_confirmed = confirm_install(dest)
while not dest_confirmed: while not dest_confirmed:
# if the given destination already exists, the starting point for browsing is its parent directory. # if the given destination already exists, the starting point for browsing is its parent directory.
# the user may have made a typo, or otherwise wants to place the root dir next to an existing one. # the user may have made a typo, or otherwise wants to place the root dir next to an existing one.
# if the destination dir does NOT exist, then the user must have changed their mind about the selection. # if the destination dir does NOT exist, then the user must have changed their mind about the selection.
@ -300,15 +302,20 @@ def introduction() -> None:
) )
console.line(2) console.line(2)
def _platform_specific_help() -> str: def _platform_specific_help() -> str:
if OS == "Darwin": if OS == "Darwin":
text = Text.from_markup("""[b wheat1]macOS Users![/]\n\nPlease be sure you have the [b wheat1]Xcode command-line tools[/] installed before continuing.\nIf not, cancel with [i]Control-C[/] and follow the Xcode install instructions at [deep_sky_blue1]https://www.freecodecamp.org/news/install-xcode-command-line-tools/[/].""") text = Text.from_markup(
"""[b wheat1]macOS Users![/]\n\nPlease be sure you have the [b wheat1]Xcode command-line tools[/] installed before continuing.\nIf not, cancel with [i]Control-C[/] and follow the Xcode install instructions at [deep_sky_blue1]https://www.freecodecamp.org/news/install-xcode-command-line-tools/[/]."""
)
elif OS == "Windows": elif OS == "Windows":
text = Text.from_markup("""[b wheat1]Windows Users![/]\n\nBefore you start, please do the following: text = Text.from_markup(
"""[b wheat1]Windows Users![/]\n\nBefore you start, please do the following:
1. Double-click on the file [b wheat1]WinLongPathsEnabled.reg[/] in order to 1. Double-click on the file [b wheat1]WinLongPathsEnabled.reg[/] in order to
enable long path support on your system. enable long path support on your system.
2. Make sure you have the [b wheat1]Visual C++ core libraries[/] installed. If not, install from 2. Make sure you have the [b wheat1]Visual C++ core libraries[/] installed. If not, install from
[deep_sky_blue1]https://learn.microsoft.com/en-US/cpp/windows/latest-supported-vc-redist?view=msvc-170[/]""") [deep_sky_blue1]https://learn.microsoft.com/en-US/cpp/windows/latest-supported-vc-redist?view=msvc-170[/]"""
)
else: else:
text = "" text = ""
return text return text

View File

@ -41,7 +41,7 @@ IF /I "%choice%" == "1" (
python .venv\Scripts\invokeai-configure.exe --skip-sd-weight --skip-support-models python .venv\Scripts\invokeai-configure.exe --skip-sd-weight --skip-support-models
) ELSE IF /I "%choice%" == "7" ( ) ELSE IF /I "%choice%" == "7" (
echo Running invokeai-configure... echo Running invokeai-configure...
python .venv\Scripts\invokeai-configure.exe --yes --default_only python .venv\Scripts\invokeai-configure.exe --yes --skip-sd-weight
) ELSE IF /I "%choice%" == "8" ( ) ELSE IF /I "%choice%" == "8" (
echo Developer Console echo Developer Console
echo Python command is: echo Python command is:

View File

@ -82,7 +82,7 @@ do_choice() {
7) 7)
clear clear
printf "Re-run the configure script to fix a broken install or to complete a major upgrade\n" printf "Re-run the configure script to fix a broken install or to complete a major upgrade\n"
invokeai-configure --root ${INVOKEAI_ROOT} --yes --default_only invokeai-configure --root ${INVOKEAI_ROOT} --yes --default_only --skip-sd-weights
;; ;;
8) 8)
clear clear

View File

@ -78,9 +78,7 @@ class ApiDependencies:
image_record_storage = SqliteImageRecordStorage(db_location) image_record_storage = SqliteImageRecordStorage(db_location)
image_file_storage = DiskImageFileStorage(f"{output_folder}/images") image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
names = SimpleNameService() names = SimpleNameService()
latents = ForwardCacheLatentsStorage( latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
DiskLatentsStorage(f"{output_folder}/latents")
)
board_record_storage = SqliteBoardRecordStorage(db_location) board_record_storage = SqliteBoardRecordStorage(db_location)
board_image_record_storage = SqliteBoardImageRecordStorage(db_location) board_image_record_storage = SqliteBoardImageRecordStorage(db_location)
@ -125,9 +123,7 @@ class ApiDependencies:
boards=boards, boards=boards,
board_images=board_images, board_images=board_images,
queue=MemoryInvocationQueue(), queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph]( graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
filename=db_location, table_name="graphs"
),
graph_execution_manager=graph_execution_manager, graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(), processor=DefaultInvocationProcessor(),
configuration=config, configuration=config,

View File

@ -1,14 +1,21 @@
import typing
from enum import Enum from enum import Enum
from fastapi import Body from fastapi import Body
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from pathlib import Path
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from invokeai.backend.image_util.patchmatch import PatchMatch from invokeai.backend.image_util.patchmatch import PatchMatch
from invokeai.backend.image_util.safety_checker import SafetyChecker
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
from invokeai.app.invocations.upscale import ESRGAN_MODELS
from invokeai.version import __version__ from invokeai.version import __version__
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
from invokeai.backend.util.logging import logging from invokeai.backend.util.logging import logging
class LogLevel(int, Enum): class LogLevel(int, Enum):
NotSet = logging.NOTSET NotSet = logging.NOTSET
Debug = logging.DEBUG Debug = logging.DEBUG
@ -17,6 +24,12 @@ class LogLevel(int, Enum):
Error = logging.ERROR Error = logging.ERROR
Critical = logging.CRITICAL Critical = logging.CRITICAL
class Upscaler(BaseModel):
upscaling_method: str = Field(description="Name of upscaling method")
upscaling_models: list[str] = Field(description="List of upscaling models for this method")
app_router = APIRouter(prefix="/v1/app", tags=["app"]) app_router = APIRouter(prefix="/v1/app", tags=["app"])
@ -30,23 +43,42 @@ class AppConfig(BaseModel):
"""App Config Response""" """App Config Response"""
infill_methods: list[str] = Field(description="List of available infill methods") infill_methods: list[str] = Field(description="List of available infill methods")
upscaling_methods: list[Upscaler] = Field(description="List of upscaling methods")
nsfw_methods: list[str] = Field(description="List of NSFW checking methods")
watermarking_methods: list[str] = Field(description="List of invisible watermark methods")
@app_router.get( @app_router.get("/version", operation_id="app_version", status_code=200, response_model=AppVersion)
"/version", operation_id="app_version", status_code=200, response_model=AppVersion
)
async def get_version() -> AppVersion: async def get_version() -> AppVersion:
return AppVersion(version=__version__) return AppVersion(version=__version__)
@app_router.get( @app_router.get("/config", operation_id="get_config", status_code=200, response_model=AppConfig)
"/config", operation_id="get_config", status_code=200, response_model=AppConfig
)
async def get_config() -> AppConfig: async def get_config() -> AppConfig:
infill_methods = ['tile'] infill_methods = ["tile"]
if PatchMatch.patchmatch_available(): if PatchMatch.patchmatch_available():
infill_methods.append('patchmatch') infill_methods.append("patchmatch")
return AppConfig(infill_methods=infill_methods)
upscaling_models = []
for model in typing.get_args(ESRGAN_MODELS):
upscaling_models.append(str(Path(model).stem))
upscaler = Upscaler(upscaling_method="esrgan", upscaling_models=upscaling_models)
nsfw_methods = []
if SafetyChecker.safety_checker_available():
nsfw_methods.append("nsfw_checker")
watermarking_methods = []
if InvisibleWatermark.invisible_watermark_available():
watermarking_methods.append("invisible_watermark")
return AppConfig(
infill_methods=infill_methods,
upscaling_methods=[upscaler],
nsfw_methods=nsfw_methods,
watermarking_methods=watermarking_methods,
)
@app_router.get( @app_router.get(
"/logging", "/logging",
@ -54,11 +86,11 @@ async def get_config() -> AppConfig:
responses={200: {"description": "The operation was successful"}}, responses={200: {"description": "The operation was successful"}},
response_model=LogLevel, response_model=LogLevel,
) )
async def get_log_level( async def get_log_level() -> LogLevel:
) -> LogLevel:
"""Returns the log level""" """Returns the log level"""
return LogLevel(ApiDependencies.invoker.services.logger.level) return LogLevel(ApiDependencies.invoker.services.logger.level)
@app_router.post( @app_router.post(
"/logging", "/logging",
operation_id="set_log_level", operation_id="set_log_level",

View File

@ -52,4 +52,3 @@ async def remove_board_image(
return result return result
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail="Failed to update board") raise HTTPException(status_code=500, detail="Failed to update board")

View File

@ -18,9 +18,7 @@ class DeleteBoardResult(BaseModel):
deleted_board_images: list[str] = Field( deleted_board_images: list[str] = Field(
description="The image names of the board-images relationships that were deleted." description="The image names of the board-images relationships that were deleted."
) )
deleted_images: list[str] = Field( deleted_images: list[str] = Field(description="The names of the images that were deleted.")
description="The names of the images that were deleted."
)
@boards_router.post( @boards_router.post(
@ -73,22 +71,16 @@ async def update_board(
) -> BoardDTO: ) -> BoardDTO:
"""Updates a board""" """Updates a board"""
try: try:
result = ApiDependencies.invoker.services.boards.update( result = ApiDependencies.invoker.services.boards.update(board_id=board_id, changes=changes)
board_id=board_id, changes=changes
)
return result return result
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail="Failed to update board") raise HTTPException(status_code=500, detail="Failed to update board")
@boards_router.delete( @boards_router.delete("/{board_id}", operation_id="delete_board", response_model=DeleteBoardResult)
"/{board_id}", operation_id="delete_board", response_model=DeleteBoardResult
)
async def delete_board( async def delete_board(
board_id: str = Path(description="The id of board to delete"), board_id: str = Path(description="The id of board to delete"),
include_images: Optional[bool] = Query( include_images: Optional[bool] = Query(description="Permanently delete all images on the board", default=False),
description="Permanently delete all images on the board", default=False
),
) -> DeleteBoardResult: ) -> DeleteBoardResult:
"""Deletes a board""" """Deletes a board"""
try: try:
@ -96,9 +88,7 @@ async def delete_board(
deleted_images = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board( deleted_images = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board(
board_id=board_id board_id=board_id
) )
ApiDependencies.invoker.services.images.delete_images_on_board( ApiDependencies.invoker.services.images.delete_images_on_board(board_id=board_id)
board_id=board_id
)
ApiDependencies.invoker.services.boards.delete(board_id=board_id) ApiDependencies.invoker.services.boards.delete(board_id=board_id)
return DeleteBoardResult( return DeleteBoardResult(
board_id=board_id, board_id=board_id,
@ -127,9 +117,7 @@ async def delete_board(
async def list_boards( async def list_boards(
all: Optional[bool] = Query(default=None, description="Whether to list all boards"), all: Optional[bool] = Query(default=None, description="Whether to list all boards"),
offset: Optional[int] = Query(default=None, description="The page offset"), offset: Optional[int] = Query(default=None, description="The page offset"),
limit: Optional[int] = Query( limit: Optional[int] = Query(default=None, description="The number of boards per page"),
default=None, description="The number of boards per page"
),
) -> Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]]: ) -> Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]]:
"""Gets a list of boards""" """Gets a list of boards"""
if all: if all:

View File

@ -40,15 +40,9 @@ async def upload_image(
response: Response, response: Response,
image_category: ImageCategory = Query(description="The category of the image"), image_category: ImageCategory = Query(description="The category of the image"),
is_intermediate: bool = Query(description="Whether this is an intermediate image"), is_intermediate: bool = Query(description="Whether this is an intermediate image"),
board_id: Optional[str] = Query( board_id: Optional[str] = Query(default=None, description="The board to add this image to, if any"),
default=None, description="The board to add this image to, if any" session_id: Optional[str] = Query(default=None, description="The session ID associated with this upload, if any"),
), crop_visible: Optional[bool] = Query(default=False, description="Whether to crop the image"),
session_id: Optional[str] = Query(
default=None, description="The session ID associated with this upload, if any"
),
crop_visible: Optional[bool] = Query(
default=False, description="Whether to crop the image"
),
) -> ImageDTO: ) -> ImageDTO:
"""Uploads an image""" """Uploads an image"""
if not file.content_type.startswith("image"): if not file.content_type.startswith("image"):
@ -115,9 +109,7 @@ async def clear_intermediates() -> int:
) )
async def update_image( async def update_image(
image_name: str = Path(description="The name of the image to update"), image_name: str = Path(description="The name of the image to update"),
image_changes: ImageRecordChanges = Body( image_changes: ImageRecordChanges = Body(description="The changes to apply to the image"),
description="The changes to apply to the image"
),
) -> ImageDTO: ) -> ImageDTO:
"""Updates an image""" """Updates an image"""
@ -212,15 +204,11 @@ async def get_image_thumbnail(
"""Gets a thumbnail image file""" """Gets a thumbnail image file"""
try: try:
path = ApiDependencies.invoker.services.images.get_path( path = ApiDependencies.invoker.services.images.get_path(image_name, thumbnail=True)
image_name, thumbnail=True
)
if not ApiDependencies.invoker.services.images.validate_path(path): if not ApiDependencies.invoker.services.images.validate_path(path):
raise HTTPException(status_code=404) raise HTTPException(status_code=404)
response = FileResponse( response = FileResponse(path, media_type="image/webp", content_disposition_type="inline")
path, media_type="image/webp", content_disposition_type="inline"
)
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}" response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
return response return response
except Exception as e: except Exception as e:
@ -239,9 +227,7 @@ async def get_image_urls(
try: try:
image_url = ApiDependencies.invoker.services.images.get_url(image_name) image_url = ApiDependencies.invoker.services.images.get_url(image_name)
thumbnail_url = ApiDependencies.invoker.services.images.get_url( thumbnail_url = ApiDependencies.invoker.services.images.get_url(image_name, thumbnail=True)
image_name, thumbnail=True
)
return ImageUrlsDTO( return ImageUrlsDTO(
image_name=image_name, image_name=image_name,
image_url=image_url, image_url=image_url,
@ -257,15 +243,9 @@ async def get_image_urls(
response_model=OffsetPaginatedResults[ImageDTO], response_model=OffsetPaginatedResults[ImageDTO],
) )
async def list_image_dtos( async def list_image_dtos(
image_origin: Optional[ResourceOrigin] = Query( image_origin: Optional[ResourceOrigin] = Query(default=None, description="The origin of images to list."),
default=None, description="The origin of images to list." categories: Optional[list[ImageCategory]] = Query(default=None, description="The categories of image to include."),
), is_intermediate: Optional[bool] = Query(default=None, description="Whether to list intermediate images."),
categories: Optional[list[ImageCategory]] = Query(
default=None, description="The categories of image to include."
),
is_intermediate: Optional[bool] = Query(
default=None, description="Whether to list intermediate images."
),
board_id: Optional[str] = Query( board_id: Optional[str] = Query(
default=None, default=None,
description="The board id to filter by. Use 'none' to find images without a board.", description="The board id to filter by. Use 'none' to find images without a board.",

View File

@ -28,9 +28,11 @@ ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)] ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
class ModelsList(BaseModel): class ModelsList(BaseModel):
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]] models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
@models_router.get( @models_router.get(
"/", "/",
operation_id="list_models", operation_id="list_models",
@ -50,10 +52,12 @@ async def list_models(
models = parse_obj_as(ModelsList, {"models": models_raw}) models = parse_obj_as(ModelsList, {"models": models_raw})
return models return models
@models_router.patch( @models_router.patch(
"/{base_model}/{model_type}/{model_name}", "/{base_model}/{model_type}/{model_name}",
operation_id="update_model", operation_id="update_model",
responses={200: {"description" : "The model was updated successfully"}, responses={
200: {"description": "The model was updated successfully"},
400: {"description": "Bad request"}, 400: {"description": "Bad request"},
404: {"description": "The model could not be found"}, 404: {"description": "The model could not be found"},
409: {"description": "There is already a model corresponding to the new name"}, 409: {"description": "There is already a model corresponding to the new name"},
@ -70,7 +74,6 @@ async def update_model(
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed.""" """Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
try: try:
previous_info = ApiDependencies.invoker.services.model_manager.list_model( previous_info = ApiDependencies.invoker.services.model_manager.list_model(
model_name=model_name, model_name=model_name,
@ -87,7 +90,7 @@ async def update_model(
new_name=info.model_name, new_name=info.model_name,
new_base=info.base_model, new_base=info.base_model,
) )
logger.info(f'Successfully renamed {base_model}/{model_name}=>{info.base_model}/{info.model_name}') logger.info(f"Successfully renamed {base_model.value}/{model_name}=>{info.base_model}/{info.model_name}")
# update information to support an update of attributes # update information to support an update of attributes
model_name = info.model_name model_name = info.model_name
base_model = info.base_model base_model = info.base_model
@ -96,14 +99,13 @@ async def update_model(
base_model=base_model, base_model=base_model,
model_type=model_type, model_type=model_type,
) )
if new_info.get('path') != previous_info.get('path'): # model manager moved model path during rename - don't overwrite it if new_info.get("path") != previous_info.get(
info.path = new_info.get('path') "path"
): # model manager moved model path during rename - don't overwrite it
info.path = new_info.get("path")
ApiDependencies.invoker.services.model_manager.update_model( ApiDependencies.invoker.services.model_manager.update_model(
model_name=model_name, model_name=model_name, base_model=base_model, model_type=model_type, model_attributes=info.dict()
base_model=base_model,
model_type=model_type,
model_attributes=info.dict()
) )
model_raw = ApiDependencies.invoker.services.model_manager.list_model( model_raw = ApiDependencies.invoker.services.model_manager.list_model(
@ -123,6 +125,7 @@ async def update_model(
return model_response return model_response
@models_router.post( @models_router.post(
"/import", "/import",
operation_id="import_model", operation_id="import_model",
@ -134,12 +137,13 @@ async def update_model(
409: {"description": "There is already a model corresponding to this path or repo_id"}, 409: {"description": "There is already a model corresponding to this path or repo_id"},
}, },
status_code=201, status_code=201,
response_model=ImportModelResponse response_model=ImportModelResponse,
) )
async def import_model( async def import_model(
location: str = Body(description="A model path, repo_id or URL to import"), location: str = Body(description="A model path, repo_id or URL to import"),
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \ prediction_type: Optional[Literal["v_prediction", "epsilon", "sample"]] = Body(
Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"), description="Prediction type for SDv2 checkpoint files", default="v_prediction"
),
) -> ImportModelResponse: ) -> ImportModelResponse:
"""Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically""" """Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically"""
@ -149,8 +153,7 @@ async def import_model(
try: try:
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import( installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
items_to_import = items_to_import, items_to_import=items_to_import, prediction_type_helper=lambda x: prediction_types.get(prediction_type)
prediction_type_helper = lambda x: prediction_types.get(prediction_type)
) )
info = installed_models.get(location) info = installed_models.get(location)
@ -158,11 +161,9 @@ async def import_model(
logger.error("Import failed") logger.error("Import failed")
raise HTTPException(status_code=415) raise HTTPException(status_code=415)
logger.info(f'Successfully imported {location}, got {info}') logger.info(f"Successfully imported {location}, got {info}")
model_raw = ApiDependencies.invoker.services.model_manager.list_model( model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=info.name, model_name=info.name, base_model=info.base_model, model_type=info.model_type
base_model=info.base_model,
model_type=info.model_type
) )
return parse_obj_as(ImportModelResponse, model_raw) return parse_obj_as(ImportModelResponse, model_raw)
@ -176,6 +177,7 @@ async def import_model(
logger.error(str(e)) logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e)) raise HTTPException(status_code=409, detail=str(e))
@models_router.post( @models_router.post(
"/add", "/add",
operation_id="add_model", operation_id="add_model",
@ -186,7 +188,7 @@ async def import_model(
409: {"description": "There is already a model corresponding to this path or repo_id"}, 409: {"description": "There is already a model corresponding to this path or repo_id"},
}, },
status_code=201, status_code=201,
response_model=ImportModelResponse response_model=ImportModelResponse,
) )
async def add_model( async def add_model(
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"), info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
@ -197,16 +199,11 @@ async def add_model(
try: try:
ApiDependencies.invoker.services.model_manager.add_model( ApiDependencies.invoker.services.model_manager.add_model(
info.model_name, info.model_name, info.base_model, info.model_type, model_attributes=info.dict()
info.base_model,
info.model_type,
model_attributes = info.dict()
) )
logger.info(f'Successfully added {info.model_name}') logger.info(f"Successfully added {info.model_name}")
model_raw = ApiDependencies.invoker.services.model_manager.list_model( model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=info.model_name, model_name=info.model_name, base_model=info.base_model, model_type=info.model_type
base_model=info.base_model,
model_type=info.model_type
) )
return parse_obj_as(ImportModelResponse, model_raw) return parse_obj_as(ImportModelResponse, model_raw)
except ModelNotFoundException as e: except ModelNotFoundException as e:
@ -220,10 +217,7 @@ async def add_model(
@models_router.delete( @models_router.delete(
"/{base_model}/{model_type}/{model_name}", "/{base_model}/{model_type}/{model_name}",
operation_id="del_model", operation_id="del_model",
responses={ responses={204: {"description": "Model deleted successfully"}, 404: {"description": "Model not found"}},
204: { "description": "Model deleted successfully" },
404: { "description": "Model not found" }
},
status_code=204, status_code=204,
response_model=None, response_model=None,
) )
@ -236,9 +230,8 @@ async def delete_model(
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
try: try:
ApiDependencies.invoker.services.model_manager.del_model(model_name, ApiDependencies.invoker.services.model_manager.del_model(
base_model = base_model, model_name, base_model=base_model, model_type=model_type
model_type = model_type
) )
logger.info(f"Deleted model: {model_name}") logger.info(f"Deleted model: {model_name}")
return Response(status_code=204) return Response(status_code=204)
@ -246,6 +239,7 @@ async def delete_model(
logger.error(str(e)) logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
@models_router.put( @models_router.put(
"/convert/{base_model}/{model_type}/{model_name}", "/convert/{base_model}/{model_type}/{model_name}",
operation_id="convert_model", operation_id="convert_model",
@ -261,21 +255,24 @@ async def convert_model(
base_model: BaseModelType = Path(description="Base model"), base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"), model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="model name"), model_name: str = Path(description="model name"),
convert_dest_directory: Optional[str] = Query(default=None, description="Save the converted model to the designated directory"), convert_dest_directory: Optional[str] = Query(
default=None, description="Save the converted model to the designated directory"
),
) -> ConvertModelResponse: ) -> ConvertModelResponse:
"""Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none.""" """Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none."""
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
try: try:
logger.info(f"Converting model: {model_name}") logger.info(f"Converting model: {model_name}")
dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None
ApiDependencies.invoker.services.model_manager.convert_model(model_name, ApiDependencies.invoker.services.model_manager.convert_model(
model_name,
base_model=base_model, base_model=base_model,
model_type=model_type, model_type=model_type,
convert_dest_directory=dest, convert_dest_directory=dest,
) )
model_raw = ApiDependencies.invoker.services.model_manager.list_model(model_name, model_raw = ApiDependencies.invoker.services.model_manager.list_model(
base_model = base_model, model_name, base_model=base_model, model_type=model_type
model_type = model_type) )
response = parse_obj_as(ConvertModelResponse, model_raw) response = parse_obj_as(ConvertModelResponse, model_raw)
except ModelNotFoundException as e: except ModelNotFoundException as e:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}") raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}")
@ -283,6 +280,7 @@ async def convert_model(
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
return response return response
@models_router.get( @models_router.get(
"/search", "/search",
operation_id="search_for_models", operation_id="search_for_models",
@ -291,15 +289,18 @@ async def convert_model(
404: {"description": "Invalid directory path"}, 404: {"description": "Invalid directory path"},
}, },
status_code=200, status_code=200,
response_model = List[pathlib.Path] response_model=List[pathlib.Path],
) )
async def search_for_models( async def search_for_models(
search_path: pathlib.Path = Query(description="Directory path to search for models") search_path: pathlib.Path = Query(description="Directory path to search for models"),
) -> List[pathlib.Path]: ) -> List[pathlib.Path]:
if not search_path.is_dir(): if not search_path.is_dir():
raise HTTPException(status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory") raise HTTPException(
status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory"
)
return ApiDependencies.invoker.services.model_manager.search_for_models(search_path) return ApiDependencies.invoker.services.model_manager.search_for_models(search_path)
@models_router.get( @models_router.get(
"/ckpt_confs", "/ckpt_confs",
operation_id="list_ckpt_configs", operation_id="list_ckpt_configs",
@ -307,10 +308,9 @@ async def search_for_models(
200: {"description": "paths retrieved successfully"}, 200: {"description": "paths retrieved successfully"},
}, },
status_code=200, status_code=200,
response_model = List[pathlib.Path] response_model=List[pathlib.Path],
) )
async def list_ckpt_configs( async def list_ckpt_configs() -> List[pathlib.Path]:
)->List[pathlib.Path]:
"""Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT.""" """Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT."""
return ApiDependencies.invoker.services.model_manager.list_checkpoint_configs() return ApiDependencies.invoker.services.model_manager.list_checkpoint_configs()
@ -322,15 +322,15 @@ async def list_ckpt_configs(
201: {"description": "synchronization successful"}, 201: {"description": "synchronization successful"},
}, },
status_code=201, status_code=201,
response_model = bool response_model=bool,
) )
async def sync_to_config( async def sync_to_config() -> bool:
)->bool:
"""Call after making changes to models.yaml, autoimport directories or models directory to synchronize """Call after making changes to models.yaml, autoimport directories or models directory to synchronize
in-memory data structures with disk data structures.""" in-memory data structures with disk data structures."""
ApiDependencies.invoker.services.model_manager.sync_to_config() ApiDependencies.invoker.services.model_manager.sync_to_config()
return True return True
@models_router.put( @models_router.put(
"/merge/{base_model}", "/merge/{base_model}",
operation_id="merge_models", operation_id="merge_models",
@ -348,23 +348,30 @@ async def merge_models(
merged_model_name: Optional[str] = Body(description="Name of destination model"), merged_model_name: Optional[str] = Body(description="Name of destination model"),
alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5), alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"), interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"),
force: Optional[bool] = Body(description="Force merging of models created with different versions of diffusers", default=False), force: Optional[bool] = Body(
merge_dest_directory: Optional[str] = Body(description="Save the merged model to the designated directory (with 'merged_model_name' appended)", default=None) description="Force merging of models created with different versions of diffusers", default=False
),
merge_dest_directory: Optional[str] = Body(
description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
default=None,
),
) -> MergeModelResponse: ) -> MergeModelResponse:
"""Convert a checkpoint model into a diffusers model""" """Convert a checkpoint model into a diffusers model"""
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
try: try:
logger.info(f"Merging models: {model_names} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}") logger.info(f"Merging models: {model_names} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
result = ApiDependencies.invoker.services.model_manager.merge_models(model_names, result = ApiDependencies.invoker.services.model_manager.merge_models(
model_names,
base_model, base_model,
merged_model_name=merged_model_name or "+".join(model_names), merged_model_name=merged_model_name or "+".join(model_names),
alpha=alpha, alpha=alpha,
interp=interp, interp=interp,
force=force, force=force,
merge_dest_directory = dest merge_dest_directory=dest,
) )
model_raw = ApiDependencies.invoker.services.model_manager.list_model(result.name, model_raw = ApiDependencies.invoker.services.model_manager.list_model(
result.name,
base_model=base_model, base_model=base_model,
model_type=ModelType.Main, model_type=ModelType.Main,
) )

View File

@ -30,9 +30,7 @@ session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"])
}, },
) )
async def create_session( async def create_session(
graph: Optional[Graph] = Body( graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with")
default=None, description="The graph to initialize the session with"
)
) -> GraphExecutionState: ) -> GraphExecutionState:
"""Creates a new session, optionally initializing it with an invocation graph""" """Creates a new session, optionally initializing it with an invocation graph"""
@ -56,13 +54,9 @@ async def list_sessions(
) -> PaginatedResults[GraphExecutionState]: ) -> PaginatedResults[GraphExecutionState]:
"""Gets a list of sessions, optionally searching""" """Gets a list of sessions, optionally searching"""
if query == "": if query == "":
result = ApiDependencies.invoker.services.graph_execution_manager.list( result = ApiDependencies.invoker.services.graph_execution_manager.list(page, per_page)
page, per_page
)
else: else:
result = ApiDependencies.invoker.services.graph_execution_manager.search( result = ApiDependencies.invoker.services.graph_execution_manager.search(query, page, per_page)
query, page, per_page
)
return result return result
@ -96,9 +90,9 @@ async def get_session(
) )
async def add_node( async def add_node(
session_id: str = Path(description="The id of the session"), session_id: str = Path(description="The id of the session"),
node: Annotated[ node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body( # type: ignore
Union[BaseInvocation.get_invocations()], Field(discriminator="type") # type: ignore description="The node to add"
] = Body(description="The node to add"), ),
) -> str: ) -> str:
"""Adds a node to the graph""" """Adds a node to the graph"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
@ -129,9 +123,9 @@ async def add_node(
async def update_node( async def update_node(
session_id: str = Path(description="The id of the session"), session_id: str = Path(description="The id of the session"),
node_path: str = Path(description="The path to the node in the graph"), node_path: str = Path(description="The path to the node in the graph"),
node: Annotated[ node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body( # type: ignore
Union[BaseInvocation.get_invocations()], Field(discriminator="type") # type: ignore description="The new node"
] = Body(description="The new node"), ),
) -> GraphExecutionState: ) -> GraphExecutionState:
"""Updates a node in the graph and removes all linked edges""" """Updates a node in the graph and removes all linked edges"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
@ -235,7 +229,7 @@ async def delete_edge(
try: try:
edge = Edge( edge = Edge(
source=EdgeConnection(node_id=from_node_id, field=from_field), source=EdgeConnection(node_id=from_node_id, field=from_field),
destination=EdgeConnection(node_id=to_node_id, field=to_field) destination=EdgeConnection(node_id=to_node_id, field=to_field),
) )
session.delete_edge(edge) session.delete_edge(edge)
ApiDependencies.invoker.services.graph_execution_manager.set( ApiDependencies.invoker.services.graph_execution_manager.set(
@ -260,9 +254,7 @@ async def delete_edge(
) )
async def invoke_session( async def invoke_session(
session_id: str = Path(description="The id of the session to invoke"), session_id: str = Path(description="The id of the session to invoke"),
all: bool = Query( all: bool = Query(default=False, description="Whether or not to invoke all remaining invocations"),
default=False, description="Whether or not to invoke all remaining invocations"
),
) -> Response: ) -> Response:
"""Invokes a session""" """Invokes a session"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
@ -279,9 +271,7 @@ async def invoke_session(
@session_router.delete( @session_router.delete(
"/{session_id}/invoke", "/{session_id}/invoke",
operation_id="cancel_session_invoke", operation_id="cancel_session_invoke",
responses={ responses={202: {"description": "The invocation is canceled"}},
202: {"description": "The invocation is canceled"}
},
) )
async def cancel_session_invoke( async def cancel_session_invoke(
session_id: str = Path(description="The id of the session to cancel"), session_id: str = Path(description="The id of the session to cancel"),

View File

@ -16,9 +16,7 @@ class SocketIO:
self.__sio.on("subscribe", handler=self._handle_sub) self.__sio.on("subscribe", handler=self._handle_sub)
self.__sio.on("unsubscribe", handler=self._handle_unsub) self.__sio.on("unsubscribe", handler=self._handle_unsub)
local_handler.register( local_handler.register(event_name=EventServiceBase.session_event, _func=self._handle_session_event)
event_name=EventServiceBase.session_event, _func=self._handle_session_event
)
async def _handle_session_event(self, event: Event): async def _handle_session_event(self, event: Event):
await self.__sio.emit( await self.__sio.emit(

View File

@ -3,6 +3,7 @@ import asyncio
import sys import sys
from inspect import signature from inspect import signature
import logging
import uvicorn import uvicorn
import socket import socket
@ -19,6 +20,7 @@ from pydantic.schema import schema
# This should come early so that modules can log their initialization properly # This should come early so that modules can log their initialization properly
from .services.config import InvokeAIAppConfig from .services.config import InvokeAIAppConfig
from ..backend.util.logging import InvokeAILogger from ..backend.util.logging import InvokeAILogger
app_config = InvokeAIAppConfig.get_config() app_config = InvokeAIAppConfig.get_config()
app_config.parse_args() app_config.parse_args()
logger = InvokeAILogger.getLogger(config=app_config) logger = InvokeAILogger.getLogger(config=app_config)
@ -27,7 +29,7 @@ from invokeai.version.invokeai_version import __version__
# we call this early so that the message appears before # we call this early so that the message appears before
# other invokeai initialization messages # other invokeai initialization messages
if app_config.version: if app_config.version:
print(f'InvokeAI version {__version__}') print(f"InvokeAI version {__version__}")
sys.exit(0) sys.exit(0)
import invokeai.frontend.web as web_dir import invokeai.frontend.web as web_dir
@ -41,13 +43,14 @@ from .invocations.baseinvocation import BaseInvocation
import torch import torch
import invokeai.backend.util.hotfixes import invokeai.backend.util.hotfixes
if torch.backends.mps.is_available(): if torch.backends.mps.is_available():
import invokeai.backend.util.mps_fixes import invokeai.backend.util.mps_fixes
# fix for windows mimetypes registry entries being borked # fix for windows mimetypes registry entries being borked
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352 # see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
mimetypes.add_type('application/javascript', '.js') mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type('text/css', '.css') mimetypes.add_type("text/css", ".css")
# Create the app # Create the app
# TODO: create this all in a method so configuration/etc. can be passed in? # TODO: create this all in a method so configuration/etc. can be passed in?
@ -57,14 +60,13 @@ app = FastAPI(title="Invoke AI", docs_url=None, redoc_url=None)
event_handler_id: int = id(app) event_handler_id: int = id(app)
app.add_middleware( app.add_middleware(
EventHandlerASGIMiddleware, EventHandlerASGIMiddleware,
handlers=[ handlers=[local_handler], # TODO: consider doing this in services to support different configurations
local_handler
], # TODO: consider doing this in services to support different configurations
middleware_id=event_handler_id, middleware_id=event_handler_id,
) )
socket_io = SocketIO(app) socket_io = SocketIO(app)
# Add startup event to load dependencies # Add startup event to load dependencies
@app.on_event("startup") @app.on_event("startup")
async def startup_event(): async def startup_event():
@ -76,9 +78,7 @@ async def startup_event():
allow_headers=app_config.allow_headers, allow_headers=app_config.allow_headers,
) )
ApiDependencies.initialize( ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
config=app_config, event_handler_id=event_handler_id, logger=logger
)
# Shut down threads # Shut down threads
@ -103,7 +103,8 @@ app.include_router(boards.boards_router, prefix="/api")
app.include_router(board_images.board_images_router, prefix="/api") app.include_router(board_images.board_images_router, prefix="/api")
app.include_router(app_info.app_router, prefix='/api') app.include_router(app_info.app_router, prefix="/api")
# Build a custom OpenAPI to include all outputs # Build a custom OpenAPI to include all outputs
# TODO: can outputs be included on metadata of invocation schemas somehow? # TODO: can outputs be included on metadata of invocation schemas somehow?
@ -144,6 +145,7 @@ def custom_openapi():
invoker_schema["output"] = outputs_ref invoker_schema["output"] = outputs_ref
from invokeai.backend.model_management.models import get_model_config_enums from invokeai.backend.model_management.models import get_model_config_enums
for model_config_format_enum in set(get_model_config_enums()): for model_config_format_enum in set(get_model_config_enums()):
name = model_config_format_enum.__qualname__ name = model_config_format_enum.__qualname__
@ -166,7 +168,8 @@ def custom_openapi():
app.openapi = custom_openapi app.openapi = custom_openapi
# Override API doc favicons # Override API doc favicons
app.mount("/static", StaticFiles(directory=Path(web_dir.__path__[0], 'static/dream_web')), name="static") app.mount("/static", StaticFiles(directory=Path(web_dir.__path__[0], "static/dream_web")), name="static")
@app.get("/docs", include_in_schema=False) @app.get("/docs", include_in_schema=False)
def overridden_swagger(): def overridden_swagger():
@ -187,11 +190,8 @@ def overridden_redoc():
# Must mount *after* the other routes else it borks em # Must mount *after* the other routes else it borks em
app.mount("/", app.mount("/", StaticFiles(directory=Path(web_dir.__path__[0], "dist"), html=True), name="ui")
StaticFiles(directory=Path(web_dir.__path__[0],"dist"),
html=True
), name="ui"
)
def invoke_api(): def invoke_api():
def find_port(port: int): def find_port(port: int):
@ -204,15 +204,34 @@ def invoke_api():
else: else:
return port return port
from invokeai.backend.install.check_root import check_invokeai_root
check_invokeai_root(app_config) # note, may exit with an exception if root not set up
port = find_port(app_config.port) port = find_port(app_config.port)
if port != app_config.port: if port != app_config.port:
logger.warn(f"Port {app_config.port} in use, using port {port}") logger.warn(f"Port {app_config.port} in use, using port {port}")
# Start our own event loop for eventing usage # Start our own event loop for eventing usage
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
config = uvicorn.Config(app=app, host=app_config.host, port=port, loop=loop) config = uvicorn.Config(
# Use access_log to turn off logging app=app,
host=app_config.host,
port=port,
loop=loop,
log_level=app_config.log_level,
)
server = uvicorn.Server(config) server = uvicorn.Server(config)
# replace uvicorn's loggers with InvokeAI's for consistent appearance
for logname in ["uvicorn.access", "uvicorn"]:
l = logging.getLogger(logname)
l.handlers.clear()
for ch in logger.handlers:
l.addHandler(ch)
loop.run_until_complete(server.serve()) loop.run_until_complete(server.serve())
if __name__ == "__main__": if __name__ == "__main__":
invoke_api() invoke_api()

View File

Before

Width:  |  Height:  |  Size: 33 KiB

After

Width:  |  Height:  |  Size: 33 KiB

View File

@ -15,7 +15,13 @@ from ..services.invoker import Invoker
def add_field_argument(command_parser, name: str, field, default_override=None): def add_field_argument(command_parser, name: str, field, default_override=None):
default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory() default = (
default_override
if default_override is not None
else field.default
if field.default_factory is None
else field.default_factory()
)
if get_origin(field.type_) == Literal: if get_origin(field.type_) == Literal:
allowed_values = get_args(field.type_) allowed_values = get_args(field.type_)
allowed_types = set() allowed_types = set()
@ -47,7 +53,7 @@ def add_parsers(
commands: list[type], commands: list[type],
command_field: str = "type", command_field: str = "type",
exclude_fields: list[str] = ["id", "type"], exclude_fields: list[str] = ["id", "type"],
add_arguments: Union[Callable[[argparse.ArgumentParser], None],None] = None add_arguments: Union[Callable[[argparse.ArgumentParser], None], None] = None,
): ):
"""Adds parsers for each command to the subparsers""" """Adds parsers for each command to the subparsers"""
@ -70,9 +76,7 @@ def add_parsers(
def add_graph_parsers( def add_graph_parsers(
subparsers, subparsers, graphs: list[LibraryGraph], add_arguments: Union[Callable[[argparse.ArgumentParser], None], None] = None
graphs: list[LibraryGraph],
add_arguments: Union[Callable[[argparse.ArgumentParser], None], None] = None
): ):
for graph in graphs: for graph in graphs:
command_parser = subparsers.add_parser(graph.name, help=graph.description) command_parser = subparsers.add_parser(graph.name, help=graph.description)
@ -128,6 +132,7 @@ class CliContext:
class ExitCli(Exception): class ExitCli(Exception):
"""Exception to exit the CLI""" """Exception to exit the CLI"""
pass pass
@ -155,7 +160,7 @@ class BaseCommand(ABC, BaseModel):
@classmethod @classmethod
def get_commands_map(cls): def get_commands_map(cls):
# Get the type strings out of the literals and into a dictionary # Get the type strings out of the literals and into a dictionary
return dict(map(lambda t: (get_args(get_type_hints(t)['type'])[0], t),BaseCommand.get_all_subclasses())) return dict(map(lambda t: (get_args(get_type_hints(t)["type"])[0], t), BaseCommand.get_all_subclasses()))
@abstractmethod @abstractmethod
def run(self, context: CliContext) -> None: def run(self, context: CliContext) -> None:
@ -165,7 +170,8 @@ class BaseCommand(ABC, BaseModel):
class ExitCommand(BaseCommand): class ExitCommand(BaseCommand):
"""Exits the CLI""" """Exits the CLI"""
type: Literal['exit'] = 'exit'
type: Literal["exit"] = "exit"
def run(self, context: CliContext) -> None: def run(self, context: CliContext) -> None:
raise ExitCli() raise ExitCli()
@ -173,7 +179,8 @@ class ExitCommand(BaseCommand):
class HelpCommand(BaseCommand): class HelpCommand(BaseCommand):
"""Shows help""" """Shows help"""
type: Literal['help'] = 'help'
type: Literal["help"] = "help"
def run(self, context: CliContext) -> None: def run(self, context: CliContext) -> None:
context.parser.print_help() context.parser.print_help()
@ -183,11 +190,7 @@ def get_graph_execution_history(
graph_execution_state: GraphExecutionState, graph_execution_state: GraphExecutionState,
) -> Iterable[str]: ) -> Iterable[str]:
"""Gets the history of fully-executed invocations for a graph execution""" """Gets the history of fully-executed invocations for a graph execution"""
return ( return (n for n in reversed(graph_execution_state.executed_history) if n in graph_execution_state.graph.nodes)
n
for n in reversed(graph_execution_state.executed_history)
if n in graph_execution_state.graph.nodes
)
def get_invocation_command(invocation) -> str: def get_invocation_command(invocation) -> str:
@ -218,7 +221,8 @@ def get_invocation_command(invocation) -> str:
class HistoryCommand(BaseCommand): class HistoryCommand(BaseCommand):
"""Shows the invocation history""" """Shows the invocation history"""
type: Literal['history'] = 'history'
type: Literal["history"] = "history"
# Inputs # Inputs
# fmt: off # fmt: off
@ -235,7 +239,8 @@ class HistoryCommand(BaseCommand):
class SetDefaultCommand(BaseCommand): class SetDefaultCommand(BaseCommand):
"""Sets a default value for a field""" """Sets a default value for a field"""
type: Literal['default'] = 'default'
type: Literal["default"] = "default"
# Inputs # Inputs
# fmt: off # fmt: off
@ -253,7 +258,8 @@ class SetDefaultCommand(BaseCommand):
class DrawGraphCommand(BaseCommand): class DrawGraphCommand(BaseCommand):
"""Debugs a graph""" """Debugs a graph"""
type: Literal['draw_graph'] = 'draw_graph'
type: Literal["draw_graph"] = "draw_graph"
def run(self, context: CliContext) -> None: def run(self, context: CliContext) -> None:
session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id) session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id)
@ -271,7 +277,8 @@ class DrawGraphCommand(BaseCommand):
class DrawExecutionGraphCommand(BaseCommand): class DrawExecutionGraphCommand(BaseCommand):
"""Debugs an execution graph""" """Debugs an execution graph"""
type: Literal['draw_xgraph'] = 'draw_xgraph'
type: Literal["draw_xgraph"] = "draw_xgraph"
def run(self, context: CliContext) -> None: def run(self, context: CliContext) -> None:
session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id) session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id)
@ -286,6 +293,7 @@ class DrawExecutionGraphCommand(BaseCommand):
plt.axis("off") plt.axis("off")
plt.show() plt.show()
class SortedHelpFormatter(argparse.HelpFormatter): class SortedHelpFormatter(argparse.HelpFormatter):
def _iter_indented_subactions(self, action): def _iter_indented_subactions(self, action):
try: try:

View File

@ -19,8 +19,8 @@ from ..services.invocation_services import InvocationServices
# singleton object, class variable # singleton object, class variable
completer = None completer = None
class Completer(object):
class Completer(object):
def __init__(self, model_manager: ModelManager): def __init__(self, model_manager: ModelManager):
self.commands = self.get_commands() self.commands = self.get_commands()
self.matches = None self.matches = None
@ -78,9 +78,9 @@ class Completer(object):
else: else:
switch = t switch = t
# don't try to autocomplete switches that are already complete # don't try to autocomplete switches that are already complete
if switch and buffer.endswith(' '): if switch and buffer.endswith(" "):
switch = None switch = None
return command or '', switch or '' return command or "", switch or ""
def parse_commands(self) -> Dict[str, List[str]]: def parse_commands(self) -> Dict[str, List[str]]:
""" """
@ -90,7 +90,7 @@ class Completer(object):
result = dict() result = dict()
for command in self.commands: for command in self.commands:
hints = get_type_hints(command) hints = get_type_hints(command)
name = get_args(hints['type'])[0] name = get_args(hints["type"])[0]
result.update({name: hints}) result.update({name: hints})
return result return result
@ -105,15 +105,18 @@ class Completer(object):
# handle switches in the format "-foo=bar" # handle switches in the format "-foo=bar"
argument = None argument = None
if switch and '=' in switch: if switch and "=" in switch:
switch, argument = switch.split('=') switch, argument = switch.split("=")
parameter = switch.strip('-') parameter = switch.strip("-")
if parameter in parsed_commands[command]: if parameter in parsed_commands[command]:
if argument is None: if argument is None:
return self.get_parameter_options(parameter, parsed_commands[command][parameter]) return self.get_parameter_options(parameter, parsed_commands[command][parameter])
else: else:
return [f"--{parameter}={x}" for x in self.get_parameter_options(parameter, parsed_commands[command][parameter])] return [
f"--{parameter}={x}"
for x in self.get_parameter_options(parameter, parsed_commands[command][parameter])
]
else: else:
return [f"--{x}" for x in parsed_commands[command].keys()] return [f"--{x}" for x in parsed_commands[command].keys()]
@ -123,7 +126,7 @@ class Completer(object):
""" """
if get_origin(typehint) == Literal: if get_origin(typehint) == Literal:
return get_args(typehint) return get_args(typehint)
if parameter == 'model': if parameter == "model":
return self.manager.model_names() return self.manager.model_names()
def _pre_input_hook(self): def _pre_input_hook(self):
@ -132,6 +135,7 @@ class Completer(object):
readline.redisplay() readline.redisplay()
self.linebuffer = None self.linebuffer = None
def set_autocompleter(services: InvocationServices) -> Completer: def set_autocompleter(services: InvocationServices) -> Completer:
global completer global completer
@ -162,8 +166,6 @@ def set_autocompleter(services: InvocationServices) -> Completer:
pass pass
except OSError: # file likely corrupted except OSError: # file likely corrupted
newname = f"{histfile}.old" newname = f"{histfile}.old"
logger.error( logger.error(f"Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}")
f"Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}"
)
histfile.replace(Path(newname)) histfile.replace(Path(newname))
atexit.register(readline.write_history_file, histfile) atexit.register(readline.write_history_file, histfile)

View File

@ -13,6 +13,7 @@ from pydantic.fields import Field
# This should come early so that the logger can pick up its configuration options # This should come early so that the logger can pick up its configuration options
from .services.config import InvokeAIAppConfig from .services.config import InvokeAIAppConfig
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
config = InvokeAIAppConfig.get_config() config = InvokeAIAppConfig.get_config()
config.parse_args() config.parse_args()
logger = InvokeAILogger().getLogger(config=config) logger = InvokeAILogger().getLogger(config=config)
@ -20,7 +21,7 @@ from invokeai.version.invokeai_version import __version__
# we call this early so that the message appears before other invokeai initialization messages # we call this early so that the message appears before other invokeai initialization messages
if config.version: if config.version:
print(f'InvokeAI version {__version__}') print(f"InvokeAI version {__version__}")
sys.exit(0) sys.exit(0)
from invokeai.app.services.board_image_record_storage import ( from invokeai.app.services.board_image_record_storage import (
@ -36,18 +37,21 @@ from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService, ImageServiceDependencies from invokeai.app.services.images import ImageService, ImageServiceDependencies
from invokeai.app.services.resource_name import SimpleNameService from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService from invokeai.app.services.urls import LocalUrlService
from .services.default_graphs import (default_text_to_image_graph_id, from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs
create_system_graphs)
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from .cli.commands import (BaseCommand, CliContext, ExitCli, from .cli.commands import BaseCommand, CliContext, ExitCli, SortedHelpFormatter, add_graph_parsers, add_parsers
SortedHelpFormatter, add_graph_parsers, add_parsers)
from .cli.completer import set_autocompleter from .cli.completer import set_autocompleter
from .invocations.baseinvocation import BaseInvocation from .invocations.baseinvocation import BaseInvocation
from .services.events import EventServiceBase from .services.events import EventServiceBase
from .services.graph import (Edge, EdgeConnection, GraphExecutionState, from .services.graph import (
GraphInvocation, LibraryGraph, Edge,
are_connection_types_compatible) EdgeConnection,
GraphExecutionState,
GraphInvocation,
LibraryGraph,
are_connection_types_compatible,
)
from .services.image_file_storage import DiskImageFileStorage from .services.image_file_storage import DiskImageFileStorage
from .services.invocation_queue import MemoryInvocationQueue from .services.invocation_queue import MemoryInvocationQueue
from .services.invocation_services import InvocationServices from .services.invocation_services import InvocationServices
@ -58,6 +62,7 @@ from .services.sqlite import SqliteItemStorage
import torch import torch
import invokeai.backend.util.hotfixes import invokeai.backend.util.hotfixes
if torch.backends.mps.is_available(): if torch.backends.mps.is_available():
import invokeai.backend.util.mps_fixes import invokeai.backend.util.mps_fixes
@ -69,6 +74,7 @@ class CliCommand(BaseModel):
class InvalidArgs(Exception): class InvalidArgs(Exception):
pass pass
def add_invocation_args(command_parser): def add_invocation_args(command_parser):
# Add linking capability # Add linking capability
command_parser.add_argument( command_parser.add_argument(
@ -113,7 +119,7 @@ def get_command_parser(services: InvocationServices) -> argparse.ArgumentParser:
return parser return parser
class NodeField(): class NodeField:
alias: str alias: str
node_path: str node_path: str
field: str field: str
@ -134,7 +140,12 @@ def get_node_input_field(graph: LibraryGraph, field_alias: str, node_id: str) ->
"""Gets the node field for the specified field alias""" """Gets the node field for the specified field alias"""
exposed_input = next(e for e in graph.exposed_inputs if e.alias == field_alias) exposed_input = next(e for e in graph.exposed_inputs if e.alias == field_alias)
node_type = type(graph.graph.get_node(exposed_input.node_path)) node_type = type(graph.graph.get_node(exposed_input.node_path))
return NodeField(alias=exposed_input.alias, node_path=f'{node_id}.{exposed_input.node_path}', field=exposed_input.field, field_type=get_type_hints(node_type)[exposed_input.field]) return NodeField(
alias=exposed_input.alias,
node_path=f"{node_id}.{exposed_input.node_path}",
field=exposed_input.field,
field_type=get_type_hints(node_type)[exposed_input.field],
)
def get_node_output_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField: def get_node_output_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField:
@ -142,7 +153,12 @@ def get_node_output_field(graph: LibraryGraph, field_alias: str, node_id: str) -
exposed_output = next(e for e in graph.exposed_outputs if e.alias == field_alias) exposed_output = next(e for e in graph.exposed_outputs if e.alias == field_alias)
node_type = type(graph.graph.get_node(exposed_output.node_path)) node_type = type(graph.graph.get_node(exposed_output.node_path))
node_output_type = node_type.get_output_type() node_output_type = node_type.get_output_type()
return NodeField(alias=exposed_output.alias, node_path=f'{node_id}.{exposed_output.node_path}', field=exposed_output.field, field_type=get_type_hints(node_output_type)[exposed_output.field]) return NodeField(
alias=exposed_output.alias,
node_path=f"{node_id}.{exposed_output.node_path}",
field=exposed_output.field,
field_type=get_type_hints(node_output_type)[exposed_output.field],
)
def get_node_inputs(invocation: BaseInvocation, context: CliContext) -> dict[str, NodeField]: def get_node_inputs(invocation: BaseInvocation, context: CliContext) -> dict[str, NodeField]:
@ -165,9 +181,7 @@ def get_node_outputs(invocation: BaseInvocation, context: CliContext) -> dict[st
return {e.alias: get_node_output_field(graph, e.alias, invocation.id) for e in graph.exposed_outputs} return {e.alias: get_node_output_field(graph, e.alias, invocation.id) for e in graph.exposed_outputs}
def generate_matching_edges( def generate_matching_edges(a: BaseInvocation, b: BaseInvocation, context: CliContext) -> list[Edge]:
a: BaseInvocation, b: BaseInvocation, context: CliContext
) -> list[Edge]:
"""Generates all possible edges between two invocations""" """Generates all possible edges between two invocations"""
afields = get_node_outputs(a, context) afields = get_node_outputs(a, context)
bfields = get_node_inputs(b, context) bfields = get_node_inputs(b, context)
@ -179,12 +193,14 @@ def generate_matching_edges(
matching_fields = matching_fields.difference(invalid_fields) matching_fields = matching_fields.difference(invalid_fields)
# Validate types # Validate types
matching_fields = [f for f in matching_fields if are_connection_types_compatible(afields[f].field_type, bfields[f].field_type)] matching_fields = [
f for f in matching_fields if are_connection_types_compatible(afields[f].field_type, bfields[f].field_type)
]
edges = [ edges = [
Edge( Edge(
source=EdgeConnection(node_id=afields[alias].node_path, field=afields[alias].field), source=EdgeConnection(node_id=afields[alias].node_path, field=afields[alias].field),
destination=EdgeConnection(node_id=bfields[alias].node_path, field=bfields[alias].field) destination=EdgeConnection(node_id=bfields[alias].node_path, field=bfields[alias].field),
) )
for alias in matching_fields for alias in matching_fields
] ]
@ -193,6 +209,7 @@ def generate_matching_edges(
class SessionError(Exception): class SessionError(Exception):
"""Raised when a session error has occurred""" """Raised when a session error has occurred"""
pass pass
@ -212,11 +229,12 @@ def invoke_all(context: CliContext):
raise SessionError() raise SessionError()
def invoke_cli(): def invoke_cli():
logger.info(f'InvokeAI version {__version__}') logger.info(f"InvokeAI version {__version__}")
# get the optional list of invocations to execute on the command line # get the optional list of invocations to execute on the command line
parser = config.get_parser() parser = config.get_parser()
parser.add_argument('commands',nargs='*') parser.add_argument("commands", nargs="*")
invocation_commands = parser.parse_args().commands invocation_commands = parser.parse_args().commands
# get the optional file to read commands from. # get the optional file to read commands from.
@ -285,21 +303,18 @@ def invoke_cli():
services = InvocationServices( services = InvocationServices(
model_manager=model_manager, model_manager=model_manager,
events=events, events=events,
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')), latents=ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")),
images=images, images=images,
boards=boards, boards=boards,
board_images=board_images, board_images=board_images,
queue=MemoryInvocationQueue(), queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph]( graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
filename=db_location, table_name="graphs"
),
graph_execution_manager=graph_execution_manager, graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(), processor=DefaultInvocationProcessor(),
logger=logger, logger=logger,
configuration=config, configuration=config,
) )
system_graphs = create_system_graphs(services.graph_library) system_graphs = create_system_graphs(services.graph_library)
system_graph_names = set([g.name for g in system_graphs]) system_graph_names = set([g.name for g in system_graphs])
set_autocompleter(services) set_autocompleter(services)
@ -308,7 +323,7 @@ def invoke_cli():
session: GraphExecutionState = invoker.create_execution_state() session: GraphExecutionState = invoker.create_execution_state()
parser = get_command_parser(services) parser = get_command_parser(services)
re_negid = re.compile('^-[0-9]+$') re_negid = re.compile("^-[0-9]+$")
# Uncomment to print out previous sessions at startup # Uncomment to print out previous sessions at startup
# print(services.session_manager.list()) # print(services.session_manager.list())
@ -355,8 +370,8 @@ def invoke_cli():
# Parse invocation # Parse invocation
command: CliCommand = None # type:ignore command: CliCommand = None # type:ignore
system_graph: Optional[LibraryGraph] = None system_graph: Optional[LibraryGraph] = None
if args['type'] in system_graph_names: if args["type"] in system_graph_names:
system_graph = next(filter(lambda g: g.name == args['type'], system_graphs)) system_graph = next(filter(lambda g: g.name == args["type"], system_graphs))
invocation = GraphInvocation(graph=system_graph.graph, id=str(current_id)) invocation = GraphInvocation(graph=system_graph.graph, id=str(current_id))
for exposed_input in system_graph.exposed_inputs: for exposed_input in system_graph.exposed_inputs:
if exposed_input.alias in args: if exposed_input.alias in args:
@ -385,17 +400,13 @@ def invoke_cli():
# Pipe previous command output (if there was a previous command) # Pipe previous command output (if there was a previous command)
edges: list[Edge] = list() edges: list[Edge] = list()
if len(history) > 0 or current_id != start_id: if len(history) > 0 or current_id != start_id:
from_id = ( from_id = history[0] if current_id == start_id else str(current_id - 1)
history[0] if current_id == start_id else str(current_id - 1)
)
from_node = ( from_node = (
next(filter(lambda n: n[0].id == from_id, new_invocations))[0] next(filter(lambda n: n[0].id == from_id, new_invocations))[0]
if current_id != start_id if current_id != start_id
else context.session.graph.get_node(from_id) else context.session.graph.get_node(from_id)
) )
matching_edges = generate_matching_edges( matching_edges = generate_matching_edges(from_node, command.command, context)
from_node, command.command, context
)
edges.extend(matching_edges) edges.extend(matching_edges)
# Parse provided links # Parse provided links
@ -406,16 +417,18 @@ def invoke_cli():
node_id = str(current_id + int(node_id)) node_id = str(current_id + int(node_id))
link_node = context.session.graph.get_node(node_id) link_node = context.session.graph.get_node(node_id)
matching_edges = generate_matching_edges( matching_edges = generate_matching_edges(link_node, command.command, context)
link_node, command.command, context
)
matching_destinations = [e.destination for e in matching_edges] matching_destinations = [e.destination for e in matching_edges]
edges = [e for e in edges if e.destination not in matching_destinations] edges = [e for e in edges if e.destination not in matching_destinations]
edges.extend(matching_edges) edges.extend(matching_edges)
if "link" in args and args["link"]: if "link" in args and args["link"]:
for link in args["link"]: for link in args["link"]:
edges = [e for e in edges if e.destination.node_id != command.command.id or e.destination.field != link[2]] edges = [
e
for e in edges
if e.destination.node_id != command.command.id or e.destination.field != link[2]
]
node_id = link[0] node_id = link[0]
if re_negid.match(node_id): if re_negid.match(node_id):
@ -428,7 +441,7 @@ def invoke_cli():
edges.append( edges.append(
Edge( Edge(
source=EdgeConnection(node_id=node_output.node_path, field=node_output.field), source=EdgeConnection(node_id=node_output.node_path, field=node_output.field),
destination=EdgeConnection(node_id=node_input.node_path, field=node_input.field) destination=EdgeConnection(node_id=node_input.node_path, field=node_input.field),
) )
) )

View File

@ -4,9 +4,5 @@ __all__ = []
dirname = os.path.dirname(os.path.abspath(__file__)) dirname = os.path.dirname(os.path.abspath(__file__))
for f in os.listdir(dirname): for f in os.listdir(dirname):
if ( if f != "__init__.py" and os.path.isfile("%s/%s" % (dirname, f)) and f[-3:] == ".py":
f != "__init__.py"
and os.path.isfile("%s/%s" % (dirname, f))
and f[-3:] == ".py"
):
__all__.append(f[:-3]) __all__.append(f[:-3])

View File

@ -4,8 +4,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from inspect import signature from inspect import signature
from typing import (TYPE_CHECKING, Dict, List, Literal, TypedDict, get_args, from typing import TYPE_CHECKING, Dict, List, Literal, TypedDict, get_args, get_type_hints
get_type_hints)
from pydantic import BaseConfig, BaseModel, Field from pydantic import BaseConfig, BaseModel, Field

View File

@ -8,8 +8,7 @@ from pydantic import Field, validator
from invokeai.app.models.image import ImageField from invokeai.app.models.image import ImageField
from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.app.util.misc import SEED_MAX, get_random_seed
from .baseinvocation import (BaseInvocation, BaseInvocationOutput, from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext, UIConfig
InvocationConfig, InvocationContext, UIConfig)
class IntCollectionOutput(BaseInvocationOutput): class IntCollectionOutput(BaseInvocationOutput):
@ -27,8 +26,7 @@ class FloatCollectionOutput(BaseInvocationOutput):
type: Literal["float_collection"] = "float_collection" type: Literal["float_collection"] = "float_collection"
# Outputs # Outputs
collection: list[float] = Field( collection: list[float] = Field(default=[], description="The float collection")
default=[], description="The float collection")
class ImageCollectionOutput(BaseInvocationOutput): class ImageCollectionOutput(BaseInvocationOutput):
@ -37,8 +35,7 @@ class ImageCollectionOutput(BaseInvocationOutput):
type: Literal["image_collection"] = "image_collection" type: Literal["image_collection"] = "image_collection"
# Outputs # Outputs
collection: list[ImageField] = Field( collection: list[ImageField] = Field(default=[], description="The output images")
default=[], description="The output images")
class Config: class Config:
schema_extra = {"required": ["type", "collection"]} schema_extra = {"required": ["type", "collection"]}
@ -56,10 +53,7 @@ class RangeInvocation(BaseInvocation):
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "Range", "tags": ["range", "integer", "collection"]},
"title": "Range",
"tags": ["range", "integer", "collection"]
},
} }
@validator("stop") @validator("stop")
@ -69,9 +63,7 @@ class RangeInvocation(BaseInvocation):
return v return v
def invoke(self, context: InvocationContext) -> IntCollectionOutput: def invoke(self, context: InvocationContext) -> IntCollectionOutput:
return IntCollectionOutput( return IntCollectionOutput(collection=list(range(self.start, self.stop, self.step)))
collection=list(range(self.start, self.stop, self.step))
)
class RangeOfSizeInvocation(BaseInvocation): class RangeOfSizeInvocation(BaseInvocation):
@ -86,18 +78,11 @@ class RangeOfSizeInvocation(BaseInvocation):
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "Sized Range", "tags": ["range", "integer", "size", "collection"]},
"title": "Sized Range",
"tags": ["range", "integer", "size", "collection"]
},
} }
def invoke(self, context: InvocationContext) -> IntCollectionOutput: def invoke(self, context: InvocationContext) -> IntCollectionOutput:
return IntCollectionOutput( return IntCollectionOutput(collection=list(range(self.start, self.start + self.size, self.step)))
collection=list(
range(
self.start, self.start + self.size,
self.step)))
class RandomRangeInvocation(BaseInvocation): class RandomRangeInvocation(BaseInvocation):
@ -107,9 +92,7 @@ class RandomRangeInvocation(BaseInvocation):
# Inputs # Inputs
low: int = Field(default=0, description="The inclusive low value") low: int = Field(default=0, description="The inclusive low value")
high: int = Field( high: int = Field(default=np.iinfo(np.int32).max, description="The exclusive high value")
default=np.iinfo(np.int32).max, description="The exclusive high value"
)
size: int = Field(default=1, description="The number of values to generate") size: int = Field(default=1, description="The number of values to generate")
seed: int = Field( seed: int = Field(
ge=0, ge=0,
@ -120,19 +103,12 @@ class RandomRangeInvocation(BaseInvocation):
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "Random Range", "tags": ["range", "integer", "random", "collection"]},
"title": "Random Range",
"tags": ["range", "integer", "random", "collection"]
},
} }
def invoke(self, context: InvocationContext) -> IntCollectionOutput: def invoke(self, context: InvocationContext) -> IntCollectionOutput:
rng = np.random.default_rng(self.seed) rng = np.random.default_rng(self.seed)
return IntCollectionOutput( return IntCollectionOutput(collection=list(rng.integers(low=self.low, high=self.high, size=self.size)))
collection=list(
rng.integers(
low=self.low, high=self.high,
size=self.size)))
class ImageCollectionInvocation(BaseInvocation): class ImageCollectionInvocation(BaseInvocation):

View File

@ -3,27 +3,24 @@ from pydantic import BaseModel, Field
import re import re
import torch import torch
from compel import Compel, ReturnedEmbeddingsType from compel import Compel, ReturnedEmbeddingsType
from compel.prompt_parser import (Blend, Conjunction, from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
CrossAttentionControlSubstitute,
FlattenedPrompt, Fragment)
from ...backend.util.devices import torch_dtype from ...backend.util.devices import torch_dtype
from ...backend.model_management import ModelType from ...backend.model_management import ModelType
from ...backend.model_management.models import ModelNotFoundException from ...backend.model_management.models import ModelNotFoundException
from ...backend.model_management.lora import ModelPatcher from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
from .baseinvocation import (BaseInvocation, BaseInvocationOutput, from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
InvocationConfig, InvocationContext)
from .model import ClipField from .model import ClipField
from dataclasses import dataclass from dataclasses import dataclass
class ConditioningField(BaseModel): class ConditioningField(BaseModel):
conditioning_name: Optional[str] = Field( conditioning_name: Optional[str] = Field(default=None, description="The name of conditioning data")
default=None, description="The name of conditioning data")
class Config: class Config:
schema_extra = {"required": ["conditioning_name"]} schema_extra = {"required": ["conditioning_name"]}
@dataclass @dataclass
class BasicConditioningInfo: class BasicConditioningInfo:
# type: Literal["basic_conditioning"] = "basic_conditioning" # type: Literal["basic_conditioning"] = "basic_conditioning"
@ -32,27 +29,29 @@ class BasicConditioningInfo:
# weight: float # weight: float
# mode: ConditioningAlgo # mode: ConditioningAlgo
@dataclass @dataclass
class SDXLConditioningInfo(BasicConditioningInfo): class SDXLConditioningInfo(BasicConditioningInfo):
# type: Literal["sdxl_conditioning"] = "sdxl_conditioning" # type: Literal["sdxl_conditioning"] = "sdxl_conditioning"
pooled_embeds: torch.Tensor pooled_embeds: torch.Tensor
add_time_ids: torch.Tensor add_time_ids: torch.Tensor
ConditioningInfoType = Annotated[
Union[BasicConditioningInfo, SDXLConditioningInfo], ConditioningInfoType = Annotated[Union[BasicConditioningInfo, SDXLConditioningInfo], Field(discriminator="type")]
Field(discriminator="type")
]
@dataclass @dataclass
class ConditioningFieldData: class ConditioningFieldData:
conditionings: List[Union[BasicConditioningInfo, SDXLConditioningInfo]] conditionings: List[Union[BasicConditioningInfo, SDXLConditioningInfo]]
# unconditioned: Optional[torch.Tensor] # unconditioned: Optional[torch.Tensor]
# class ConditioningAlgo(str, Enum): # class ConditioningAlgo(str, Enum):
# Compose = "compose" # Compose = "compose"
# ComposeEx = "compose_ex" # ComposeEx = "compose_ex"
# PerpNeg = "perp_neg" # PerpNeg = "perp_neg"
class CompelOutput(BaseInvocationOutput): class CompelOutput(BaseInvocationOutput):
"""Compel parser output""" """Compel parser output"""
@ -74,28 +73,23 @@ class CompelInvocation(BaseInvocation):
# Schema customisation # Schema customisation
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "Prompt (Compel)", "tags": ["prompt", "compel"], "type_hints": {"model": "model"}},
"title": "Prompt (Compel)",
"tags": ["prompt", "compel"],
"type_hints": {
"model": "model"
}
},
} }
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput: def invoke(self, context: InvocationContext) -> CompelOutput:
tokenizer_info = context.services.model_manager.get_model( tokenizer_info = context.services.model_manager.get_model(
**self.clip.tokenizer.dict(), context=context, **self.clip.tokenizer.dict(),
context=context,
) )
text_encoder_info = context.services.model_manager.get_model( text_encoder_info = context.services.model_manager.get_model(
**self.clip.text_encoder.dict(), context=context, **self.clip.text_encoder.dict(),
context=context,
) )
def _lora_loader(): def _lora_loader():
for lora in self.clip.loras: for lora in self.clip.loras:
lora_info = context.services.model_manager.get_model( lora_info = context.services.model_manager.get_model(**lora.dict(exclude={"weight"}), context=context)
**lora.dict(exclude={"weight"}))
yield (lora_info.context.model, lora.weight) yield (lora_info.context.model, lora.weight)
del lora_info del lora_info
return return
@ -118,13 +112,16 @@ class CompelInvocation(BaseInvocation):
# print(e) # print(e)
# import traceback # import traceback
# print(traceback.format_exc()) # print(traceback.format_exc())
print(f"Warn: trigger: \"{trigger}\" not found") print(f'Warn: trigger: "{trigger}" not found')
with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers),\
text_encoder_info as text_encoder:
with ModelPatcher.apply_lora_text_encoder(
text_encoder_info.context.model, _lora_loader()
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
tokenizer,
ti_manager,
), ModelPatcher.apply_clip_skip(
text_encoder_info.context.model, self.clip.skipped_layers
), text_encoder_info as text_encoder:
compel = Compel( compel = Compel(
tokenizer=tokenizer, tokenizer=tokenizer,
text_encoder=text_encoder, text_encoder=text_encoder,
@ -139,14 +136,12 @@ class CompelInvocation(BaseInvocation):
if context.services.configuration.log_tokenization: if context.services.configuration.log_tokenization:
log_tokenization_for_prompt_object(prompt, tokenizer) log_tokenization_for_prompt_object(prompt, tokenizer)
c, options = compel.build_conditioning_tensor_for_prompt_object( c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
prompt)
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo( ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
tokens_count_including_eos_bos=get_max_token_count( tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
tokenizer, conjunction), cross_attention_control_args=options.get("cross_attention_control", None),
cross_attention_control_args=options.get( )
"cross_attention_control", None),)
c = c.detach().to("cpu") c = c.detach().to("cpu")
@ -168,19 +163,21 @@ class CompelInvocation(BaseInvocation):
), ),
) )
class SDXLPromptInvocationBase: class SDXLPromptInvocationBase:
def run_clip_raw(self, context, clip_field, prompt, get_pooled): def run_clip_raw(self, context, clip_field, prompt, get_pooled):
tokenizer_info = context.services.model_manager.get_model( tokenizer_info = context.services.model_manager.get_model(
**clip_field.tokenizer.dict(), **clip_field.tokenizer.dict(),
context=context,
) )
text_encoder_info = context.services.model_manager.get_model( text_encoder_info = context.services.model_manager.get_model(
**clip_field.text_encoder.dict(), **clip_field.text_encoder.dict(),
context=context,
) )
def _lora_loader(): def _lora_loader():
for lora in clip_field.loras: for lora in clip_field.loras:
lora_info = context.services.model_manager.get_model( lora_info = context.services.model_manager.get_model(**lora.dict(exclude={"weight"}), context=context)
**lora.dict(exclude={"weight"}))
yield (lora_info.context.model, lora.weight) yield (lora_info.context.model, lora.weight)
del lora_info del lora_info
return return
@ -196,19 +193,23 @@ class SDXLPromptInvocationBase:
model_name=name, model_name=name,
base_model=clip_field.text_encoder.base_model, base_model=clip_field.text_encoder.base_model,
model_type=ModelType.TextualInversion, model_type=ModelType.TextualInversion,
context=context,
).context.model ).context.model
) )
except ModelNotFoundException: except ModelNotFoundException:
# print(e) # print(e)
# import traceback # import traceback
# print(traceback.format_exc()) # print(traceback.format_exc())
print(f"Warn: trigger: \"{trigger}\" not found") print(f'Warn: trigger: "{trigger}" not found')
with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),\
text_encoder_info as text_encoder:
with ModelPatcher.apply_lora_text_encoder(
text_encoder_info.context.model, _lora_loader()
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
tokenizer,
ti_manager,
), ModelPatcher.apply_clip_skip(
text_encoder_info.context.model, clip_field.skipped_layers
), text_encoder_info as text_encoder:
text_inputs = tokenizer( text_inputs = tokenizer(
prompt, prompt,
padding="max_length", padding="max_length",
@ -241,15 +242,16 @@ class SDXLPromptInvocationBase:
def run_clip_compel(self, context, clip_field, prompt, get_pooled): def run_clip_compel(self, context, clip_field, prompt, get_pooled):
tokenizer_info = context.services.model_manager.get_model( tokenizer_info = context.services.model_manager.get_model(
**clip_field.tokenizer.dict(), **clip_field.tokenizer.dict(),
context=context,
) )
text_encoder_info = context.services.model_manager.get_model( text_encoder_info = context.services.model_manager.get_model(
**clip_field.text_encoder.dict(), **clip_field.text_encoder.dict(),
context=context,
) )
def _lora_loader(): def _lora_loader():
for lora in clip_field.loras: for lora in clip_field.loras:
lora_info = context.services.model_manager.get_model( lora_info = context.services.model_manager.get_model(**lora.dict(exclude={"weight"}), context=context)
**lora.dict(exclude={"weight"}))
yield (lora_info.context.model, lora.weight) yield (lora_info.context.model, lora.weight)
del lora_info del lora_info
return return
@ -265,19 +267,23 @@ class SDXLPromptInvocationBase:
model_name=name, model_name=name,
base_model=clip_field.text_encoder.base_model, base_model=clip_field.text_encoder.base_model,
model_type=ModelType.TextualInversion, model_type=ModelType.TextualInversion,
context=context,
).context.model ).context.model
) )
except ModelNotFoundException: except ModelNotFoundException:
# print(e) # print(e)
# import traceback # import traceback
# print(traceback.format_exc()) # print(traceback.format_exc())
print(f"Warn: trigger: \"{trigger}\" not found") print(f'Warn: trigger: "{trigger}" not found')
with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),\
text_encoder_info as text_encoder:
with ModelPatcher.apply_lora_text_encoder(
text_encoder_info.context.model, _lora_loader()
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
tokenizer,
ti_manager,
), ModelPatcher.apply_clip_skip(
text_encoder_info.context.model, clip_field.skipped_layers
), text_encoder_info as text_encoder:
compel = Compel( compel = Compel(
tokenizer=tokenizer, tokenizer=tokenizer,
text_encoder=text_encoder, text_encoder=text_encoder,
@ -318,6 +324,7 @@ class SDXLPromptInvocationBase:
return c, c_pooled, ec return c, c_pooled, ec
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning.""" """Parse prompt using compel package to conditioning."""
@ -337,13 +344,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
# Schema customisation # Schema customisation
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "SDXL Prompt (Compel)", "tags": ["prompt", "compel"], "type_hints": {"model": "model"}},
"title": "SDXL Prompt (Compel)",
"tags": ["prompt", "compel"],
"type_hints": {
"model": "model"
}
},
} }
@torch.no_grad() @torch.no_grad()
@ -358,9 +359,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
crop_coords = (self.crop_top, self.crop_left) crop_coords = (self.crop_top, self.crop_left)
target_size = (self.target_height, self.target_width) target_size = (self.target_height, self.target_width)
add_time_ids = torch.tensor([ add_time_ids = torch.tensor([original_size + crop_coords + target_size])
original_size + crop_coords + target_size
])
conditioning_data = ConditioningFieldData( conditioning_data = ConditioningFieldData(
conditionings=[ conditionings=[
@ -382,6 +381,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
), ),
) )
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning.""" """Parse prompt using compel package to conditioning."""
@ -401,9 +401,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
"ui": { "ui": {
"title": "SDXL Refiner Prompt (Compel)", "title": "SDXL Refiner Prompt (Compel)",
"tags": ["prompt", "compel"], "tags": ["prompt", "compel"],
"type_hints": { "type_hints": {"model": "model"},
"model": "model"
}
}, },
} }
@ -414,9 +412,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
original_size = (self.original_height, self.original_width) original_size = (self.original_height, self.original_width)
crop_coords = (self.crop_top, self.crop_left) crop_coords = (self.crop_top, self.crop_left)
add_time_ids = torch.tensor([ add_time_ids = torch.tensor([original_size + crop_coords + (self.aesthetic_score,)])
original_size + crop_coords + (self.aesthetic_score,)
])
conditioning_data = ConditioningFieldData( conditioning_data = ConditioningFieldData(
conditionings=[ conditionings=[
@ -438,6 +434,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
), ),
) )
class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Pass unmodified prompt to conditioning without compel processing.""" """Pass unmodified prompt to conditioning without compel processing."""
@ -457,13 +454,7 @@ class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
# Schema customisation # Schema customisation
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "SDXL Prompt (Raw)", "tags": ["prompt", "compel"], "type_hints": {"model": "model"}},
"title": "SDXL Prompt (Raw)",
"tags": ["prompt", "compel"],
"type_hints": {
"model": "model"
}
},
} }
@torch.no_grad() @torch.no_grad()
@ -478,9 +469,7 @@ class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
crop_coords = (self.crop_top, self.crop_left) crop_coords = (self.crop_top, self.crop_left)
target_size = (self.target_height, self.target_width) target_size = (self.target_height, self.target_width)
add_time_ids = torch.tensor([ add_time_ids = torch.tensor([original_size + crop_coords + target_size])
original_size + crop_coords + target_size
])
conditioning_data = ConditioningFieldData( conditioning_data = ConditioningFieldData(
conditionings=[ conditionings=[
@ -502,6 +491,7 @@ class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
), ),
) )
class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning.""" """Parse prompt using compel package to conditioning."""
@ -521,9 +511,7 @@ class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"ui": { "ui": {
"title": "SDXL Refiner Prompt (Raw)", "title": "SDXL Refiner Prompt (Raw)",
"tags": ["prompt", "compel"], "tags": ["prompt", "compel"],
"type_hints": { "type_hints": {"model": "model"},
"model": "model"
}
}, },
} }
@ -534,9 +522,7 @@ class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
original_size = (self.original_height, self.original_width) original_size = (self.original_height, self.original_width)
crop_coords = (self.crop_top, self.crop_left) crop_coords = (self.crop_top, self.crop_left)
add_time_ids = torch.tensor([ add_time_ids = torch.tensor([original_size + crop_coords + (self.aesthetic_score,)])
original_size + crop_coords + (self.aesthetic_score,)
])
conditioning_data = ConditioningFieldData( conditioning_data = ConditioningFieldData(
conditionings=[ conditionings=[
@ -561,11 +547,14 @@ class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
class ClipSkipInvocationOutput(BaseInvocationOutput): class ClipSkipInvocationOutput(BaseInvocationOutput):
"""Clip skip node output""" """Clip skip node output"""
type: Literal["clip_skip_output"] = "clip_skip_output" type: Literal["clip_skip_output"] = "clip_skip_output"
clip: ClipField = Field(None, description="Clip with skipped layers") clip: ClipField = Field(None, description="Clip with skipped layers")
class ClipSkipInvocation(BaseInvocation): class ClipSkipInvocation(BaseInvocation):
"""Skip layers in clip text_encoder model.""" """Skip layers in clip text_encoder model."""
type: Literal["clip_skip"] = "clip_skip" type: Literal["clip_skip"] = "clip_skip"
clip: ClipField = Field(None, description="Clip to use") clip: ClipField = Field(None, description="Clip to use")
@ -573,10 +562,7 @@ class ClipSkipInvocation(BaseInvocation):
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "CLIP Skip", "tags": ["clip", "skip"]},
"title": "CLIP Skip",
"tags": ["clip", "skip"]
},
} }
def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput: def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput:
@ -587,46 +573,26 @@ class ClipSkipInvocation(BaseInvocation):
def get_max_token_count( def get_max_token_count(
tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], truncate_if_too_long=False
truncate_if_too_long=False) -> int: ) -> int:
if type(prompt) is Blend: if type(prompt) is Blend:
blend: Blend = prompt blend: Blend = prompt
return max( return max([get_max_token_count(tokenizer, p, truncate_if_too_long) for p in blend.prompts])
[
get_max_token_count(tokenizer, p, truncate_if_too_long)
for p in blend.prompts
]
)
elif type(prompt) is Conjunction: elif type(prompt) is Conjunction:
conjunction: Conjunction = prompt conjunction: Conjunction = prompt
return sum( return sum([get_max_token_count(tokenizer, p, truncate_if_too_long) for p in conjunction.prompts])
[
get_max_token_count(tokenizer, p, truncate_if_too_long)
for p in conjunction.prompts
]
)
else: else:
return len( return len(get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long))
get_tokens_for_prompt_object(
tokenizer, prompt, truncate_if_too_long))
def get_tokens_for_prompt_object( def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True) -> List[str]:
tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True
) -> List[str]:
if type(parsed_prompt) is Blend: if type(parsed_prompt) is Blend:
raise ValueError( raise ValueError("Blend is not supported here - you need to get tokens for each of its .children")
"Blend is not supported here - you need to get tokens for each of its .children"
)
text_fragments = [ text_fragments = [
x.text x.text
if type(x) is Fragment if type(x) is Fragment
else ( else (" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else str(x))
" ".join([f.text for f in x.original])
if type(x) is CrossAttentionControlSubstitute
else str(x)
)
for x in parsed_prompt.children for x in parsed_prompt.children
] ]
text = " ".join(text_fragments) text = " ".join(text_fragments)
@ -637,25 +603,17 @@ def get_tokens_for_prompt_object(
return tokens return tokens
def log_tokenization_for_conjunction( def log_tokenization_for_conjunction(c: Conjunction, tokenizer, display_label_prefix=None):
c: Conjunction, tokenizer, display_label_prefix=None
):
display_label_prefix = display_label_prefix or "" display_label_prefix = display_label_prefix or ""
for i, p in enumerate(c.prompts): for i, p in enumerate(c.prompts):
if len(c.prompts) > 1: if len(c.prompts) > 1:
this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})" this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})"
else: else:
this_display_label_prefix = display_label_prefix this_display_label_prefix = display_label_prefix
log_tokenization_for_prompt_object( log_tokenization_for_prompt_object(p, tokenizer, display_label_prefix=this_display_label_prefix)
p,
tokenizer,
display_label_prefix=this_display_label_prefix
)
def log_tokenization_for_prompt_object( def log_tokenization_for_prompt_object(p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None):
p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None
):
display_label_prefix = display_label_prefix or "" display_label_prefix = display_label_prefix or ""
if type(p) is Blend: if type(p) is Blend:
blend: Blend = p blend: Blend = p
@ -692,13 +650,10 @@ def log_tokenization_for_prompt_object(
) )
else: else:
text = " ".join([x.text for x in flattened_prompt.children]) text = " ".join([x.text for x in flattened_prompt.children])
log_tokenization_for_text( log_tokenization_for_text(text, tokenizer, display_label=display_label_prefix)
text, tokenizer, display_label=display_label_prefix
)
def log_tokenization_for_text( def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False):
text, tokenizer, display_label=None, truncate_if_too_long=False):
"""shows how the prompt is tokenized """shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word, # usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' ' # but for readability it has been replaced with ' '

View File

@ -6,21 +6,30 @@ from typing import Dict, List, Literal, Optional, Union
import cv2 import cv2
import numpy as np import numpy as np
from controlnet_aux import (CannyDetector, ContentShuffleDetector, HEDdetector, from controlnet_aux import (
LeresDetector, LineartAnimeDetector, CannyDetector,
LineartDetector, MediapipeFaceDetector, ContentShuffleDetector,
MidasDetector, MLSDdetector, NormalBaeDetector, HEDdetector,
OpenposeDetector, PidiNetDetector, SamDetector, LeresDetector,
ZoeDetector) LineartAnimeDetector,
LineartDetector,
MediapipeFaceDetector,
MidasDetector,
MLSDdetector,
NormalBaeDetector,
OpenposeDetector,
PidiNetDetector,
SamDetector,
ZoeDetector,
)
from controlnet_aux.util import HWC3, ade_palette from controlnet_aux.util import HWC3, ade_palette
from PIL import Image from PIL import Image
from pydantic import BaseModel, Field, validator from pydantic import BaseModel, Field, validator
from ...backend.model_management import BaseModelType, ModelType from ...backend.model_management import BaseModelType, ModelType
from ..models.image import ImageCategory, ImageField, ResourceOrigin from ..models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import (BaseInvocation, BaseInvocationOutput, from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
InvocationConfig, InvocationContext) from ..models.image import ImageOutput, PILInvocationConfig
from .image import ImageOutput, PILInvocationConfig
CONTROLNET_DEFAULT_MODELS = [ CONTROLNET_DEFAULT_MODELS = [
########################################### ###########################################
@ -34,7 +43,6 @@ CONTROLNET_DEFAULT_MODELS = [
"lllyasviel/sd-controlnet-scribble", "lllyasviel/sd-controlnet-scribble",
"lllyasviel/sd-controlnet-normal", "lllyasviel/sd-controlnet-normal",
"lllyasviel/sd-controlnet-mlsd", "lllyasviel/sd-controlnet-mlsd",
############################################# #############################################
# lllyasviel sd v1.5, ControlNet v1.1 models # lllyasviel sd v1.5, ControlNet v1.1 models
############################################# #############################################
@ -56,7 +64,6 @@ CONTROLNET_DEFAULT_MODELS = [
"lllyasviel/control_v11e_sd15_shuffle", "lllyasviel/control_v11e_sd15_shuffle",
"lllyasviel/control_v11e_sd15_ip2p", "lllyasviel/control_v11e_sd15_ip2p",
"lllyasviel/control_v11f1e_sd15_tile", "lllyasviel/control_v11f1e_sd15_tile",
################################################# #################################################
# thibaud sd v2.1 models (ControlNet v1.0? or v1.1? # thibaud sd v2.1 models (ControlNet v1.0? or v1.1?
################################################## ##################################################
@ -71,7 +78,6 @@ CONTROLNET_DEFAULT_MODELS = [
"thibaud/controlnet-sd21-lineart-diffusers", "thibaud/controlnet-sd21-lineart-diffusers",
"thibaud/controlnet-sd21-normalbae-diffusers", "thibaud/controlnet-sd21-normalbae-diffusers",
"thibaud/controlnet-sd21-ade20k-diffusers", "thibaud/controlnet-sd21-ade20k-diffusers",
############################################## ##############################################
# ControlNetMediaPipeface, ControlNet v1.1 # ControlNetMediaPipeface, ControlNet v1.1
############################################## ##############################################
@ -83,10 +89,17 @@ CONTROLNET_DEFAULT_MODELS = [
] ]
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)] CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
CONTROLNET_MODE_VALUES = Literal[tuple( CONTROLNET_MODE_VALUES = Literal[tuple(["balanced", "more_prompt", "more_control", "unbalanced"])]
["balanced", "more_prompt", "more_control", "unbalanced"])] CONTROLNET_RESIZE_VALUES = Literal[
CONTROLNET_RESIZE_VALUES = Literal[tuple( tuple(
["just_resize", "crop_resize", "fill_resize", "just_resize_simple",])] [
"just_resize",
"crop_resize",
"fill_resize",
"just_resize_simple",
]
)
]
class ControlNetModelField(BaseModel): class ControlNetModelField(BaseModel):
@ -98,21 +111,17 @@ class ControlNetModelField(BaseModel):
class ControlField(BaseModel): class ControlField(BaseModel):
image: ImageField = Field(default=None, description="The control image") image: ImageField = Field(default=None, description="The control image")
control_model: Optional[ControlNetModelField] = Field( control_model: Optional[ControlNetModelField] = Field(default=None, description="The ControlNet model to use")
default=None, description="The ControlNet model to use")
# control_weight: Optional[float] = Field(default=1, description="weight given to controlnet") # control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
control_weight: Union[float, List[float]] = Field( control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field( begin_step_percent: float = Field(
default=0, ge=0, le=1, default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
description="When the ControlNet is first applied (% of total steps)") )
end_step_percent: float = Field( end_step_percent: float = Field(
default=1, ge=0, le=1, default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
description="When the ControlNet is last applied (% of total steps)") )
control_mode: CONTROLNET_MODE_VALUES = Field( control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
default="balanced", description="The control mode to use") resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
resize_mode: CONTROLNET_RESIZE_VALUES = Field(
default="just_resize", description="The resize mode to use")
@validator("control_weight") @validator("control_weight")
def validate_control_weight(cls, v): def validate_control_weight(cls, v):
@ -120,11 +129,10 @@ class ControlField(BaseModel):
if isinstance(v, list): if isinstance(v, list):
for i in v: for i in v:
if i < -1 or i > 2: if i < -1 or i > 2:
raise ValueError( raise ValueError("Control weights must be within -1 to 2 range")
'Control weights must be within -1 to 2 range')
else: else:
if v < -1 or v > 2: if v < -1 or v > 2:
raise ValueError('Control weights must be within -1 to 2 range') raise ValueError("Control weights must be within -1 to 2 range")
return v return v
class Config: class Config:
@ -136,12 +144,13 @@ class ControlField(BaseModel):
"control_model": "controlnet_model", "control_model": "controlnet_model",
# "control_weight": "number", # "control_weight": "number",
} }
} },
} }
class ControlOutput(BaseInvocationOutput): class ControlOutput(BaseInvocationOutput):
"""node output for ControlNet info""" """node output for ControlNet info"""
# fmt: off # fmt: off
type: Literal["control_output"] = "control_output" type: Literal["control_output"] = "control_output"
control: ControlField = Field(default=None, description="The control info") control: ControlField = Field(default=None, description="The control info")
@ -150,6 +159,7 @@ class ControlOutput(BaseInvocationOutput):
class ControlNetInvocation(BaseInvocation): class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes""" """Collects ControlNet info to pass to other nodes"""
# fmt: off # fmt: off
type: Literal["controlnet"] = "controlnet" type: Literal["controlnet"] = "controlnet"
# Inputs # Inputs
@ -176,7 +186,7 @@ class ControlNetInvocation(BaseInvocation):
# "cfg_scale": "float", # "cfg_scale": "float",
"cfg_scale": "number", "cfg_scale": "number",
"control_weight": "float", "control_weight": "float",
} },
}, },
} }
@ -205,10 +215,7 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "Image Processor", "tags": ["image", "processor"]},
"title": "Image Processor",
"tags": ["image", "processor"]
},
} }
def run_processor(self, image): def run_processor(self, image):
@ -233,7 +240,7 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
image_category=ImageCategory.CONTROL, image_category=ImageCategory.CONTROL,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
node_id=self.id, node_id=self.id,
is_intermediate=self.is_intermediate is_intermediate=self.is_intermediate,
) )
"""Builds an ImageOutput and its ImageField""" """Builds an ImageOutput and its ImageField"""
@ -248,9 +255,9 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
) )
class CannyImageProcessorInvocation( class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
ImageProcessorInvocation, PILInvocationConfig):
"""Canny edge detection for ControlNet""" """Canny edge detection for ControlNet"""
# fmt: off # fmt: off
type: Literal["canny_image_processor"] = "canny_image_processor" type: Literal["canny_image_processor"] = "canny_image_processor"
# Input # Input
@ -260,22 +267,18 @@ class CannyImageProcessorInvocation(
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "Canny Processor", "tags": ["controlnet", "canny", "image", "processor"]},
"title": "Canny Processor",
"tags": ["controlnet", "canny", "image", "processor"]
},
} }
def run_processor(self, image): def run_processor(self, image):
canny_processor = CannyDetector() canny_processor = CannyDetector()
processed_image = canny_processor( processed_image = canny_processor(image, self.low_threshold, self.high_threshold)
image, self.low_threshold, self.high_threshold)
return processed_image return processed_image
class HedImageProcessorInvocation( class HedImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
ImageProcessorInvocation, PILInvocationConfig):
"""Applies HED edge detection to image""" """Applies HED edge detection to image"""
# fmt: off # fmt: off
type: Literal["hed_image_processor"] = "hed_image_processor" type: Literal["hed_image_processor"] = "hed_image_processor"
# Inputs # Inputs
@ -288,15 +291,13 @@ class HedImageProcessorInvocation(
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "Softedge(HED) Processor", "tags": ["controlnet", "softedge", "hed", "image", "processor"]},
"title": "Softedge(HED) Processor",
"tags": ["controlnet", "softedge", "hed", "image", "processor"]
},
} }
def run_processor(self, image): def run_processor(self, image):
hed_processor = HEDdetector.from_pretrained("lllyasviel/Annotators") hed_processor = HEDdetector.from_pretrained("lllyasviel/Annotators")
processed_image = hed_processor(image, processed_image = hed_processor(
image,
detect_resolution=self.detect_resolution, detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution, image_resolution=self.image_resolution,
# safe not supported in controlnet_aux v0.0.3 # safe not supported in controlnet_aux v0.0.3
@ -306,9 +307,9 @@ class HedImageProcessorInvocation(
return processed_image return processed_image
class LineartImageProcessorInvocation( class LineartImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
ImageProcessorInvocation, PILInvocationConfig):
"""Applies line art processing to image""" """Applies line art processing to image"""
# fmt: off # fmt: off
type: Literal["lineart_image_processor"] = "lineart_image_processor" type: Literal["lineart_image_processor"] = "lineart_image_processor"
# Inputs # Inputs
@ -319,24 +320,20 @@ class LineartImageProcessorInvocation(
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "Lineart Processor", "tags": ["controlnet", "lineart", "image", "processor"]},
"title": "Lineart Processor",
"tags": ["controlnet", "lineart", "image", "processor"]
},
} }
def run_processor(self, image): def run_processor(self, image):
lineart_processor = LineartDetector.from_pretrained( lineart_processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
"lllyasviel/Annotators")
processed_image = lineart_processor( processed_image = lineart_processor(
image, detect_resolution=self.detect_resolution, image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, coarse=self.coarse
image_resolution=self.image_resolution, coarse=self.coarse) )
return processed_image return processed_image
class LineartAnimeImageProcessorInvocation( class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
ImageProcessorInvocation, PILInvocationConfig):
"""Applies line art anime processing to image""" """Applies line art anime processing to image"""
# fmt: off # fmt: off
type: Literal["lineart_anime_image_processor"] = "lineart_anime_image_processor" type: Literal["lineart_anime_image_processor"] = "lineart_anime_image_processor"
# Inputs # Inputs
@ -348,23 +345,23 @@ class LineartAnimeImageProcessorInvocation(
schema_extra = { schema_extra = {
"ui": { "ui": {
"title": "Lineart Anime Processor", "title": "Lineart Anime Processor",
"tags": ["controlnet", "lineart", "anime", "image", "processor"] "tags": ["controlnet", "lineart", "anime", "image", "processor"],
}, },
} }
def run_processor(self, image): def run_processor(self, image):
processor = LineartAnimeDetector.from_pretrained( processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
"lllyasviel/Annotators") processed_image = processor(
processed_image = processor(image, image,
detect_resolution=self.detect_resolution, detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution, image_resolution=self.image_resolution,
) )
return processed_image return processed_image
class OpenposeImageProcessorInvocation( class OpenposeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
ImageProcessorInvocation, PILInvocationConfig):
"""Applies Openpose processing to image""" """Applies Openpose processing to image"""
# fmt: off # fmt: off
type: Literal["openpose_image_processor"] = "openpose_image_processor" type: Literal["openpose_image_processor"] = "openpose_image_processor"
# Inputs # Inputs
@ -375,25 +372,23 @@ class OpenposeImageProcessorInvocation(
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "Openpose Processor", "tags": ["controlnet", "openpose", "image", "processor"]},
"title": "Openpose Processor",
"tags": ["controlnet", "openpose", "image", "processor"]
},
} }
def run_processor(self, image): def run_processor(self, image):
openpose_processor = OpenposeDetector.from_pretrained( openpose_processor = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
"lllyasviel/Annotators")
processed_image = openpose_processor( processed_image = openpose_processor(
image, detect_resolution=self.detect_resolution, image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution, image_resolution=self.image_resolution,
hand_and_face=self.hand_and_face,) hand_and_face=self.hand_and_face,
)
return processed_image return processed_image
class MidasDepthImageProcessorInvocation( class MidasDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
ImageProcessorInvocation, PILInvocationConfig):
"""Applies Midas depth processing to image""" """Applies Midas depth processing to image"""
# fmt: off # fmt: off
type: Literal["midas_depth_image_processor"] = "midas_depth_image_processor" type: Literal["midas_depth_image_processor"] = "midas_depth_image_processor"
# Inputs # Inputs
@ -405,15 +400,13 @@ class MidasDepthImageProcessorInvocation(
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "Midas (Depth) Processor", "tags": ["controlnet", "midas", "depth", "image", "processor"]},
"title": "Midas (Depth) Processor",
"tags": ["controlnet", "midas", "depth", "image", "processor"]
},
} }
def run_processor(self, image): def run_processor(self, image):
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators") midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
processed_image = midas_processor(image, processed_image = midas_processor(
image,
a=np.pi * self.a_mult, a=np.pi * self.a_mult,
bg_th=self.bg_th, bg_th=self.bg_th,
# dept_and_normal not supported in controlnet_aux v0.0.3 # dept_and_normal not supported in controlnet_aux v0.0.3
@ -422,9 +415,9 @@ class MidasDepthImageProcessorInvocation(
return processed_image return processed_image
class NormalbaeImageProcessorInvocation( class NormalbaeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
ImageProcessorInvocation, PILInvocationConfig):
"""Applies NormalBae processing to image""" """Applies NormalBae processing to image"""
# fmt: off # fmt: off
type: Literal["normalbae_image_processor"] = "normalbae_image_processor" type: Literal["normalbae_image_processor"] = "normalbae_image_processor"
# Inputs # Inputs
@ -434,24 +427,20 @@ class NormalbaeImageProcessorInvocation(
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "Normal BAE Processor", "tags": ["controlnet", "normal", "bae", "image", "processor"]},
"title": "Normal BAE Processor",
"tags": ["controlnet", "normal", "bae", "image", "processor"]
},
} }
def run_processor(self, image): def run_processor(self, image):
normalbae_processor = NormalBaeDetector.from_pretrained( normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
"lllyasviel/Annotators")
processed_image = normalbae_processor( processed_image = normalbae_processor(
image, detect_resolution=self.detect_resolution, image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution
image_resolution=self.image_resolution) )
return processed_image return processed_image
class MlsdImageProcessorInvocation( class MlsdImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
ImageProcessorInvocation, PILInvocationConfig):
"""Applies MLSD processing to image""" """Applies MLSD processing to image"""
# fmt: off # fmt: off
type: Literal["mlsd_image_processor"] = "mlsd_image_processor" type: Literal["mlsd_image_processor"] = "mlsd_image_processor"
# Inputs # Inputs
@ -463,24 +452,24 @@ class MlsdImageProcessorInvocation(
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "MLSD Processor", "tags": ["controlnet", "mlsd", "image", "processor"]},
"title": "MLSD Processor",
"tags": ["controlnet", "mlsd", "image", "processor"]
},
} }
def run_processor(self, image): def run_processor(self, image):
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators") mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
processed_image = mlsd_processor( processed_image = mlsd_processor(
image, detect_resolution=self.detect_resolution, image,
image_resolution=self.image_resolution, thr_v=self.thr_v, detect_resolution=self.detect_resolution,
thr_d=self.thr_d) image_resolution=self.image_resolution,
thr_v=self.thr_v,
thr_d=self.thr_d,
)
return processed_image return processed_image
class PidiImageProcessorInvocation( class PidiImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
ImageProcessorInvocation, PILInvocationConfig):
"""Applies PIDI processing to image""" """Applies PIDI processing to image"""
# fmt: off # fmt: off
type: Literal["pidi_image_processor"] = "pidi_image_processor" type: Literal["pidi_image_processor"] = "pidi_image_processor"
# Inputs # Inputs
@ -492,25 +481,24 @@ class PidiImageProcessorInvocation(
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "PIDI Processor", "tags": ["controlnet", "pidi", "image", "processor"]},
"title": "PIDI Processor",
"tags": ["controlnet", "pidi", "image", "processor"]
},
} }
def run_processor(self, image): def run_processor(self, image):
pidi_processor = PidiNetDetector.from_pretrained( pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
"lllyasviel/Annotators")
processed_image = pidi_processor( processed_image = pidi_processor(
image, detect_resolution=self.detect_resolution, image,
image_resolution=self.image_resolution, safe=self.safe, detect_resolution=self.detect_resolution,
scribble=self.scribble) image_resolution=self.image_resolution,
safe=self.safe,
scribble=self.scribble,
)
return processed_image return processed_image
class ContentShuffleImageProcessorInvocation( class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
ImageProcessorInvocation, PILInvocationConfig):
"""Applies content shuffle processing to image""" """Applies content shuffle processing to image"""
# fmt: off # fmt: off
type: Literal["content_shuffle_image_processor"] = "content_shuffle_image_processor" type: Literal["content_shuffle_image_processor"] = "content_shuffle_image_processor"
# Inputs # Inputs
@ -525,48 +513,45 @@ class ContentShuffleImageProcessorInvocation(
schema_extra = { schema_extra = {
"ui": { "ui": {
"title": "Content Shuffle Processor", "title": "Content Shuffle Processor",
"tags": ["controlnet", "contentshuffle", "image", "processor"] "tags": ["controlnet", "contentshuffle", "image", "processor"],
}, },
} }
def run_processor(self, image): def run_processor(self, image):
content_shuffle_processor = ContentShuffleDetector() content_shuffle_processor = ContentShuffleDetector()
processed_image = content_shuffle_processor(image, processed_image = content_shuffle_processor(
image,
detect_resolution=self.detect_resolution, detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution, image_resolution=self.image_resolution,
h=self.h, h=self.h,
w=self.w, w=self.w,
f=self.f f=self.f,
) )
return processed_image return processed_image
# should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13 # should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13
class ZoeDepthImageProcessorInvocation( class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
ImageProcessorInvocation, PILInvocationConfig):
"""Applies Zoe depth processing to image""" """Applies Zoe depth processing to image"""
# fmt: off # fmt: off
type: Literal["zoe_depth_image_processor"] = "zoe_depth_image_processor" type: Literal["zoe_depth_image_processor"] = "zoe_depth_image_processor"
# fmt: on # fmt: on
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "Zoe (Depth) Processor", "tags": ["controlnet", "zoe", "depth", "image", "processor"]},
"title": "Zoe (Depth) Processor",
"tags": ["controlnet", "zoe", "depth", "image", "processor"]
},
} }
def run_processor(self, image): def run_processor(self, image):
zoe_depth_processor = ZoeDetector.from_pretrained( zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
"lllyasviel/Annotators")
processed_image = zoe_depth_processor(image) processed_image = zoe_depth_processor(image)
return processed_image return processed_image
class MediapipeFaceProcessorInvocation( class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
ImageProcessorInvocation, PILInvocationConfig):
"""Applies mediapipe face processing to image""" """Applies mediapipe face processing to image"""
# fmt: off # fmt: off
type: Literal["mediapipe_face_processor"] = "mediapipe_face_processor" type: Literal["mediapipe_face_processor"] = "mediapipe_face_processor"
# Inputs # Inputs
@ -576,26 +561,22 @@ class MediapipeFaceProcessorInvocation(
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "Mediapipe Processor", "tags": ["controlnet", "mediapipe", "image", "processor"]},
"title": "Mediapipe Processor",
"tags": ["controlnet", "mediapipe", "image", "processor"]
},
} }
def run_processor(self, image): def run_processor(self, image):
# MediaPipeFaceDetector throws an error if image has alpha channel # MediaPipeFaceDetector throws an error if image has alpha channel
# so convert to RGB if needed # so convert to RGB if needed
if image.mode == 'RGBA': if image.mode == "RGBA":
image = image.convert('RGB') image = image.convert("RGB")
mediapipe_face_processor = MediapipeFaceDetector() mediapipe_face_processor = MediapipeFaceDetector()
processed_image = mediapipe_face_processor( processed_image = mediapipe_face_processor(image, max_faces=self.max_faces, min_confidence=self.min_confidence)
image, max_faces=self.max_faces, min_confidence=self.min_confidence)
return processed_image return processed_image
class LeresImageProcessorInvocation( class LeresImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
ImageProcessorInvocation, PILInvocationConfig):
"""Applies leres processing to image""" """Applies leres processing to image"""
# fmt: off # fmt: off
type: Literal["leres_image_processor"] = "leres_image_processor" type: Literal["leres_image_processor"] = "leres_image_processor"
# Inputs # Inputs
@ -608,24 +589,23 @@ class LeresImageProcessorInvocation(
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "Leres (Depth) Processor", "tags": ["controlnet", "leres", "depth", "image", "processor"]},
"title": "Leres (Depth) Processor",
"tags": ["controlnet", "leres", "depth", "image", "processor"]
},
} }
def run_processor(self, image): def run_processor(self, image):
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators") leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
processed_image = leres_processor( processed_image = leres_processor(
image, thr_a=self.thr_a, thr_b=self.thr_b, boost=self.boost, image,
thr_a=self.thr_a,
thr_b=self.thr_b,
boost=self.boost,
detect_resolution=self.detect_resolution, detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution) image_resolution=self.image_resolution,
)
return processed_image return processed_image
class TileResamplerProcessorInvocation( class TileResamplerProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
ImageProcessorInvocation, PILInvocationConfig):
# fmt: off # fmt: off
type: Literal["tile_image_processor"] = "tile_image_processor" type: Literal["tile_image_processor"] = "tile_image_processor"
# Inputs # Inputs
@ -637,12 +617,13 @@ class TileResamplerProcessorInvocation(
schema_extra = { schema_extra = {
"ui": { "ui": {
"title": "Tile Resample Processor", "title": "Tile Resample Processor",
"tags": ["controlnet", "tile", "resample", "image", "processor"] "tags": ["controlnet", "tile", "resample", "image", "processor"],
}, },
} }
# tile_resample copied from sd-webui-controlnet/scripts/processor.py # tile_resample copied from sd-webui-controlnet/scripts/processor.py
def tile_resample(self, def tile_resample(
self,
np_img: np.ndarray, np_img: np.ndarray,
res=512, # never used? res=512, # never used?
down_sampling_rate=1.0, down_sampling_rate=1.0,
@ -658,36 +639,41 @@ class TileResamplerProcessorInvocation(
def run_processor(self, img): def run_processor(self, img):
np_img = np.array(img, dtype=np.uint8) np_img = np.array(img, dtype=np.uint8)
processed_np_image = self.tile_resample(np_img, processed_np_image = self.tile_resample(
np_img,
# res=self.tile_size, # res=self.tile_size,
down_sampling_rate=self.down_sampling_rate down_sampling_rate=self.down_sampling_rate,
) )
processed_image = Image.fromarray(processed_np_image) processed_image = Image.fromarray(processed_np_image)
return processed_image return processed_image
class SegmentAnythingProcessorInvocation( class SegmentAnythingProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
ImageProcessorInvocation, PILInvocationConfig):
"""Applies segment anything processing to image""" """Applies segment anything processing to image"""
# fmt: off # fmt: off
type: Literal["segment_anything_processor"] = "segment_anything_processor" type: Literal["segment_anything_processor"] = "segment_anything_processor"
# fmt: on # fmt: on
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = {"ui": {"title": "Segment Anything Processor", "tags": [ schema_extra = {
"controlnet", "segment", "anything", "sam", "image", "processor"]}, } "ui": {
"title": "Segment Anything Processor",
"tags": ["controlnet", "segment", "anything", "sam", "image", "processor"],
},
}
def run_processor(self, image): def run_processor(self, image):
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints") # segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained( segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(
"ybelkada/segment-anything", subfolder="checkpoints") "ybelkada/segment-anything", subfolder="checkpoints"
)
np_img = np.array(image, dtype=np.uint8) np_img = np.array(image, dtype=np.uint8)
processed_image = segment_anything_processor(np_img) processed_image = segment_anything_processor(np_img)
return processed_image return processed_image
class SamDetectorReproducibleColors(SamDetector): class SamDetectorReproducibleColors(SamDetector):
# overriding SamDetector.show_anns() method to use reproducible colors for segmentation image # overriding SamDetector.show_anns() method to use reproducible colors for segmentation image
# base class show_anns() method randomizes colors, # base class show_anns() method randomizes colors,
# which seems to also lead to non-reproducible image generation # which seems to also lead to non-reproducible image generation
@ -695,19 +681,15 @@ class SamDetectorReproducibleColors(SamDetector):
def show_anns(self, anns: List[Dict]): def show_anns(self, anns: List[Dict]):
if len(anns) == 0: if len(anns) == 0:
return return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True)
h, w = anns[0]['segmentation'].shape h, w = anns[0]["segmentation"].shape
final_img = Image.fromarray( final_img = Image.fromarray(np.zeros((h, w, 3), dtype=np.uint8), mode="RGB")
np.zeros((h, w, 3), dtype=np.uint8), mode="RGB")
palette = ade_palette() palette = ade_palette()
for i, ann in enumerate(sorted_anns): for i, ann in enumerate(sorted_anns):
m = ann['segmentation'] m = ann["segmentation"]
img = np.empty((m.shape[0], m.shape[1], 3), dtype=np.uint8) img = np.empty((m.shape[0], m.shape[1], 3), dtype=np.uint8)
# doing modulo just in case number of annotated regions exceeds number of colors in palette # doing modulo just in case number of annotated regions exceeds number of colors in palette
ann_color = palette[i % len(palette)] ann_color = palette[i % len(palette)]
img[:, :] = ann_color img[:, :] = ann_color
final_img.paste( final_img.paste(Image.fromarray(img, mode="RGB"), (0, 0), Image.fromarray(np.uint8(m * 255)))
Image.fromarray(img, mode="RGB"),
(0, 0),
Image.fromarray(np.uint8(m * 255)))
return np.array(final_img, dtype=np.uint8) return np.array(final_img, dtype=np.uint8)

View File

@ -37,10 +37,7 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "OpenCV Inpaint", "tags": ["opencv", "inpaint"]},
"title": "OpenCV Inpaint",
"tags": ["opencv", "inpaint"]
},
} }
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:

View File

@ -6,8 +6,7 @@ from typing import Literal, Optional, get_args
import torch import torch
from pydantic import Field from pydantic import Field
from invokeai.app.models.image import (ColorField, ImageCategory, ImageField, from invokeai.app.models.image import ColorField, ImageCategory, ImageField, ResourceOrigin
ResourceOrigin)
from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.backend.generator.inpaint import infill_methods from invokeai.backend.generator.inpaint import infill_methods
@ -25,13 +24,12 @@ from contextlib import contextmanager, ExitStack, ContextDecorator
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())] SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
INFILL_METHODS = Literal[tuple(infill_methods())] INFILL_METHODS = Literal[tuple(infill_methods())]
DEFAULT_INFILL_METHOD = ( DEFAULT_INFILL_METHOD = "patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
"patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
)
from .latent import get_scheduler from .latent import get_scheduler
class OldModelContext(ContextDecorator): class OldModelContext(ContextDecorator):
model: StableDiffusionGeneratorPipeline model: StableDiffusionGeneratorPipeline
@ -44,6 +42,7 @@ class OldModelContext(ContextDecorator):
def __exit__(self, *exc): def __exit__(self, *exc):
return False return False
class OldModelInfo: class OldModelInfo:
name: str name: str
hash: str hash: str
@ -64,20 +63,34 @@ class InpaintInvocation(BaseInvocation):
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation") positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation") negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use (omit for random)", default_factory=get_random_seed) seed: int = Field(
ge=0, le=SEED_MAX, description="The seed to use (omit for random)", default_factory=get_random_seed
)
steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image") steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image")
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", ) width: int = Field(
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", ) default=512,
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", ) multiple_of=8,
gt=0,
description="The width of the resulting image",
)
height: int = Field(
default=512,
multiple_of=8,
gt=0,
description="The height of the resulting image",
)
cfg_scale: float = Field(
default=7.5,
ge=1,
description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt",
)
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use") scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use")
unet: UNetField = Field(default=None, description="UNet model") unet: UNetField = Field(default=None, description="UNet model")
vae: VaeField = Field(default=None, description="Vae model") vae: VaeField = Field(default=None, description="Vae model")
# Inputs # Inputs
image: Optional[ImageField] = Field(description="The input image") image: Optional[ImageField] = Field(description="The input image")
strength: float = Field( strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the original image")
default=0.75, gt=0, le=1, description="The strength of the original image"
)
fit: bool = Field( fit: bool = Field(
default=True, default=True,
description="Whether or not the result should be fit to the aspect ratio of the input image", description="Whether or not the result should be fit to the aspect ratio of the input image",
@ -86,18 +99,10 @@ class InpaintInvocation(BaseInvocation):
# Inputs # Inputs
mask: Optional[ImageField] = Field(description="The mask") mask: Optional[ImageField] = Field(description="The mask")
seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)") seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)")
seam_blur: int = Field( seam_blur: int = Field(default=16, ge=0, description="The seam inpaint blur radius (px)")
default=16, ge=0, description="The seam inpaint blur radius (px)" seam_strength: float = Field(default=0.75, gt=0, le=1, description="The seam inpaint strength")
) seam_steps: int = Field(default=30, ge=1, description="The number of steps to use for seam inpaint")
seam_strength: float = Field( tile_size: int = Field(default=32, ge=1, description="The tile infill method size (px)")
default=0.75, gt=0, le=1, description="The seam inpaint strength"
)
seam_steps: int = Field(
default=30, ge=1, description="The number of steps to use for seam inpaint"
)
tile_size: int = Field(
default=32, ge=1, description="The tile infill method size (px)"
)
infill_method: INFILL_METHODS = Field( infill_method: INFILL_METHODS = Field(
default=DEFAULT_INFILL_METHOD, default=DEFAULT_INFILL_METHOD,
description="The method used to infill empty regions (px)", description="The method used to infill empty regions (px)",
@ -128,10 +133,7 @@ class InpaintInvocation(BaseInvocation):
# Schema customisation # Schema customisation
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"tags": ["stable-diffusion", "image"], "title": "Inpaint"},
"tags": ["stable-diffusion", "image"],
"title": "Inpaint"
},
} }
def dispatch_progress( def dispatch_progress(
@ -162,18 +164,23 @@ class InpaintInvocation(BaseInvocation):
def _lora_loader(): def _lora_loader():
for lora in self.unet.loras: for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model( lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}), context=context,) **lora.dict(exclude={"weight"}),
context=context,
)
yield (lora_info.context.model, lora.weight) yield (lora_info.context.model, lora.weight)
del lora_info del lora_info
return return
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context,) unet_info = context.services.model_manager.get_model(
vae_info = context.services.model_manager.get_model(**self.vae.vae.dict(), context=context,) **self.unet.unet.dict(),
context=context,
with vae_info as vae,\ )
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ vae_info = context.services.model_manager.get_model(
unet_info as unet: **self.vae.vae.dict(),
context=context,
)
with vae_info as vae, ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()), unet_info as unet:
device = context.services.model_manager.mgr.cache.execution_device device = context.services.model_manager.mgr.cache.execution_device
dtype = context.services.model_manager.mgr.cache.precision dtype = context.services.model_manager.mgr.cache.precision
@ -197,21 +204,11 @@ class InpaintInvocation(BaseInvocation):
) )
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = ( image = None if self.image is None else context.services.images.get_pil_image(self.image.image_name)
None mask = None if self.mask is None else context.services.images.get_pil_image(self.mask.image_name)
if self.image is None
else context.services.images.get_pil_image(self.image.image_name)
)
mask = (
None
if self.mask is None
else context.services.images.get_pil_image(self.mask.image_name)
)
# Get the source node id (we are invoking the prepared node) # Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get( graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id] source_node_id = graph_execution_state.prepared_source_mapping[self.id]
scheduler = get_scheduler( scheduler = get_scheduler(

View File

@ -4,60 +4,25 @@ from typing import Literal, Optional
import numpy import numpy
from PIL import Image, ImageFilter, ImageOps, ImageChops from PIL import Image, ImageFilter, ImageOps, ImageChops
from pydantic import BaseModel, Field from pydantic import Field
from pathlib import Path
from typing import Union from typing import Union
from invokeai.app.invocations.metadata import CoreMetadata
from ..models.image import ImageCategory, ImageField, ResourceOrigin from ..models.image import (
ImageCategory,
ImageField,
ResourceOrigin,
PILInvocationConfig,
ImageOutput,
MaskOutput,
)
from .baseinvocation import ( from .baseinvocation import (
BaseInvocation, BaseInvocation,
BaseInvocationOutput,
InvocationContext, InvocationContext,
InvocationConfig, InvocationConfig,
) )
from invokeai.backend.image_util.safety_checker import SafetyChecker
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
class PILInvocationConfig(BaseModel):
"""Helper class to provide all PIL invocations with additional config"""
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["PIL", "image"],
},
}
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
# fmt: off
type: Literal["image_output"] = "image_output"
image: ImageField = Field(default=None, description="The output image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
# fmt: on
class Config:
schema_extra = {"required": ["type", "image", "width", "height"]}
class MaskOutput(BaseInvocationOutput):
"""Base class for invocations that output a mask"""
# fmt: off
type: Literal["mask"] = "mask"
mask: ImageField = Field(default=None, description="The output mask")
width: int = Field(description="The width of the mask in pixels")
height: int = Field(description="The height of the mask in pixels")
# fmt: on
class Config:
schema_extra = {
"required": [
"type",
"mask",
]
}
class LoadImageInvocation(BaseInvocation): class LoadImageInvocation(BaseInvocation):
@ -74,10 +39,7 @@ class LoadImageInvocation(BaseInvocation):
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "Load Image", "tags": ["image", "load"]},
"title": "Load Image",
"tags": ["image", "load"]
},
} }
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
@ -96,16 +58,11 @@ class ShowImageInvocation(BaseInvocation):
type: Literal["show_image"] = "show_image" type: Literal["show_image"] = "show_image"
# Inputs # Inputs
image: Optional[ImageField] = Field( image: Optional[ImageField] = Field(default=None, description="The image to show")
default=None, description="The image to show"
)
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "Show Image", "tags": ["image", "show"]},
"title": "Show Image",
"tags": ["image", "show"]
},
} }
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
@ -138,18 +95,13 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "Crop Image", "tags": ["image", "crop"]},
"title": "Crop Image",
"tags": ["image", "crop"]
},
} }
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name) image = context.services.images.get_pil_image(self.image.image_name)
image_crop = Image.new( image_crop = Image.new(mode="RGBA", size=(self.width, self.height), color=(0, 0, 0, 0))
mode="RGBA", size=(self.width, self.height), color=(0, 0, 0, 0)
)
image_crop.paste(image, (-self.x, -self.y)) image_crop.paste(image, (-self.x, -self.y))
image_dto = context.services.images.create( image_dto = context.services.images.create(
@ -184,21 +136,14 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "Paste Image", "tags": ["image", "paste"]},
"title": "Paste Image",
"tags": ["image", "paste"]
},
} }
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
base_image = context.services.images.get_pil_image(self.base_image.image_name) base_image = context.services.images.get_pil_image(self.base_image.image_name)
image = context.services.images.get_pil_image(self.image.image_name) image = context.services.images.get_pil_image(self.image.image_name)
mask = ( mask = (
None None if self.mask is None else ImageOps.invert(context.services.images.get_pil_image(self.mask.image_name))
if self.mask is None
else ImageOps.invert(
context.services.images.get_pil_image(self.mask.image_name)
)
) )
# TODO: probably shouldn't invert mask here... should user be required to do it? # TODO: probably shouldn't invert mask here... should user be required to do it?
@ -207,9 +152,7 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
max_x = max(base_image.width, image.width + self.x) max_x = max(base_image.width, image.width + self.x)
max_y = max(base_image.height, image.height + self.y) max_y = max(base_image.height, image.height + self.y)
new_image = Image.new( new_image = Image.new(mode="RGBA", size=(max_x - min_x, max_y - min_y), color=(0, 0, 0, 0))
mode="RGBA", size=(max_x - min_x, max_y - min_y), color=(0, 0, 0, 0)
)
new_image.paste(base_image, (abs(min_x), abs(min_y))) new_image.paste(base_image, (abs(min_x), abs(min_y)))
new_image.paste(image, (max(0, self.x), max(0, self.y)), mask=mask) new_image.paste(image, (max(0, self.x), max(0, self.y)), mask=mask)
@ -242,10 +185,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "Mask From Alpha", "tags": ["image", "mask", "alpha"]},
"title": "Mask From Alpha",
"tags": ["image", "mask", "alpha"]
},
} }
def invoke(self, context: InvocationContext) -> MaskOutput: def invoke(self, context: InvocationContext) -> MaskOutput:
@ -284,10 +224,7 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "Multiply Images", "tags": ["image", "multiply"]},
"title": "Multiply Images",
"tags": ["image", "multiply"]
},
} }
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
@ -328,10 +265,7 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "Image Channel", "tags": ["image", "channel"]},
"title": "Image Channel",
"tags": ["image", "channel"]
},
} }
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
@ -371,10 +305,7 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "Convert Image", "tags": ["image", "convert"]},
"title": "Convert Image",
"tags": ["image", "convert"]
},
} }
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
@ -412,19 +343,14 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "Blur Image", "tags": ["image", "blur"]},
"title": "Blur Image",
"tags": ["image", "blur"]
},
} }
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name) image = context.services.images.get_pil_image(self.image.image_name)
blur = ( blur = (
ImageFilter.GaussianBlur(self.radius) ImageFilter.GaussianBlur(self.radius) if self.blur_type == "gaussian" else ImageFilter.BoxBlur(self.radius)
if self.blur_type == "gaussian"
else ImageFilter.BoxBlur(self.radius)
) )
blur_image = image.filter(blur) blur_image = image.filter(blur)
@ -479,10 +405,7 @@ class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "Resize Image", "tags": ["image", "resize"]},
"title": "Resize Image",
"tags": ["image", "resize"]
},
} }
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
@ -525,10 +448,7 @@ class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "Scale Image", "tags": ["image", "scale"]},
"title": "Scale Image",
"tags": ["image", "scale"]
},
} }
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
@ -573,10 +493,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "Image Linear Interpolation", "tags": ["image", "linear", "interpolation", "lerp"]},
"title": "Image Linear Interpolation",
"tags": ["image", "linear", "interpolation", "lerp"]
},
} }
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
@ -619,7 +536,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {
"title": "Image Inverse Linear Interpolation", "title": "Image Inverse Linear Interpolation",
"tags": ["image", "linear", "interpolation", "inverse"] "tags": ["image", "linear", "interpolation", "inverse"],
}, },
} }
@ -627,12 +544,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
image = context.services.images.get_pil_image(self.image.image_name) image = context.services.images.get_pil_image(self.image.image_name)
image_arr = numpy.asarray(image, dtype=numpy.float32) image_arr = numpy.asarray(image, dtype=numpy.float32)
image_arr = ( image_arr = numpy.minimum(numpy.maximum(image_arr - self.min, 0) / float(self.max - self.min), 1) * 255
numpy.minimum(
numpy.maximum(image_arr - self.min, 0) / float(self.max - self.min), 1
)
* 255
)
ilerp_image = Image.fromarray(numpy.uint8(image_arr)) ilerp_image = Image.fromarray(numpy.uint8(image_arr))
@ -650,3 +562,91 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )
class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
"""Add blur to NSFW-flagged images"""
# fmt: off
type: Literal["img_nsfw"] = "img_nsfw"
# Inputs
image: Optional[ImageField] = Field(default=None, description="The image to check")
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Blur NSFW Images", "tags": ["image", "nsfw", "checker"]},
}
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
logger = context.services.logger
logger.debug("Running NSFW checker")
if SafetyChecker.has_nsfw_concept(image):
logger.info("A potentially NSFW image has been detected. Image will be blurred.")
blurry_image = image.filter(filter=ImageFilter.GaussianBlur(radius=32))
caution = self._get_caution_img()
blurry_image.paste(caution, (0, 0), caution)
image = blurry_image
image_dto = context.services.images.create(
image=image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata.dict() if self.metadata else None,
)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
def _get_caution_img(self) -> Image:
import invokeai.app.assets.images as image_assets
caution = Image.open(Path(image_assets.__path__[0]) / "caution.png")
return caution.resize((caution.width // 2, caution.height // 2))
class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
"""Add an invisible watermark to an image"""
# fmt: off
type: Literal["img_watermark"] = "img_watermark"
# Inputs
image: Optional[ImageField] = Field(default=None, description="The image to check")
text: str = Field(default='InvokeAI', description="Watermark text")
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Add Invisible Watermark", "tags": ["image", "watermark", "invisible"]},
}
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
new_image = InvisibleWatermark.add_watermark(image, self.text)
image_dto = context.services.images.create(
image=new_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata.dict() if self.metadata else None,
)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)

View File

@ -30,9 +30,7 @@ def infill_methods() -> list[str]:
INFILL_METHODS = Literal[tuple(infill_methods())] INFILL_METHODS = Literal[tuple(infill_methods())]
DEFAULT_INFILL_METHOD = ( DEFAULT_INFILL_METHOD = "patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
"patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
)
def infill_patchmatch(im: Image.Image) -> Image.Image: def infill_patchmatch(im: Image.Image) -> Image.Image:
@ -44,9 +42,7 @@ def infill_patchmatch(im: Image.Image) -> Image.Image:
return im return im
# Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though) # Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though)
im_patched_np = PatchMatch.inpaint( im_patched_np = PatchMatch.inpaint(im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3)
im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3
)
im_patched = Image.fromarray(im_patched_np, mode="RGB") im_patched = Image.fromarray(im_patched_np, mode="RGB")
return im_patched return im_patched
@ -68,9 +64,7 @@ def get_tile_images(image: np.ndarray, width=8, height=8):
) )
def tile_fill_missing( def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int] = None) -> Image.Image:
im: Image.Image, tile_size: int = 16, seed: Optional[int] = None
) -> Image.Image:
# Only fill if there's an alpha layer # Only fill if there's an alpha layer
if im.mode != "RGBA": if im.mode != "RGBA":
return im return im
@ -103,9 +97,7 @@ def tile_fill_missing(
# Find all invalid tiles and replace with a random valid tile # Find all invalid tiles and replace with a random valid tile
replace_count = (tiles_mask == False).sum() replace_count = (tiles_mask == False).sum()
rng = np.random.default_rng(seed=seed) rng = np.random.default_rng(seed=seed)
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[ tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count), :, :, :]
rng.choice(filtered_tiles.shape[0], replace_count), :, :, :
]
# Convert back to an image # Convert back to an image
tiles_all = tiles_all.reshape(tshape) tiles_all = tiles_all.reshape(tshape)
@ -126,9 +118,7 @@ class InfillColorInvocation(BaseInvocation):
"""Infills transparent areas of an image with a solid color""" """Infills transparent areas of an image with a solid color"""
type: Literal["infill_rgba"] = "infill_rgba" type: Literal["infill_rgba"] = "infill_rgba"
image: Optional[ImageField] = Field( image: Optional[ImageField] = Field(default=None, description="The image to infill")
default=None, description="The image to infill"
)
color: ColorField = Field( color: ColorField = Field(
default=ColorField(r=127, g=127, b=127, a=255), default=ColorField(r=127, g=127, b=127, a=255),
description="The color to use to infill", description="The color to use to infill",
@ -136,10 +126,7 @@ class InfillColorInvocation(BaseInvocation):
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "Color Infill", "tags": ["image", "inpaint", "color", "infill"]},
"title": "Color Infill",
"tags": ["image", "inpaint", "color", "infill"]
},
} }
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
@ -171,9 +158,7 @@ class InfillTileInvocation(BaseInvocation):
type: Literal["infill_tile"] = "infill_tile" type: Literal["infill_tile"] = "infill_tile"
image: Optional[ImageField] = Field( image: Optional[ImageField] = Field(default=None, description="The image to infill")
default=None, description="The image to infill"
)
tile_size: int = Field(default=32, ge=1, description="The tile size (px)") tile_size: int = Field(default=32, ge=1, description="The tile size (px)")
seed: int = Field( seed: int = Field(
ge=0, ge=0,
@ -184,18 +169,13 @@ class InfillTileInvocation(BaseInvocation):
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "Tile Infill", "tags": ["image", "inpaint", "tile", "infill"]},
"title": "Tile Infill",
"tags": ["image", "inpaint", "tile", "infill"]
},
} }
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name) image = context.services.images.get_pil_image(self.image.image_name)
infilled = tile_fill_missing( infilled = tile_fill_missing(image.copy(), seed=self.seed, tile_size=self.tile_size)
image.copy(), seed=self.seed, tile_size=self.tile_size
)
infilled.paste(image, (0, 0), image.split()[-1]) infilled.paste(image, (0, 0), image.split()[-1])
image_dto = context.services.images.create( image_dto = context.services.images.create(
@ -219,16 +199,11 @@ class InfillPatchMatchInvocation(BaseInvocation):
type: Literal["infill_patchmatch"] = "infill_patchmatch" type: Literal["infill_patchmatch"] = "infill_patchmatch"
image: Optional[ImageField] = Field( image: Optional[ImageField] = Field(default=None, description="The image to infill")
default=None, description="The image to infill"
)
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "Patch Match Infill", "tags": ["image", "inpaint", "patchmatch", "infill"]},
"title": "Patch Match Infill",
"tags": ["image", "inpaint", "patchmatch", "infill"]
},
} }
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:

View File

@ -12,20 +12,21 @@ from pydantic import BaseModel, Field, validator
from invokeai.app.invocations.metadata import CoreMetadata from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_management.models.base import ModelType from invokeai.backend.model_management.models import ModelType, SilenceWarnings
from ...backend.model_management.lora import ModelPatcher from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import ( from ...backend.stable_diffusion.diffusers_pipeline import (
ConditioningData, ControlNetData, StableDiffusionGeneratorPipeline, ConditioningData,
image_resized_to_grid_as_tensor) ControlNetData,
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \ StableDiffusionGeneratorPipeline,
PostprocessingSettings image_resized_to_grid_as_tensor,
)
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.util.devices import choose_torch_device, torch_dtype, choose_precision from ...backend.util.devices import choose_torch_device, torch_dtype, choose_precision
from ..models.image import ImageCategory, ImageField, ResourceOrigin from ..models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import (BaseInvocation, BaseInvocationOutput, from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
InvocationConfig, InvocationContext)
from .compel import ConditioningField from .compel import ConditioningField
from .controlnet_image_processors import ControlField from .controlnet_image_processors import ControlField
from .image import ImageOutput from .image import ImageOutput
@ -46,8 +47,7 @@ DEFAULT_PRECISION = choose_precision(choose_torch_device())
class LatentsField(BaseModel): class LatentsField(BaseModel):
"""A latents field used for passing latents between invocations""" """A latents field used for passing latents between invocations"""
latents_name: Optional[str] = Field( latents_name: Optional[str] = Field(default=None, description="The name of the latents")
default=None, description="The name of the latents")
class Config: class Config:
schema_extra = {"required": ["latents_name"]} schema_extra = {"required": ["latents_name"]}
@ -55,6 +55,7 @@ class LatentsField(BaseModel):
class LatentsOutput(BaseInvocationOutput): class LatentsOutput(BaseInvocationOutput):
"""Base class for invocations that output latents""" """Base class for invocations that output latents"""
# fmt: off # fmt: off
type: Literal["latents_output"] = "latents_output" type: Literal["latents_output"] = "latents_output"
@ -73,9 +74,7 @@ def build_latents_output(latents_name: str, latents: torch.Tensor):
) )
SAMPLER_NAME_VALUES = Literal[ SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))]
tuple(list(SCHEDULER_MAP.keys()))
]
def get_scheduler( def get_scheduler(
@ -83,11 +82,10 @@ def get_scheduler(
scheduler_info: ModelInfo, scheduler_info: ModelInfo,
scheduler_name: str, scheduler_name: str,
) -> Scheduler: ) -> Scheduler:
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get( scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
scheduler_name, SCHEDULER_MAP['ddim']
)
orig_scheduler_info = context.services.model_manager.get_model( orig_scheduler_info = context.services.model_manager.get_model(
**scheduler_info.dict(), context=context, **scheduler_info.dict(),
context=context,
) )
with orig_scheduler_info as orig_scheduler: with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config scheduler_config = orig_scheduler.config
@ -102,7 +100,7 @@ def get_scheduler(
scheduler = scheduler_class.from_config(scheduler_config) scheduler = scheduler_class.from_config(scheduler_config)
# hack copied over from generate.py # hack copied over from generate.py
if not hasattr(scheduler, 'uses_inpainting_model'): if not hasattr(scheduler, "uses_inpainting_model"):
scheduler.uses_inpainting_model = lambda: False scheduler.uses_inpainting_model = lambda: False
return scheduler return scheduler
@ -133,10 +131,10 @@ class TextToLatentsInvocation(BaseInvocation):
if isinstance(v, list): if isinstance(v, list):
for i in v: for i in v:
if i < 1: if i < 1:
raise ValueError('cfg_scale must be greater than 1') raise ValueError("cfg_scale must be greater than 1")
else: else:
if v < 1: if v < 1:
raise ValueError('cfg_scale must be greater than 1') raise ValueError("cfg_scale must be greater than 1")
return v return v
# Schema customisation # Schema customisation
@ -149,8 +147,8 @@ class TextToLatentsInvocation(BaseInvocation):
"model": "model", "model": "model",
"control": "control", "control": "control",
# "cfg_scale": "float", # "cfg_scale": "float",
"cfg_scale": "number" "cfg_scale": "number",
} },
}, },
} }
@ -190,16 +188,14 @@ class TextToLatentsInvocation(BaseInvocation):
threshold=0.0, # threshold, threshold=0.0, # threshold,
warmup=0.2, # warmup, warmup=0.2, # warmup,
h_symmetry_time_pct=None, # h_symmetry_time_pct, h_symmetry_time_pct=None, # h_symmetry_time_pct,
v_symmetry_time_pct=None # v_symmetry_time_pct, v_symmetry_time_pct=None, # v_symmetry_time_pct,
), ),
) )
conditioning_data = conditioning_data.add_scheduler_args_if_applicable( conditioning_data = conditioning_data.add_scheduler_args_if_applicable(
scheduler, scheduler,
# for ddim scheduler # for ddim scheduler
eta=0.0, # ddim_eta eta=0.0, # ddim_eta
# for ancestral and sde schedulers # for ancestral and sde schedulers
generator=torch.Generator(device=unet.device).manual_seed(0), generator=torch.Generator(device=unet.device).manual_seed(0),
) )
@ -247,7 +243,6 @@ class TextToLatentsInvocation(BaseInvocation):
exit_stack: ExitStack, exit_stack: ExitStack,
do_classifier_free_guidance: bool = True, do_classifier_free_guidance: bool = True,
) -> List[ControlNetData]: ) -> List[ControlNetData]:
# assuming fixed dimensional scaling of 8:1 for image:latents # assuming fixed dimensional scaling of 8:1 for image:latents
control_height_resize = latents_shape[2] * 8 control_height_resize = latents_shape[2] * 8
control_width_resize = latents_shape[3] * 8 control_width_resize = latents_shape[3] * 8
@ -261,7 +256,7 @@ class TextToLatentsInvocation(BaseInvocation):
control_list = control_input control_list = control_input
else: else:
control_list = None control_list = None
if (control_list is None): if control_list is None:
control_data = None control_data = None
# from above handling, any control that is not None should now be of type list[ControlField] # from above handling, any control that is not None should now be of type list[ControlField]
else: else:
@ -281,9 +276,7 @@ class TextToLatentsInvocation(BaseInvocation):
control_models.append(control_model) control_models.append(control_model)
control_image_field = control_info.image control_image_field = control_info.image
input_image = context.services.images.get_pil_image( input_image = context.services.images.get_pil_image(control_image_field.image_name)
control_image_field.image_name
)
# self.image.image_type, self.image.image_name # self.image.image_type, self.image.image_name
# FIXME: still need to test with different widths, heights, devices, dtypes # FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt? # and add in batch_size, num_images_per_prompt?
@ -318,12 +311,11 @@ class TextToLatentsInvocation(BaseInvocation):
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
with SilenceWarnings():
noise = context.services.latents.get(self.noise.latents_name) noise = context.services.latents.get(self.noise.latents_name)
# Get the source node id (we are invoking the prepared node) # Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get( graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id] source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState): def step_callback(state: PipelineIntermediateState):
@ -332,19 +324,20 @@ class TextToLatentsInvocation(BaseInvocation):
def _lora_loader(): def _lora_loader():
for lora in self.unet.loras: for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model( lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}), context=context, **lora.dict(exclude={"weight"}),
context=context,
) )
yield (lora_info.context.model, lora.weight) yield (lora_info.context.model, lora.weight)
del lora_info del lora_info
return return
unet_info = context.services.model_manager.get_model( unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict(), context=context, **self.unet.unet.dict(),
context=context,
) )
with ExitStack() as exit_stack,\ with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ unet_info.context.model, _lora_loader()
unet_info as unet: ), unet_info as unet:
noise = noise.to(device=unet.device, dtype=unet.dtype) noise = noise.to(device=unet.device, dtype=unet.dtype)
scheduler = get_scheduler( scheduler = get_scheduler(
@ -357,7 +350,9 @@ class TextToLatentsInvocation(BaseInvocation):
conditioning_data = self.get_conditioning_data(context, scheduler, unet) conditioning_data = self.get_conditioning_data(context, scheduler, unet)
control_data = self.prep_control_data( control_data = self.prep_control_data(
model=pipeline, context=context, control_input=self.control, model=pipeline,
context=context,
control_input=self.control,
latents_shape=noise.shape, latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0)) # do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True, do_classifier_free_guidance=True,
@ -378,7 +373,7 @@ class TextToLatentsInvocation(BaseInvocation):
result_latents = result_latents.to("cpu") result_latents = result_latents.to("cpu")
torch.cuda.empty_cache() torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}' name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.save(name, result_latents) context.services.latents.save(name, result_latents)
return build_latents_output(latents_name=name, latents=result_latents) return build_latents_output(latents_name=name, latents=result_latents)
@ -389,11 +384,8 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
type: Literal["l2l"] = "l2l" type: Literal["l2l"] = "l2l"
# Inputs # Inputs
latents: Optional[LatentsField] = Field( latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
description="The latents to use as a base image") strength: float = Field(default=0.7, ge=0, le=1, description="The strength of the latents to use")
strength: float = Field(
default=0.7, ge=0, le=1,
description="The strength of the latents to use")
# Schema customisation # Schema customisation
class Config(InvocationConfig): class Config(InvocationConfig):
@ -405,19 +397,18 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
"model": "model", "model": "model",
"control": "control", "control": "control",
"cfg_scale": "number", "cfg_scale": "number",
} },
}, },
} }
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
with SilenceWarnings(): # this quenches NSFW nag from diffusers
noise = context.services.latents.get(self.noise.latents_name) noise = context.services.latents.get(self.noise.latents_name)
latent = context.services.latents.get(self.latents.latents_name) latent = context.services.latents.get(self.latents.latents_name)
# Get the source node id (we are invoking the prepared node) # Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get( graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id] source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState): def step_callback(state: PipelineIntermediateState):
@ -426,19 +417,20 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
def _lora_loader(): def _lora_loader():
for lora in self.unet.loras: for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model( lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}), context=context, **lora.dict(exclude={"weight"}),
context=context,
) )
yield (lora_info.context.model, lora.weight) yield (lora_info.context.model, lora.weight)
del lora_info del lora_info
return return
unet_info = context.services.model_manager.get_model( unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict(), context=context, **self.unet.unet.dict(),
context=context,
) )
with ExitStack() as exit_stack,\ with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ unet_info.context.model, _lora_loader()
unet_info as unet: ), unet_info as unet:
noise = noise.to(device=unet.device, dtype=unet.dtype) noise = noise.to(device=unet.device, dtype=unet.dtype)
latent = latent.to(device=unet.device, dtype=unet.dtype) latent = latent.to(device=unet.device, dtype=unet.dtype)
@ -452,7 +444,9 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
conditioning_data = self.get_conditioning_data(context, scheduler, unet) conditioning_data = self.get_conditioning_data(context, scheduler, unet)
control_data = self.prep_control_data( control_data = self.prep_control_data(
model=pipeline, context=context, control_input=self.control, model=pipeline,
context=context,
control_input=self.control,
latents_shape=noise.shape, latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0)) # do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True, do_classifier_free_guidance=True,
@ -460,8 +454,8 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
) )
# TODO: Verify the noise is the right size # TODO: Verify the noise is the right size
initial_latents = latent if self.strength < 1.0 else torch.zeros_like( initial_latents = (
latent, device=unet.device, dtype=latent.dtype latent if self.strength < 1.0 else torch.zeros_like(latent, device=unet.device, dtype=latent.dtype)
) )
timesteps, _ = pipeline.get_img2img_timesteps( timesteps, _ = pipeline.get_img2img_timesteps(
@ -477,14 +471,14 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
num_inference_steps=self.steps, num_inference_steps=self.steps,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData] control_data=control_data, # list[ControlNetData]
callback=step_callback callback=step_callback,
) )
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.to("cpu") result_latents = result_latents.to("cpu")
torch.cuda.empty_cache() torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}' name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.save(name, result_latents) context.services.latents.save(name, result_latents)
return build_latents_output(latents_name=name, latents=result_latents) return build_latents_output(latents_name=name, latents=result_latents)
@ -496,14 +490,13 @@ class LatentsToImageInvocation(BaseInvocation):
type: Literal["l2i"] = "l2i" type: Literal["l2i"] = "l2i"
# Inputs # Inputs
latents: Optional[LatentsField] = Field( latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
description="The latents to generate an image from")
vae: VaeField = Field(default=None, description="Vae submodel") vae: VaeField = Field(default=None, description="Vae submodel")
tiled: bool = Field( tiled: bool = Field(default=False, description="Decode latents by overlaping tiles (less memory consumption)")
default=False, fp32: bool = Field(DEFAULT_PRECISION == "float32", description="Decode in full precision")
description="Decode latents by overlaping tiles(less memory consumption)") metadata: Optional[CoreMetadata] = Field(
fp32: bool = Field(DEFAULT_PRECISION=='float32', description="Decode in full precision") default=None, description="Optional core metadata to be written to the image"
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image") )
# Schema customisation # Schema customisation
class Config(InvocationConfig): class Config(InvocationConfig):
@ -519,7 +512,8 @@ class LatentsToImageInvocation(BaseInvocation):
latents = context.services.latents.get(self.latents.latents_name) latents = context.services.latents.get(self.latents.latents_name)
vae_info = context.services.model_manager.get_model( vae_info = context.services.model_manager.get_model(
**self.vae.vae.dict(), context=context, **self.vae.vae.dict(),
context=context,
) )
with vae_info as vae: with vae_info as vae:
@ -586,8 +580,7 @@ class LatentsToImageInvocation(BaseInvocation):
) )
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
"bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
class ResizeLatentsInvocation(BaseInvocation): class ResizeLatentsInvocation(BaseInvocation):
@ -596,24 +589,17 @@ class ResizeLatentsInvocation(BaseInvocation):
type: Literal["lresize"] = "lresize" type: Literal["lresize"] = "lresize"
# Inputs # Inputs
latents: Optional[LatentsField] = Field( latents: Optional[LatentsField] = Field(description="The latents to resize")
description="The latents to resize") width: Union[int, None] = Field(default=512, ge=64, multiple_of=8, description="The width to resize to (px)")
width: Union[int, None] = Field(default=512, height: Union[int, None] = Field(default=512, ge=64, multiple_of=8, description="The height to resize to (px)")
ge=64, multiple_of=8, description="The width to resize to (px)") mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
height: Union[int, None] = Field(default=512,
ge=64, multiple_of=8, description="The height to resize to (px)")
mode: LATENTS_INTERPOLATION_MODE = Field(
default="bilinear", description="The interpolation mode")
antialias: bool = Field( antialias: bool = Field(
default=False, default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)"
description="Whether or not to antialias (applied in bilinear and bicubic modes only)") )
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "Resize Latents", "tags": ["latents", "resize"]},
"title": "Resize Latents",
"tags": ["latents", "resize"]
},
} }
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
@ -623,9 +609,10 @@ class ResizeLatentsInvocation(BaseInvocation):
device = choose_torch_device() device = choose_torch_device()
resized_latents = torch.nn.functional.interpolate( resized_latents = torch.nn.functional.interpolate(
latents.to(device), size=(self.height // 8, self.width // 8), latents.to(device),
mode=self.mode, antialias=self.antialias size=(self.height // 8, self.width // 8),
if self.mode in ["bilinear", "bicubic"] else False, mode=self.mode,
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
) )
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
@ -644,22 +631,16 @@ class ScaleLatentsInvocation(BaseInvocation):
type: Literal["lscale"] = "lscale" type: Literal["lscale"] = "lscale"
# Inputs # Inputs
latents: Optional[LatentsField] = Field( latents: Optional[LatentsField] = Field(description="The latents to scale")
description="The latents to scale") scale_factor: float = Field(gt=0, description="The factor by which to scale the latents")
scale_factor: float = Field( mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
gt=0, description="The factor by which to scale the latents")
mode: LATENTS_INTERPOLATION_MODE = Field(
default="bilinear", description="The interpolation mode")
antialias: bool = Field( antialias: bool = Field(
default=False, default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)"
description="Whether or not to antialias (applied in bilinear and bicubic modes only)") )
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "Scale Latents", "tags": ["latents", "scale"]},
"title": "Scale Latents",
"tags": ["latents", "scale"]
},
} }
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
@ -670,9 +651,10 @@ class ScaleLatentsInvocation(BaseInvocation):
# resizing # resizing
resized_latents = torch.nn.functional.interpolate( resized_latents = torch.nn.functional.interpolate(
latents.to(device), scale_factor=self.scale_factor, mode=self.mode, latents.to(device),
antialias=self.antialias scale_factor=self.scale_factor,
if self.mode in ["bilinear", "bicubic"] else False, mode=self.mode,
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
) )
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
@ -693,19 +675,13 @@ class ImageToLatentsInvocation(BaseInvocation):
# Inputs # Inputs
image: Optional[ImageField] = Field(description="The image to encode") image: Optional[ImageField] = Field(description="The image to encode")
vae: VaeField = Field(default=None, description="Vae submodel") vae: VaeField = Field(default=None, description="Vae submodel")
tiled: bool = Field( tiled: bool = Field(default=False, description="Encode latents by overlaping tiles(less memory consumption)")
default=False, fp32: bool = Field(DEFAULT_PRECISION == "float32", description="Decode in full precision")
description="Encode latents by overlaping tiles(less memory consumption)")
fp32: bool = Field(DEFAULT_PRECISION=='float32', description="Decode in full precision")
# Schema customisation # Schema customisation
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "Image To Latents", "tags": ["latents", "image"]},
"title": "Image To Latents",
"tags": ["latents", "image"]
},
} }
@torch.no_grad() @torch.no_grad()
@ -717,7 +693,8 @@ class ImageToLatentsInvocation(BaseInvocation):
# vae_info = context.services.model_manager.get_model(**self.vae.vae.dict()) # vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
vae_info = context.services.model_manager.get_model( vae_info = context.services.model_manager.get_model(
**self.vae.vae.dict(), context=context, **self.vae.vae.dict(),
context=context,
) )
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
@ -760,9 +737,7 @@ class ImageToLatentsInvocation(BaseInvocation):
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype) image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
with torch.inference_mode(): with torch.inference_mode():
image_tensor_dist = vae.encode(image_tensor).latent_dist image_tensor_dist = vae.encode(image_tensor).latent_dist
latents = image_tensor_dist.sample().to( latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible!
dtype=vae.dtype
) # FIXME: uses torch.randn. make reproducible!
latents = vae.config.scaling_factor * latents latents = vae.config.scaling_factor * latents
latents = latents.to(dtype=orig_dtype) latents = latents.to(dtype=orig_dtype)

View File

@ -54,10 +54,7 @@ class AddInvocation(BaseInvocation, MathInvocationConfig):
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "Add", "tags": ["math", "add"]},
"title": "Add",
"tags": ["math", "add"]
},
} }
def invoke(self, context: InvocationContext) -> IntOutput: def invoke(self, context: InvocationContext) -> IntOutput:
@ -75,10 +72,7 @@ class SubtractInvocation(BaseInvocation, MathInvocationConfig):
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "Subtract", "tags": ["math", "subtract"]},
"title": "Subtract",
"tags": ["math", "subtract"]
},
} }
def invoke(self, context: InvocationContext) -> IntOutput: def invoke(self, context: InvocationContext) -> IntOutput:
@ -96,10 +90,7 @@ class MultiplyInvocation(BaseInvocation, MathInvocationConfig):
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "Multiply", "tags": ["math", "multiply"]},
"title": "Multiply",
"tags": ["math", "multiply"]
},
} }
def invoke(self, context: InvocationContext) -> IntOutput: def invoke(self, context: InvocationContext) -> IntOutput:
@ -117,10 +108,7 @@ class DivideInvocation(BaseInvocation, MathInvocationConfig):
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "Divide", "tags": ["math", "divide"]},
"title": "Divide",
"tags": ["math", "divide"]
},
} }
def invoke(self, context: InvocationContext) -> IntOutput: def invoke(self, context: InvocationContext) -> IntOutput:
@ -140,10 +128,7 @@ class RandomIntInvocation(BaseInvocation):
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "Random Integer", "tags": ["math", "random", "integer"]},
"title": "Random Integer",
"tags": ["math", "random", "integer"]
},
} }
def invoke(self, context: InvocationContext) -> IntOutput: def invoke(self, context: InvocationContext) -> IntOutput:

View File

@ -2,16 +2,19 @@ from typing import Literal, Optional, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from invokeai.app.invocations.baseinvocation import (BaseInvocation, from invokeai.app.invocations.baseinvocation import (
BaseInvocationOutput, InvocationConfig, BaseInvocation,
InvocationContext) BaseInvocationOutput,
InvocationConfig,
InvocationContext,
)
from invokeai.app.invocations.controlnet_image_processors import ControlField from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.model import (LoRAModelField, MainModelField, from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
VAEModelField)
class LoRAMetadataField(BaseModel): class LoRAMetadataField(BaseModel):
"""LoRA metadata for an image generated in InvokeAI.""" """LoRA metadata for an image generated in InvokeAI."""
lora: LoRAModelField = Field(description="The LoRA model") lora: LoRAModelField = Field(description="The LoRA model")
weight: float = Field(description="The weight of the LoRA model") weight: float = Field(description="The weight of the LoRA model")
@ -19,7 +22,9 @@ class LoRAMetadataField(BaseModel):
class CoreMetadata(BaseModel): class CoreMetadata(BaseModel):
"""Core generation metadata for an image generated in InvokeAI.""" """Core generation metadata for an image generated in InvokeAI."""
generation_mode: str = Field(description="The generation mode that output this image",) generation_mode: str = Field(
description="The generation mode that output this image",
)
positive_prompt: str = Field(description="The positive prompt parameter") positive_prompt: str = Field(description="The positive prompt parameter")
negative_prompt: str = Field(description="The negative prompt parameter") negative_prompt: str = Field(description="The negative prompt parameter")
width: int = Field(description="The width parameter") width: int = Field(description="The width parameter")
@ -29,22 +34,41 @@ class CoreMetadata(BaseModel):
cfg_scale: float = Field(description="The classifier-free guidance scale parameter") cfg_scale: float = Field(description="The classifier-free guidance scale parameter")
steps: int = Field(description="The number of steps used for inference") steps: int = Field(description="The number of steps used for inference")
scheduler: str = Field(description="The scheduler used for inference") scheduler: str = Field(description="The scheduler used for inference")
clip_skip: int = Field(description="The number of skipped CLIP layers",) clip_skip: int = Field(
description="The number of skipped CLIP layers",
)
model: MainModelField = Field(description="The main model used for inference") model: MainModelField = Field(description="The main model used for inference")
controlnets: list[ControlField] = Field(description="The ControlNets used for inference") controlnets: list[ControlField] = Field(description="The ControlNets used for inference")
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference") loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
strength: Union[float, None] = Field(
default=None,
description="The strength used for latents-to-latents",
)
init_image: Union[str, None] = Field(
default=None, description="The name of the initial image"
)
vae: Union[VAEModelField, None] = Field( vae: Union[VAEModelField, None] = Field(
default=None, default=None,
description="The VAE used for decoding, if the main model's default was not used", description="The VAE used for decoding, if the main model's default was not used",
) )
# Latents-to-Latents
strength: Union[float, None] = Field(
default=None,
description="The strength used for latents-to-latents",
)
init_image: Union[str, None] = Field(default=None, description="The name of the initial image")
# SDXL
positive_style_prompt: Union[str, None] = Field(default=None, description="The positive style prompt parameter")
negative_style_prompt: Union[str, None] = Field(default=None, description="The negative style prompt parameter")
# SDXL Refiner
refiner_model: Union[MainModelField, None] = Field(default=None, description="The SDXL Refiner model used")
refiner_cfg_scale: Union[float, None] = Field(
default=None,
description="The classifier-free guidance scale parameter used for the refiner",
)
refiner_steps: Union[int, None] = Field(default=None, description="The number of steps used for the refiner")
refiner_scheduler: Union[str, None] = Field(default=None, description="The scheduler used for the refiner")
refiner_aesthetic_store: Union[float, None] = Field(
default=None, description="The aesthetic score used for the refiner"
)
refiner_start: Union[float, None] = Field(default=None, description="The start value used for refiner denoising")
class ImageMetadata(BaseModel): class ImageMetadata(BaseModel):
"""An image's generation metadata""" """An image's generation metadata"""
@ -53,9 +77,7 @@ class ImageMetadata(BaseModel):
default=None, default=None,
description="The image's core metadata, if it was created in the Linear or Canvas UI", description="The image's core metadata, if it was created in the Linear or Canvas UI",
) )
graph: Optional[dict] = Field( graph: Optional[dict] = Field(default=None, description="The graph that created the image")
default=None, description="The graph that created the image"
)
class MetadataAccumulatorOutput(BaseInvocationOutput): class MetadataAccumulatorOutput(BaseInvocationOutput):
@ -71,7 +93,9 @@ class MetadataAccumulatorInvocation(BaseInvocation):
type: Literal["metadata_accumulator"] = "metadata_accumulator" type: Literal["metadata_accumulator"] = "metadata_accumulator"
generation_mode: str = Field(description="The generation mode that output this image",) generation_mode: str = Field(
description="The generation mode that output this image",
)
positive_prompt: str = Field(description="The positive prompt parameter") positive_prompt: str = Field(description="The positive prompt parameter")
negative_prompt: str = Field(description="The negative prompt parameter") negative_prompt: str = Field(description="The negative prompt parameter")
width: int = Field(description="The width parameter") width: int = Field(description="The width parameter")
@ -81,7 +105,9 @@ class MetadataAccumulatorInvocation(BaseInvocation):
cfg_scale: float = Field(description="The classifier-free guidance scale parameter") cfg_scale: float = Field(description="The classifier-free guidance scale parameter")
steps: int = Field(description="The number of steps used for inference") steps: int = Field(description="The number of steps used for inference")
scheduler: str = Field(description="The scheduler used for inference") scheduler: str = Field(description="The scheduler used for inference")
clip_skip: int = Field(description="The number of skipped CLIP layers",) clip_skip: int = Field(
description="The number of skipped CLIP layers",
)
model: MainModelField = Field(description="The main model used for inference") model: MainModelField = Field(description="The main model used for inference")
controlnets: list[ControlField] = Field(description="The ControlNets used for inference") controlnets: list[ControlField] = Field(description="The ControlNets used for inference")
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference") loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
@ -89,44 +115,38 @@ class MetadataAccumulatorInvocation(BaseInvocation):
default=None, default=None,
description="The strength used for latents-to-latents", description="The strength used for latents-to-latents",
) )
init_image: Union[str, None] = Field( init_image: Union[str, None] = Field(default=None, description="The name of the initial image")
default=None, description="The name of the initial image"
)
vae: Union[VAEModelField, None] = Field( vae: Union[VAEModelField, None] = Field(
default=None, default=None,
description="The VAE used for decoding, if the main model's default was not used", description="The VAE used for decoding, if the main model's default was not used",
) )
# SDXL
positive_style_prompt: Union[str, None] = Field(default=None, description="The positive style prompt parameter")
negative_style_prompt: Union[str, None] = Field(default=None, description="The negative style prompt parameter")
# SDXL Refiner
refiner_model: Union[MainModelField, None] = Field(default=None, description="The SDXL Refiner model used")
refiner_cfg_scale: Union[float, None] = Field(
default=None,
description="The classifier-free guidance scale parameter used for the refiner",
)
refiner_steps: Union[int, None] = Field(default=None, description="The number of steps used for the refiner")
refiner_scheduler: Union[str, None] = Field(default=None, description="The scheduler used for the refiner")
refiner_aesthetic_store: Union[float, None] = Field(
default=None, description="The aesthetic score used for the refiner"
)
refiner_start: Union[float, None] = Field(default=None, description="The start value used for refiner denoising")
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {
"title": "Metadata Accumulator", "title": "Metadata Accumulator",
"tags": ["image", "metadata", "generation"] "tags": ["image", "metadata", "generation"],
}, },
} }
def invoke(self, context: InvocationContext) -> MetadataAccumulatorOutput: def invoke(self, context: InvocationContext) -> MetadataAccumulatorOutput:
"""Collects and outputs a CoreMetadata object""" """Collects and outputs a CoreMetadata object"""
return MetadataAccumulatorOutput( return MetadataAccumulatorOutput(metadata=CoreMetadata(**self.dict()))
metadata=CoreMetadata(
generation_mode=self.generation_mode,
positive_prompt=self.positive_prompt,
negative_prompt=self.negative_prompt,
width=self.width,
height=self.height,
seed=self.seed,
rand_device=self.rand_device,
cfg_scale=self.cfg_scale,
steps=self.steps,
scheduler=self.scheduler,
model=self.model,
strength=self.strength,
init_image=self.init_image,
vae=self.vae,
controlnets=self.controlnets,
loras=self.loras,
clip_skip=self.clip_skip,
)
)

View File

@ -4,17 +4,14 @@ from typing import List, Literal, Optional, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from ...backend.model_management import BaseModelType, ModelType, SubModelType from ...backend.model_management import BaseModelType, ModelType, SubModelType
from .baseinvocation import (BaseInvocation, BaseInvocationOutput, from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
InvocationConfig, InvocationContext)
class ModelInfo(BaseModel): class ModelInfo(BaseModel):
model_name: str = Field(description="Info to load submodel") model_name: str = Field(description="Info to load submodel")
base_model: BaseModelType = Field(description="Base model") base_model: BaseModelType = Field(description="Base model")
model_type: ModelType = Field(description="Info to load submodel") model_type: ModelType = Field(description="Info to load submodel")
submodel: Optional[SubModelType] = Field( submodel: Optional[SubModelType] = Field(default=None, description="Info to load submodel")
default=None, description="Info to load submodel"
)
class LoraInfo(ModelInfo): class LoraInfo(ModelInfo):
@ -33,6 +30,7 @@ class ClipField(BaseModel):
skipped_layers: int = Field(description="Number of skipped layers in text_encoder") skipped_layers: int = Field(description="Number of skipped layers in text_encoder")
loras: List[LoraInfo] = Field(description="Loras to apply on model loading") loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
class VaeField(BaseModel): class VaeField(BaseModel):
# TODO: better naming? # TODO: better naming?
vae: ModelInfo = Field(description="Info to load vae submodel") vae: ModelInfo = Field(description="Info to load vae submodel")
@ -49,6 +47,7 @@ class ModelLoaderOutput(BaseInvocationOutput):
vae: VaeField = Field(default=None, description="Vae submodel") vae: VaeField = Field(default=None, description="Vae submodel")
# fmt: on # fmt: on
class MainModelField(BaseModel): class MainModelField(BaseModel):
"""Main model field""" """Main model field"""
@ -62,6 +61,7 @@ class LoRAModelField(BaseModel):
model_name: str = Field(description="Name of the LoRA model") model_name: str = Field(description="Name of the LoRA model")
base_model: BaseModelType = Field(description="Base model") base_model: BaseModelType = Field(description="Base model")
class MainModelLoaderInvocation(BaseInvocation): class MainModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels.""" """Loads a main model, outputting its submodels."""
@ -197,9 +197,7 @@ class LoraLoaderInvocation(BaseInvocation):
type: Literal["lora_loader"] = "lora_loader" type: Literal["lora_loader"] = "lora_loader"
lora: Union[LoRAModelField, None] = Field( lora: Union[LoRAModelField, None] = Field(default=None, description="Lora model name")
default=None, description="Lora model name"
)
weight: float = Field(default=0.75, description="With what weight to apply lora") weight: float = Field(default=0.75, description="With what weight to apply lora")
unet: Optional[UNetField] = Field(description="UNet model for applying lora") unet: Optional[UNetField] = Field(description="UNet model for applying lora")
@ -228,14 +226,10 @@ class LoraLoaderInvocation(BaseInvocation):
): ):
raise Exception(f"Unkown lora name: {lora_name}!") raise Exception(f"Unkown lora name: {lora_name}!")
if self.unet is not None and any( if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras):
lora.model_name == lora_name for lora in self.unet.loras
):
raise Exception(f'Lora "{lora_name}" already applied to unet') raise Exception(f'Lora "{lora_name}" already applied to unet')
if self.clip is not None and any( if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras):
lora.model_name == lora_name for lora in self.clip.loras
):
raise Exception(f'Lora "{lora_name}" already applied to clip') raise Exception(f'Lora "{lora_name}" already applied to clip')
output = LoraLoaderOutput() output = LoraLoaderOutput()

View File

@ -12,16 +12,37 @@ import matplotlib.pyplot as plt
from easing_functions import ( from easing_functions import (
LinearInOut, LinearInOut,
QuadEaseInOut, QuadEaseIn, QuadEaseOut, QuadEaseInOut,
CubicEaseInOut, CubicEaseIn, CubicEaseOut, QuadEaseIn,
QuarticEaseInOut, QuarticEaseIn, QuarticEaseOut, QuadEaseOut,
QuinticEaseInOut, QuinticEaseIn, QuinticEaseOut, CubicEaseInOut,
SineEaseInOut, SineEaseIn, SineEaseOut, CubicEaseIn,
CircularEaseIn, CircularEaseInOut, CircularEaseOut, CubicEaseOut,
ExponentialEaseInOut, ExponentialEaseIn, ExponentialEaseOut, QuarticEaseInOut,
ElasticEaseIn, ElasticEaseInOut, ElasticEaseOut, QuarticEaseIn,
BackEaseIn, BackEaseInOut, BackEaseOut, QuarticEaseOut,
BounceEaseIn, BounceEaseInOut, BounceEaseOut) QuinticEaseInOut,
QuinticEaseIn,
QuinticEaseOut,
SineEaseInOut,
SineEaseIn,
SineEaseOut,
CircularEaseIn,
CircularEaseInOut,
CircularEaseOut,
ExponentialEaseInOut,
ExponentialEaseIn,
ExponentialEaseOut,
ElasticEaseIn,
ElasticEaseInOut,
ElasticEaseOut,
BackEaseIn,
BackEaseInOut,
BackEaseOut,
BounceEaseIn,
BounceEaseInOut,
BounceEaseOut,
)
from .baseinvocation import ( from .baseinvocation import (
BaseInvocation, BaseInvocation,
@ -45,17 +66,12 @@ class FloatLinearRangeInvocation(BaseInvocation):
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "Linear Range (Float)", "tags": ["math", "float", "linear", "range"]},
"title": "Linear Range (Float)",
"tags": ["math", "float", "linear", "range"]
},
} }
def invoke(self, context: InvocationContext) -> FloatCollectionOutput: def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
param_list = list(np.linspace(self.start, self.stop, self.steps)) param_list = list(np.linspace(self.start, self.stop, self.steps))
return FloatCollectionOutput( return FloatCollectionOutput(collection=param_list)
collection=param_list
)
EASING_FUNCTIONS_MAP = { EASING_FUNCTIONS_MAP = {
@ -92,9 +108,7 @@ EASING_FUNCTIONS_MAP = {
"BounceInOut": BounceEaseInOut, "BounceInOut": BounceEaseInOut,
} }
EASING_FUNCTION_KEYS: Any = Literal[ EASING_FUNCTION_KEYS: Any = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))]
tuple(list(EASING_FUNCTIONS_MAP.keys()))
]
# actually I think for now could just use CollectionOutput (which is list[Any] # actually I think for now could just use CollectionOutput (which is list[Any]
@ -123,13 +137,9 @@ class StepParamEasingInvocation(BaseInvocation):
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "Param Easing By Step", "tags": ["param", "step", "easing"]},
"title": "Param Easing By Step",
"tags": ["param", "step", "easing"]
},
} }
def invoke(self, context: InvocationContext) -> FloatCollectionOutput: def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
log_diagnostics = False log_diagnostics = False
# convert from start_step_percent to nearest step <= (steps * start_step_percent) # convert from start_step_percent to nearest step <= (steps * start_step_percent)
@ -171,11 +181,12 @@ class StepParamEasingInvocation(BaseInvocation):
# but if even then number_of_steps/2 === ceil(number_of_steps/2), so can just use ceil always # but if even then number_of_steps/2 === ceil(number_of_steps/2), so can just use ceil always
base_easing_duration = int(np.ceil(num_easing_steps / 2.0)) base_easing_duration = int(np.ceil(num_easing_steps / 2.0))
if log_diagnostics: context.services.logger.debug("base easing duration: " + str(base_easing_duration)) if log_diagnostics:
even_num_steps = (num_easing_steps % 2 == 0) # even number of steps context.services.logger.debug("base easing duration: " + str(base_easing_duration))
easing_function = easing_class(start=self.start_value, even_num_steps = num_easing_steps % 2 == 0 # even number of steps
end=self.end_value, easing_function = easing_class(
duration=base_easing_duration - 1) start=self.start_value, end=self.end_value, duration=base_easing_duration - 1
)
base_easing_vals = list() base_easing_vals = list()
for step_index in range(base_easing_duration): for step_index in range(base_easing_duration):
easing_val = easing_function.ease(step_index) easing_val = easing_function.ease(step_index)
@ -214,9 +225,7 @@ class StepParamEasingInvocation(BaseInvocation):
# #
else: # no mirroring (default) else: # no mirroring (default)
easing_function = easing_class(start=self.start_value, easing_function = easing_class(start=self.start_value, end=self.end_value, duration=num_easing_steps - 1)
end=self.end_value,
duration=num_easing_steps - 1)
for step_index in range(num_easing_steps): for step_index in range(num_easing_steps):
step_val = easing_function.ease(step_index) step_val = easing_function.ease(step_index)
easing_list.append(step_val) easing_list.append(step_val)
@ -240,13 +249,11 @@ class StepParamEasingInvocation(BaseInvocation):
ax = plt.gca() ax = plt.gca()
ax.xaxis.set_major_locator(MaxNLocator(integer=True)) ax.xaxis.set_major_locator(MaxNLocator(integer=True))
buf = io.BytesIO() buf = io.BytesIO()
plt.savefig(buf, format='png') plt.savefig(buf, format="png")
buf.seek(0) buf.seek(0)
im = PIL.Image.open(buf) im = PIL.Image.open(buf)
im.show() im.show()
buf.close() buf.close()
# output array of size steps, each entry list[i] is param value for step i # output array of size steps, each entry list[i] is param value for step i
return FloatCollectionOutput( return FloatCollectionOutput(collection=param_list)
collection=param_list
)

View File

@ -4,14 +4,17 @@ from typing import Literal
from pydantic import Field from pydantic import Field
from .baseinvocation import (BaseInvocation, BaseInvocationOutput, from invokeai.app.invocations.prompt import PromptOutput
InvocationConfig, InvocationContext)
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
from .math import FloatOutput, IntOutput from .math import FloatOutput, IntOutput
# Pass-through parameter nodes - used by subgraphs # Pass-through parameter nodes - used by subgraphs
class ParamIntInvocation(BaseInvocation): class ParamIntInvocation(BaseInvocation):
"""An integer parameter""" """An integer parameter"""
# fmt: off # fmt: off
type: Literal["param_int"] = "param_int" type: Literal["param_int"] = "param_int"
a: int = Field(default=0, description="The integer value") a: int = Field(default=0, description="The integer value")
@ -19,17 +22,16 @@ class ParamIntInvocation(BaseInvocation):
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"tags": ["param", "integer"], "title": "Integer Parameter"},
"tags": ["param", "integer"],
"title": "Integer Parameter"
},
} }
def invoke(self, context: InvocationContext) -> IntOutput: def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=self.a) return IntOutput(a=self.a)
class ParamFloatInvocation(BaseInvocation): class ParamFloatInvocation(BaseInvocation):
"""A float parameter""" """A float parameter"""
# fmt: off # fmt: off
type: Literal["param_float"] = "param_float" type: Literal["param_float"] = "param_float"
param: float = Field(default=0.0, description="The float value") param: float = Field(default=0.0, description="The float value")
@ -37,34 +39,45 @@ class ParamFloatInvocation(BaseInvocation):
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"tags": ["param", "float"], "title": "Float Parameter"},
"tags": ["param", "float"],
"title": "Float Parameter"
},
} }
def invoke(self, context: InvocationContext) -> FloatOutput: def invoke(self, context: InvocationContext) -> FloatOutput:
return FloatOutput(param=self.param) return FloatOutput(param=self.param)
class StringOutput(BaseInvocationOutput): class StringOutput(BaseInvocationOutput):
"""A string output""" """A string output"""
type: Literal["string_output"] = "string_output" type: Literal["string_output"] = "string_output"
text: str = Field(default=None, description="The output string") text: str = Field(default=None, description="The output string")
class ParamStringInvocation(BaseInvocation): class ParamStringInvocation(BaseInvocation):
"""A string parameter""" """A string parameter"""
type: Literal['param_string'] = 'param_string'
text: str = Field(default='', description='The string value') type: Literal["param_string"] = "param_string"
text: str = Field(default="", description="The string value")
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"tags": ["param", "string"], "title": "String Parameter"},
"tags": ["param", "string"],
"title": "String Parameter"
},
} }
def invoke(self, context: InvocationContext) -> StringOutput: def invoke(self, context: InvocationContext) -> StringOutput:
return StringOutput(text=self.text) return StringOutput(text=self.text)
class ParamPromptInvocation(BaseInvocation):
"""A prompt input parameter"""
type: Literal["param_prompt"] = "param_prompt"
prompt: str = Field(default="", description="The prompt value")
class Config(InvocationConfig):
schema_extra = {
"ui": {"tags": ["param", "prompt"], "title": "Prompt"},
}
def invoke(self, context: InvocationContext) -> PromptOutput:
return PromptOutput(prompt=self.prompt)

View File

@ -7,8 +7,10 @@ from pydantic import Field, validator
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator
class PromptOutput(BaseInvocationOutput): class PromptOutput(BaseInvocationOutput):
"""Base class for invocations that output a prompt""" """Base class for invocations that output a prompt"""
# fmt: off # fmt: off
type: Literal["prompt"] = "prompt" type: Literal["prompt"] = "prompt"
@ -17,9 +19,9 @@ class PromptOutput(BaseInvocationOutput):
class Config: class Config:
schema_extra = { schema_extra = {
'required': [ "required": [
'type', "type",
'prompt', "prompt",
] ]
} }
@ -44,16 +46,11 @@ class DynamicPromptInvocation(BaseInvocation):
type: Literal["dynamic_prompt"] = "dynamic_prompt" type: Literal["dynamic_prompt"] = "dynamic_prompt"
prompt: str = Field(description="The prompt to parse with dynamicprompts") prompt: str = Field(description="The prompt to parse with dynamicprompts")
max_prompts: int = Field(default=1, description="The number of prompts to generate") max_prompts: int = Field(default=1, description="The number of prompts to generate")
combinatorial: bool = Field( combinatorial: bool = Field(default=False, description="Whether to use the combinatorial generator")
default=False, description="Whether to use the combinatorial generator"
)
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "Dynamic Prompt", "tags": ["prompt", "dynamic"]},
"title": "Dynamic Prompt",
"tags": ["prompt", "dynamic"]
},
} }
def invoke(self, context: InvocationContext) -> PromptCollectionOutput: def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
@ -68,7 +65,8 @@ class DynamicPromptInvocation(BaseInvocation):
class PromptsFromFileInvocation(BaseInvocation): class PromptsFromFileInvocation(BaseInvocation):
'''Loads prompts from a text file''' """Loads prompts from a text file"""
# fmt: off # fmt: off
type: Literal['prompt_from_file'] = 'prompt_from_file' type: Literal['prompt_from_file'] = 'prompt_from_file'
@ -82,10 +80,7 @@ class PromptsFromFileInvocation(BaseInvocation):
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "Prompts From File", "tags": ["prompt", "file"]},
"title": "Prompts From File",
"tags": ["prompt", "file"]
},
} }
@validator("file_path") @validator("file_path")
@ -103,11 +98,13 @@ class PromptsFromFileInvocation(BaseInvocation):
with open(file_path) as f: with open(file_path) as f:
for i, line in enumerate(f): for i, line in enumerate(f):
if i >= start_line and i < end_line: if i >= start_line and i < end_line:
prompts.append((pre_prompt or '') + line.strip() + (post_prompt or '')) prompts.append((pre_prompt or "") + line.strip() + (post_prompt or ""))
if i >= end_line: if i >= end_line:
break break
return prompts return prompts
def invoke(self, context: InvocationContext) -> PromptCollectionOutput: def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
prompts = self.promptsFromFile(self.file_path, self.pre_prompt, self.post_prompt, self.start_line, self.max_prompts) prompts = self.promptsFromFile(
self.file_path, self.pre_prompt, self.post_prompt, self.start_line, self.max_prompts
)
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts)) return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))

View File

@ -7,13 +7,13 @@ from pydantic import Field, validator
from ...backend.model_management import ModelType, SubModelType from ...backend.model_management import ModelType, SubModelType
from invokeai.app.util.step_callback import stable_diffusion_xl_step_callback from invokeai.app.util.step_callback import stable_diffusion_xl_step_callback
from .baseinvocation import (BaseInvocation, BaseInvocationOutput, from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
InvocationConfig, InvocationContext)
from .model import UNetField, ClipField, VaeField, MainModelField, ModelInfo from .model import UNetField, ClipField, VaeField, MainModelField, ModelInfo
from .compel import ConditioningField from .compel import ConditioningField
from .latent import LatentsField, SAMPLER_NAME_VALUES, LatentsOutput, get_scheduler, build_latents_output from .latent import LatentsField, SAMPLER_NAME_VALUES, LatentsOutput, get_scheduler, build_latents_output
class SDXLModelLoaderOutput(BaseInvocationOutput): class SDXLModelLoaderOutput(BaseInvocationOutput):
"""SDXL base model loader output""" """SDXL base model loader output"""
@ -26,8 +26,10 @@ class SDXLModelLoaderOutput(BaseInvocationOutput):
vae: VaeField = Field(default=None, description="Vae submodel") vae: VaeField = Field(default=None, description="Vae submodel")
# fmt: on # fmt: on
class SDXLRefinerModelLoaderOutput(BaseInvocationOutput): class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
"""SDXL refiner model loader output""" """SDXL refiner model loader output"""
# fmt: off # fmt: off
type: Literal["sdxl_refiner_model_loader_output"] = "sdxl_refiner_model_loader_output" type: Literal["sdxl_refiner_model_loader_output"] = "sdxl_refiner_model_loader_output"
unet: UNetField = Field(default=None, description="UNet submodel") unet: UNetField = Field(default=None, description="UNet submodel")
@ -36,6 +38,7 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
# fmt: on # fmt: on
# fmt: on # fmt: on
class SDXLModelLoaderInvocation(BaseInvocation): class SDXLModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl base model, outputting its submodels.""" """Loads an sdxl base model, outputting its submodels."""
@ -125,8 +128,10 @@ class SDXLModelLoaderInvocation(BaseInvocation):
), ),
) )
class SDXLRefinerModelLoaderInvocation(BaseInvocation): class SDXLRefinerModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl refiner model, outputting its submodels.""" """Loads an sdxl refiner model, outputting its submodels."""
type: Literal["sdxl_refiner_model_loader"] = "sdxl_refiner_model_loader" type: Literal["sdxl_refiner_model_loader"] = "sdxl_refiner_model_loader"
model: MainModelField = Field(description="The model to load") model: MainModelField = Field(description="The model to load")
@ -138,7 +143,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
"ui": { "ui": {
"title": "SDXL Refiner Model Loader", "title": "SDXL Refiner Model Loader",
"tags": ["model", "loader", "sdxl_refiner"], "tags": ["model", "loader", "sdxl_refiner"],
"type_hints": {"model": "model"}, "type_hints": {"model": "refiner_model"},
}, },
} }
@ -197,6 +202,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
), ),
) )
# Text to image # Text to image
class SDXLTextToLatentsInvocation(BaseInvocation): class SDXLTextToLatentsInvocation(BaseInvocation):
"""Generates latents from conditionings.""" """Generates latents from conditionings."""
@ -224,10 +230,10 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
if isinstance(v, list): if isinstance(v, list):
for i in v: for i in v:
if i < 1: if i < 1:
raise ValueError('cfg_scale must be greater than 1') raise ValueError("cfg_scale must be greater than 1")
else: else:
if v < 1: if v < 1:
raise ValueError('cfg_scale must be greater than 1') raise ValueError("cfg_scale must be greater than 1")
return v return v
# Schema customisation # Schema customisation
@ -239,8 +245,8 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
"type_hints": { "type_hints": {
"model": "model", "model": "model",
# "cfg_scale": "float", # "cfg_scale": "float",
"cfg_scale": "number" "cfg_scale": "number",
} },
}, },
} }
@ -265,9 +271,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375 # https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
graph_execution_state = context.services.graph_execution_manager.get( graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id] source_node_id = graph_execution_state.prepared_source_mapping[self.id]
latents = context.services.latents.get(self.noise.latents_name) latents = context.services.latents.get(self.noise.latents_name)
@ -288,18 +292,15 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
) )
num_inference_steps = self.steps num_inference_steps = self.steps
scheduler.set_timesteps(num_inference_steps)
timesteps = scheduler.timesteps
latents = latents * scheduler.init_noise_sigma unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context)
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict()
)
do_classifier_free_guidance = True do_classifier_free_guidance = True
cross_attention_kwargs = None cross_attention_kwargs = None
with unet_info as unet: with unet_info as unet:
scheduler.set_timesteps(num_inference_steps, device=unet.device)
timesteps = scheduler.timesteps
latents = latents.to(device=unet.device, dtype=unet.dtype) * scheduler.init_noise_sigma
extra_step_kwargs = dict() extra_step_kwargs = dict()
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()): if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
@ -436,17 +437,16 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
# if callback is not None and i % callback_steps == 0: # if callback is not None and i % callback_steps == 0:
# callback(i, t, latents) # callback(i, t, latents)
################# #################
latents = latents.to("cpu") latents = latents.to("cpu")
torch.cuda.empty_cache() torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}' name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.save(name, latents) context.services.latents.save(name, latents)
return build_latents_output(latents_name=name, latents=latents) return build_latents_output(latents_name=name, latents=latents)
class SDXLLatentsToLatentsInvocation(BaseInvocation): class SDXLLatentsToLatentsInvocation(BaseInvocation):
"""Generates latents from conditionings.""" """Generates latents from conditionings."""
@ -463,8 +463,8 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
unet: UNetField = Field(default=None, description="UNet submodel") unet: UNetField = Field(default=None, description="UNet submodel")
latents: Optional[LatentsField] = Field(description="Initial latents") latents: Optional[LatentsField] = Field(description="Initial latents")
denoising_start: float = Field(default=0.0, ge=0, lt=1, description="") denoising_start: float = Field(default=0.0, ge=0, le=1, description="")
denoising_end: float = Field(default=1.0, gt=0, le=1, description="") denoising_end: float = Field(default=1.0, ge=0, le=1, description="")
# control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use") # control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", ) # seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
@ -477,10 +477,10 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
if isinstance(v, list): if isinstance(v, list):
for i in v: for i in v:
if i < 1: if i < 1:
raise ValueError('cfg_scale must be greater than 1') raise ValueError("cfg_scale must be greater than 1")
else: else:
if v < 1: if v < 1:
raise ValueError('cfg_scale must be greater than 1') raise ValueError("cfg_scale must be greater than 1")
return v return v
# Schema customisation # Schema customisation
@ -492,8 +492,8 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
"type_hints": { "type_hints": {
"model": "model", "model": "model",
# "cfg_scale": "float", # "cfg_scale": "float",
"cfg_scale": "number" "cfg_scale": "number",
} },
}, },
} }
@ -518,9 +518,7 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375 # https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
graph_execution_state = context.services.graph_execution_manager.get( graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id] source_node_id = graph_execution_state.prepared_source_mapping[self.id]
latents = context.services.latents.get(self.latents.latents_name) latents = context.services.latents.get(self.latents.latents_name)
@ -540,27 +538,28 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
scheduler_name=self.scheduler, scheduler_name=self.scheduler,
) )
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict(),
context=context,
)
do_classifier_free_guidance = True
cross_attention_kwargs = None
with unet_info as unet:
# apply denoising_start # apply denoising_start
num_inference_steps = self.steps num_inference_steps = self.steps
scheduler.set_timesteps(num_inference_steps) scheduler.set_timesteps(num_inference_steps, device=unet.device)
t_start = int(round(self.denoising_start * num_inference_steps)) t_start = int(round(self.denoising_start * num_inference_steps))
timesteps = scheduler.timesteps[t_start * scheduler.order :] timesteps = scheduler.timesteps[t_start * scheduler.order :]
num_inference_steps = num_inference_steps - t_start num_inference_steps = num_inference_steps - t_start
# apply noise(if provided) # apply noise(if provided)
if self.noise is not None: if self.noise is not None and timesteps.shape[0] > 0:
noise = context.services.latents.get(self.noise.latents_name) noise = context.services.latents.get(self.noise.latents_name)
latents = scheduler.add_noise(latents, noise, timesteps[:1]) latents = scheduler.add_noise(latents, noise, timesteps[:1])
del noise del noise
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict()
)
do_classifier_free_guidance = True
cross_attention_kwargs = None
with unet_info as unet:
# apply scheduler extra args # apply scheduler extra args
extra_step_kwargs = dict() extra_step_kwargs = dict()
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()): if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
@ -697,13 +696,11 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
# if callback is not None and i % callback_steps == 0: # if callback is not None and i % callback_steps == 0:
# callback(i, t, latents) # callback(i, t, latents)
################# #################
latents = latents.to("cpu") latents = latents.to("cpu")
torch.cuda.empty_cache() torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}' name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.save(name, latents) context.services.latents.save(name, latents)
return build_latents_output(latents_name=name, latents=latents) return build_latents_output(latents_name=name, latents=latents)

View File

@ -29,16 +29,11 @@ class ESRGANInvocation(BaseInvocation):
type: Literal["esrgan"] = "esrgan" type: Literal["esrgan"] = "esrgan"
image: Union[ImageField, None] = Field(default=None, description="The input image") image: Union[ImageField, None] = Field(default=None, description="The input image")
model_name: ESRGAN_MODELS = Field( model_name: ESRGAN_MODELS = Field(default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use")
default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use"
)
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "Upscale (RealESRGAN)", "tags": ["image", "upscale", "realesrgan"]},
"title": "Upscale (RealESRGAN)",
"tags": ["image", "upscale", "realesrgan"]
},
} }
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
@ -108,9 +103,7 @@ class ESRGANInvocation(BaseInvocation):
upscaled_image, img_mode = upsampler.enhance(cv_image) upscaled_image, img_mode = upsampler.enhance(cv_image)
# back to PIL # back to PIL
pil_image = Image.fromarray( pil_image = Image.fromarray(cv.cvtColor(upscaled_image, cv.COLOR_BGR2RGB)).convert("RGBA")
cv.cvtColor(upscaled_image, cv.COLOR_BGR2RGB)
).convert("RGBA")
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=pil_image, image=pil_image,

View File

@ -1,3 +1,4 @@
class CanceledException(Exception): class CanceledException(Exception):
"""Execution canceled by user.""" """Execution canceled by user."""
pass pass

View File

@ -1,8 +1,83 @@
from enum import Enum from enum import Enum
from typing import Optional, Tuple from typing import Optional, Tuple, Literal
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from invokeai.app.util.metaenum import MetaEnum from invokeai.app.util.metaenum import MetaEnum
from ..invocations.baseinvocation import (
BaseInvocationOutput,
InvocationConfig,
)
class ImageField(BaseModel):
"""An image field used for passing image objects between invocations"""
image_name: Optional[str] = Field(default=None, description="The name of the image")
class Config:
schema_extra = {"required": ["image_name"]}
class ColorField(BaseModel):
r: int = Field(ge=0, le=255, description="The red component")
g: int = Field(ge=0, le=255, description="The green component")
b: int = Field(ge=0, le=255, description="The blue component")
a: int = Field(ge=0, le=255, description="The alpha component")
def tuple(self) -> Tuple[int, int, int, int]:
return (self.r, self.g, self.b, self.a)
class ProgressImage(BaseModel):
"""The progress image sent intermittently during processing"""
width: int = Field(description="The effective width of the image in pixels")
height: int = Field(description="The effective height of the image in pixels")
dataURL: str = Field(description="The image data as a b64 data URL")
class PILInvocationConfig(BaseModel):
"""Helper class to provide all PIL invocations with additional config"""
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["PIL", "image"],
},
}
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
# fmt: off
type: Literal["image_output"] = "image_output"
image: ImageField = Field(default=None, description="The output image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
# fmt: on
class Config:
schema_extra = {"required": ["type", "image", "width", "height"]}
class MaskOutput(BaseInvocationOutput):
"""Base class for invocations that output a mask"""
# fmt: off
type: Literal["mask"] = "mask"
mask: ImageField = Field(default=None, description="The output mask")
width: int = Field(description="The width of the mask in pixels")
height: int = Field(description="The height of the mask in pixels")
# fmt: on
class Config:
schema_extra = {
"required": [
"type",
"mask",
]
}
class ResourceOrigin(str, Enum, metaclass=MetaEnum): class ResourceOrigin(str, Enum, metaclass=MetaEnum):
@ -61,30 +136,3 @@ class InvalidImageCategoryException(ValueError):
def __init__(self, message="Invalid image category."): def __init__(self, message="Invalid image category."):
super().__init__(message) super().__init__(message)
class ImageField(BaseModel):
"""An image field used for passing image objects between invocations"""
image_name: Optional[str] = Field(default=None, description="The name of the image")
class Config:
schema_extra = {"required": ["image_name"]}
class ColorField(BaseModel):
r: int = Field(ge=0, le=255, description="The red component")
g: int = Field(ge=0, le=255, description="The green component")
b: int = Field(ge=0, le=255, description="The blue component")
a: int = Field(ge=0, le=255, description="The alpha component")
def tuple(self) -> Tuple[int, int, int, int]:
return (self.r, self.g, self.b, self.a)
class ProgressImage(BaseModel):
"""The progress image sent intermittently during processing"""
width: int = Field(description="The effective width of the image in pixels")
height: int = Field(description="The effective height of the image in pixels")
dataURL: str = Field(description="The image data as a b64 data URL")

View File

@ -207,9 +207,7 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
raise e raise e
finally: finally:
self._lock.release() self._lock.release()
return OffsetPaginatedResults( return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
items=images, offset=offset, limit=limit, total=count
)
def get_all_board_image_names_for_board(self, board_id: str) -> list[str]: def get_all_board_image_names_for_board(self, board_id: str) -> list[str]:
try: try:

View File

@ -102,9 +102,7 @@ class BoardImagesService(BoardImagesServiceABC):
self, self,
board_id: str, board_id: str,
) -> list[str]: ) -> list[str]:
return self._services.board_image_records.get_all_board_image_names_for_board( return self._services.board_image_records.get_all_board_image_names_for_board(board_id)
board_id
)
def get_board_for_image( def get_board_for_image(
self, self,
@ -114,9 +112,7 @@ class BoardImagesService(BoardImagesServiceABC):
return board_id return board_id
def board_record_to_dto( def board_record_to_dto(board_record: BoardRecord, cover_image_name: Optional[str], image_count: int) -> BoardDTO:
board_record: BoardRecord, cover_image_name: Optional[str], image_count: int
) -> BoardDTO:
"""Converts a board record to a board DTO.""" """Converts a board record to a board DTO."""
return BoardDTO( return BoardDTO(
**board_record.dict(exclude={"cover_image_name"}), **board_record.dict(exclude={"cover_image_name"}),

View File

@ -15,9 +15,7 @@ from pydantic import BaseModel, Field, Extra
class BoardChanges(BaseModel, extra=Extra.forbid): class BoardChanges(BaseModel, extra=Extra.forbid):
board_name: Optional[str] = Field(description="The board's new name.") board_name: Optional[str] = Field(description="The board's new name.")
cover_image_name: Optional[str] = Field( cover_image_name: Optional[str] = Field(description="The name of the board's new cover image.")
description="The name of the board's new cover image."
)
class BoardRecordNotFoundException(Exception): class BoardRecordNotFoundException(Exception):
@ -292,9 +290,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
count = cast(int, self._cursor.fetchone()[0]) count = cast(int, self._cursor.fetchone()[0])
return OffsetPaginatedResults[BoardRecord]( return OffsetPaginatedResults[BoardRecord](items=boards, offset=offset, limit=limit, total=count)
items=boards, offset=offset, limit=limit, total=count
)
except sqlite3.Error as e: except sqlite3.Error as e:
self._conn.rollback() self._conn.rollback()

View File

@ -108,16 +108,12 @@ class BoardService(BoardServiceABC):
def get_dto(self, board_id: str) -> BoardDTO: def get_dto(self, board_id: str) -> BoardDTO:
board_record = self._services.board_records.get(board_id) board_record = self._services.board_records.get(board_id)
cover_image = self._services.image_records.get_most_recent_image_for_board( cover_image = self._services.image_records.get_most_recent_image_for_board(board_record.board_id)
board_record.board_id
)
if cover_image: if cover_image:
cover_image_name = cover_image.image_name cover_image_name = cover_image.image_name
else: else:
cover_image_name = None cover_image_name = None
image_count = self._services.board_image_records.get_image_count_for_board( image_count = self._services.board_image_records.get_image_count_for_board(board_id)
board_id
)
return board_record_to_dto(board_record, cover_image_name, image_count) return board_record_to_dto(board_record, cover_image_name, image_count)
def update( def update(
@ -126,60 +122,44 @@ class BoardService(BoardServiceABC):
changes: BoardChanges, changes: BoardChanges,
) -> BoardDTO: ) -> BoardDTO:
board_record = self._services.board_records.update(board_id, changes) board_record = self._services.board_records.update(board_id, changes)
cover_image = self._services.image_records.get_most_recent_image_for_board( cover_image = self._services.image_records.get_most_recent_image_for_board(board_record.board_id)
board_record.board_id
)
if cover_image: if cover_image:
cover_image_name = cover_image.image_name cover_image_name = cover_image.image_name
else: else:
cover_image_name = None cover_image_name = None
image_count = self._services.board_image_records.get_image_count_for_board( image_count = self._services.board_image_records.get_image_count_for_board(board_id)
board_id
)
return board_record_to_dto(board_record, cover_image_name, image_count) return board_record_to_dto(board_record, cover_image_name, image_count)
def delete(self, board_id: str) -> None: def delete(self, board_id: str) -> None:
self._services.board_records.delete(board_id) self._services.board_records.delete(board_id)
def get_many( def get_many(self, offset: int = 0, limit: int = 10) -> OffsetPaginatedResults[BoardDTO]:
self, offset: int = 0, limit: int = 10
) -> OffsetPaginatedResults[BoardDTO]:
board_records = self._services.board_records.get_many(offset, limit) board_records = self._services.board_records.get_many(offset, limit)
board_dtos = [] board_dtos = []
for r in board_records.items: for r in board_records.items:
cover_image = self._services.image_records.get_most_recent_image_for_board( cover_image = self._services.image_records.get_most_recent_image_for_board(r.board_id)
r.board_id
)
if cover_image: if cover_image:
cover_image_name = cover_image.image_name cover_image_name = cover_image.image_name
else: else:
cover_image_name = None cover_image_name = None
image_count = self._services.board_image_records.get_image_count_for_board( image_count = self._services.board_image_records.get_image_count_for_board(r.board_id)
r.board_id
)
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count)) board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
return OffsetPaginatedResults[BoardDTO]( return OffsetPaginatedResults[BoardDTO](items=board_dtos, offset=offset, limit=limit, total=len(board_dtos))
items=board_dtos, offset=offset, limit=limit, total=len(board_dtos)
)
def get_all(self) -> list[BoardDTO]: def get_all(self) -> list[BoardDTO]:
board_records = self._services.board_records.get_all() board_records = self._services.board_records.get_all()
board_dtos = [] board_dtos = []
for r in board_records: for r in board_records:
cover_image = self._services.image_records.get_most_recent_image_for_board( cover_image = self._services.image_records.get_most_recent_image_for_board(r.board_id)
r.board_id
)
if cover_image: if cover_image:
cover_image_name = cover_image.image_name cover_image_name = cover_image.image_name
else: else:
cover_image_name = None cover_image_name = None
image_count = self._services.board_image_records.get_image_count_for_board( image_count = self._services.board_image_records.get_image_count_for_board(r.board_id)
r.board_id
)
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count)) board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
return board_dtos return board_dtos

View File

@ -1,6 +1,6 @@
# Copyright (c) 2023 Lincoln Stein (https://github.com/lstein) and the InvokeAI Development Team # Copyright (c) 2023 Lincoln Stein (https://github.com/lstein) and the InvokeAI Development Team
'''Invokeai configuration system. """Invokeai configuration system.
Arguments and fields are taken from the pydantic definition of the Arguments and fields are taken from the pydantic definition of the
model. Defaults can be set by creating a yaml configuration file that model. Defaults can be set by creating a yaml configuration file that
@ -28,7 +28,6 @@ InvokeAI:
always_use_cpu: false always_use_cpu: false
free_gpu_mem: false free_gpu_mem: false
Features: Features:
nsfw_checker: true
restore: true restore: true
esrgan: true esrgan: true
patchmatch: true patchmatch: true
@ -92,18 +91,18 @@ Typical usage at the top level file:
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
# get global configuration and print its nsfw_checker value # get global configuration and print its cache size
conf = InvokeAIAppConfig.get_config() conf = InvokeAIAppConfig.get_config()
conf.parse_args() conf.parse_args()
print(conf.nsfw_checker) print(conf.max_cache_size)
Typical usage in a backend module: Typical usage in a backend module:
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
# get global configuration and print its nsfw_checker value # get global configuration and print its cache size value
conf = InvokeAIAppConfig.get_config() conf = InvokeAIAppConfig.get_config()
print(conf.nsfw_checker) print(conf.max_cache_size)
Computed properties: Computed properties:
@ -159,7 +158,7 @@ two configs are kept in separate sections of the config file:
outdir: outputs outdir: outputs
... ...
''' """
from __future__ import annotations from __future__ import annotations
import argparse import argparse
import pydoc import pydoc
@ -171,16 +170,17 @@ from pathlib import Path
from pydantic import BaseSettings, Field, parse_obj_as from pydantic import BaseSettings, Field, parse_obj_as
from typing import ClassVar, Dict, List, Set, Literal, Union, get_origin, get_type_hints, get_args from typing import ClassVar, Dict, List, Set, Literal, Union, get_origin, get_type_hints, get_args
INIT_FILE = Path('invokeai.yaml') INIT_FILE = Path("invokeai.yaml")
MODEL_CORE = Path('models/core') DB_FILE = Path("invokeai.db")
DB_FILE = Path('invokeai.db') LEGACY_INIT_FILE = Path("invokeai.init")
LEGACY_INIT_FILE = Path('invokeai.init')
class InvokeAISettings(BaseSettings): class InvokeAISettings(BaseSettings):
''' """
Runtime configuration settings in which default values are Runtime configuration settings in which default values are
read from an omegaconf .yaml file. read from an omegaconf .yaml file.
''' """
initconf: ClassVar[DictConfig] = None initconf: ClassVar[DictConfig] = None
argparse_groups: ClassVar[Dict] = {} argparse_groups: ClassVar[Dict] = {}
@ -197,7 +197,7 @@ class InvokeAISettings(BaseSettings):
as the contents of `invokeai.yaml` to restore settings later. as the contents of `invokeai.yaml` to restore settings later.
""" """
cls = self.__class__ cls = self.__class__
type = get_args(get_type_hints(cls)['type'])[0] type = get_args(get_type_hints(cls)["type"])[0]
field_dict = dict({type: dict()}) field_dict = dict({type: dict()})
for name, field in self.__fields__.items(): for name, field in self.__fields__.items():
if name in cls._excluded_from_yaml(): if name in cls._excluded_from_yaml():
@ -213,16 +213,18 @@ class InvokeAISettings(BaseSettings):
@classmethod @classmethod
def add_parser_arguments(cls, parser): def add_parser_arguments(cls, parser):
if 'type' in get_type_hints(cls): if "type" in get_type_hints(cls):
settings_stanza = get_args(get_type_hints(cls)['type'])[0] settings_stanza = get_args(get_type_hints(cls)["type"])[0]
else: else:
settings_stanza = "Uncategorized" settings_stanza = "Uncategorized"
env_prefix = cls.Config.env_prefix if hasattr(cls.Config,'env_prefix') else settings_stanza.upper() env_prefix = cls.Config.env_prefix if hasattr(cls.Config, "env_prefix") else settings_stanza.upper()
initconf = cls.initconf.get(settings_stanza) \ initconf = (
if cls.initconf and settings_stanza in cls.initconf \ cls.initconf.get(settings_stanza)
if cls.initconf and settings_stanza in cls.initconf
else OmegaConf.create() else OmegaConf.create()
)
# create an upcase version of the environment in # create an upcase version of the environment in
# order to achieve case-insensitive environment # order to achieve case-insensitive environment
@ -239,7 +241,7 @@ class InvokeAISettings(BaseSettings):
current_default = field.default current_default = field.default
category = field.field_info.extra.get("category", "Uncategorized") category = field.field_info.extra.get("category", "Uncategorized")
env_name = env_prefix + '_' + name env_name = env_prefix + "_" + name
if category in initconf and name in initconf.get(category): if category in initconf and name in initconf.get(category):
field.default = initconf.get(category).get(name) field.default = initconf.get(category).get(name)
if env_name.upper() in upcase_environ: if env_name.upper() in upcase_environ:
@ -249,12 +251,12 @@ class InvokeAISettings(BaseSettings):
field.default = current_default field.default = current_default
@classmethod @classmethod
def cmd_name(self, command_field: str='type')->str: def cmd_name(self, command_field: str = "type") -> str:
hints = get_type_hints(self) hints = get_type_hints(self)
if command_field in hints: if command_field in hints:
return get_args(hints[command_field])[0] return get_args(hints[command_field])[0]
else: else:
return 'Uncategorized' return "Uncategorized"
@classmethod @classmethod
def get_parser(cls) -> ArgumentParser: def get_parser(cls) -> ArgumentParser:
@ -272,22 +274,40 @@ class InvokeAISettings(BaseSettings):
@classmethod @classmethod
def _excluded(self) -> List[str]: def _excluded(self) -> List[str]:
# internal fields that shouldn't be exposed as command line options # internal fields that shouldn't be exposed as command line options
return ['type','initconf'] return ["type", "initconf", "cached_root"]
@classmethod @classmethod
def _excluded_from_yaml(self) -> List[str]: def _excluded_from_yaml(self) -> List[str]:
# combination of deprecated parameters and internal ones that shouldn't be exposed as invokeai.yaml options # combination of deprecated parameters and internal ones that shouldn't be exposed as invokeai.yaml options
return ['type','initconf', 'gpu_mem_reserved', 'max_loaded_models', 'version', 'from_file', 'model', 'restore', 'root'] return [
"type",
"initconf",
"gpu_mem_reserved",
"max_loaded_models",
"version",
"from_file",
"model",
"restore",
"root",
"nsfw_checker",
"cached_root",
]
class Config: class Config:
env_file_encoding = 'utf-8' env_file_encoding = "utf-8"
arbitrary_types_allowed = True arbitrary_types_allowed = True
case_sensitive = True case_sensitive = True
@classmethod @classmethod
def add_field_argument(cls, command_parser, name: str, field, default_override=None): def add_field_argument(cls, command_parser, name: str, field, default_override=None):
field_type = get_type_hints(cls).get(name) field_type = get_type_hints(cls).get(name)
default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory() default = (
default_override
if default_override is not None
else field.default
if field.default_factory is None
else field.default_factory()
)
if category := field.field_info.extra.get("category"): if category := field.field_info.extra.get("category"):
if category not in cls.argparse_groups: if category not in cls.argparse_groups:
cls.argparse_groups[category] = command_parser.add_argument_group(category) cls.argparse_groups[category] = command_parser.add_argument_group(category)
@ -316,10 +336,10 @@ class InvokeAISettings(BaseSettings):
argparse_group.add_argument( argparse_group.add_argument(
f"--{name}", f"--{name}",
dest=name, dest=name,
nargs='*', nargs="*",
type=field.type_, type=field.type_,
default=default, default=default,
action=argparse.BooleanOptionalAction if field.type_==bool else 'store', action=argparse.BooleanOptionalAction if field.type_ == bool else "store",
help=field.field_info.description, help=field.field_info.description,
) )
else: else:
@ -328,27 +348,31 @@ class InvokeAISettings(BaseSettings):
dest=name, dest=name,
type=field.type_, type=field.type_,
default=default, default=default,
action=argparse.BooleanOptionalAction if field.type_==bool else 'store', action=argparse.BooleanOptionalAction if field.type_ == bool else "store",
help=field.field_info.description, help=field.field_info.description,
) )
def _find_root() -> Path: def _find_root() -> Path:
venv = Path(os.environ.get("VIRTUAL_ENV") or ".") venv = Path(os.environ.get("VIRTUAL_ENV") or ".")
if os.environ.get("INVOKEAI_ROOT"): if os.environ.get("INVOKEAI_ROOT"):
root = Path(os.environ.get("INVOKEAI_ROOT")).resolve() root = Path(os.environ.get("INVOKEAI_ROOT")).resolve()
elif any([(venv.parent/x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE, MODEL_CORE]]): elif any([(venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE]]):
root = (venv.parent).resolve() root = (venv.parent).resolve()
else: else:
root = Path("~/invokeai").expanduser().resolve() root = Path("~/invokeai").expanduser().resolve()
return root return root
class InvokeAIAppConfig(InvokeAISettings): class InvokeAIAppConfig(InvokeAISettings):
''' """
Generate images using Stable Diffusion. Use "invokeai" to launch Generate images using Stable Diffusion. Use "invokeai" to launch
the command-line client (recommended for experts only), or the command-line client (recommended for experts only), or
"invokeai-web" to launch the web server. Global options "invokeai-web" to launch the web server. Global options
can be changed by editing the file "INVOKEAI_ROOT/invokeai.yaml" or by can be changed by editing the file "INVOKEAI_ROOT/invokeai.yaml" or by
setting environment variables INVOKEAI_<setting>. setting environment variables INVOKEAI_<setting>.
''' """
singleton_config: ClassVar[InvokeAIAppConfig] = None singleton_config: ClassVar[InvokeAIAppConfig] = None
singleton_init: ClassVar[Dict] = None singleton_init: ClassVar[Dict] = None
@ -364,7 +388,6 @@ setting environment variables INVOKEAI_<setting>.
esrgan : bool = Field(default=True, description="Enable/disable upscaling code", category='Features') esrgan : bool = Field(default=True, description="Enable/disable upscaling code", category='Features')
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", category='Features') internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", category='Features')
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features') log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
nsfw_checker : bool = Field(default=True, description="Enable/disable the NSFW checker", category='Features')
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features') patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
restore : bool = Field(default=True, description="Enable/disable face restoration code (DEPRECATED)", category='DEPRECATED') restore : bool = Field(default=True, description="Enable/disable face restoration code (DEPRECATED)", category='DEPRECATED')
@ -374,6 +397,7 @@ setting environment variables INVOKEAI_<setting>.
max_cache_size : float = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching", category='Memory/Performance') max_cache_size : float = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching", category='Memory/Performance')
max_vram_cache_size : float = Field(default=2.75, ge=0, description="Amount of VRAM reserved for model storage", category='Memory/Performance') max_vram_cache_size : float = Field(default=2.75, ge=0, description="Amount of VRAM reserved for model storage", category='Memory/Performance')
gpu_mem_reserved : float = Field(default=2.75, ge=0, description="DEPRECATED: use max_vram_cache_size. Amount of VRAM reserved for model storage", category='DEPRECATED') gpu_mem_reserved : float = Field(default=2.75, ge=0, description="DEPRECATED: use max_vram_cache_size. Amount of VRAM reserved for model storage", category='DEPRECATED')
nsfw_checker : bool = Field(default=True, description="DEPRECATED: use Web settings to enable/disable", category='DEPRECATED')
precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='auto',description='Floating point precision', category='Memory/Performance') precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='auto',description='Floating point precision', category='Memory/Performance')
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance') sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance') xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
@ -400,16 +424,17 @@ setting environment variables INVOKEAI_<setting>.
log_level : Literal[tuple(["debug","info","warning","error","critical"])] = Field(default="info", description="Emit logging messages at this level or higher", category="Logging") log_level : Literal[tuple(["debug","info","warning","error","critical"])] = Field(default="info", description="Emit logging messages at this level or higher", category="Logging")
version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other") version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other")
cached_root : Path = Field(default=None, description="internal use only", category="DEPRECATED")
# fmt: on # fmt: on
def parse_args(self, argv: List[str] = None, conf: DictConfig = None, clobber=False): def parse_args(self, argv: List[str] = None, conf: DictConfig = None, clobber=False):
''' """
Update settings with contents of init file, environment, and Update settings with contents of init file, environment, and
command-line settings. command-line settings.
:param conf: alternate Omegaconf dictionary object :param conf: alternate Omegaconf dictionary object
:param argv: aternate sys.argv list :param argv: aternate sys.argv list
:param clobber: ovewrite any initialization parameters passed during initialization :param clobber: ovewrite any initialization parameters passed during initialization
''' """
# Set the runtime root directory. We parse command-line switches here # Set the runtime root directory. We parse command-line switches here
# in order to pick up the --root_dir option. # in order to pick up the --root_dir option.
super().parse_args(argv) super().parse_args(argv)
@ -430,31 +455,38 @@ setting environment variables INVOKEAI_<setting>.
@classmethod @classmethod
def get_config(cls, **kwargs) -> InvokeAIAppConfig: def get_config(cls, **kwargs) -> InvokeAIAppConfig:
''' """
This returns a singleton InvokeAIAppConfig configuration object. This returns a singleton InvokeAIAppConfig configuration object.
''' """
if cls.singleton_config is None \ if (
or type(cls.singleton_config)!=cls \ cls.singleton_config is None
or (kwargs and cls.singleton_init != kwargs): or type(cls.singleton_config) != cls
or (kwargs and cls.singleton_init != kwargs)
):
cls.singleton_config = cls(**kwargs) cls.singleton_config = cls(**kwargs)
cls.singleton_init = kwargs cls.singleton_init = kwargs
return cls.singleton_config return cls.singleton_config
@property @property
def root_path(self) -> Path: def root_path(self) -> Path:
''' """
Path to the runtime root directory Path to the runtime root directory
''' """
if self.root: # we cache value of root to protect against it being '.' and the cwd changing
return Path(self.root).expanduser().absolute() if self.cached_root:
root = self.cached_root
elif self.root:
root = Path(self.root).expanduser().absolute()
else: else:
return self.find_root() root = self.find_root()
self.cached_root = root
return self.cached_root
@property @property
def root_dir(self) -> Path: def root_dir(self) -> Path:
''' """
Alias for above. Alias for above.
''' """
return self.root_path return self.root_path
def _resolve(self, partial_path: Path) -> Path: def _resolve(self, partial_path: Path) -> Path:
@ -462,58 +494,58 @@ setting environment variables INVOKEAI_<setting>.
@property @property
def init_file_path(self) -> Path: def init_file_path(self) -> Path:
''' """
Path to invokeai.yaml Path to invokeai.yaml
''' """
return self._resolve(INIT_FILE) return self._resolve(INIT_FILE)
@property @property
def output_path(self) -> Path: def output_path(self) -> Path:
''' """
Path to defaults outputs directory. Path to defaults outputs directory.
''' """
return self._resolve(self.outdir) return self._resolve(self.outdir)
@property @property
def db_path(self) -> Path: def db_path(self) -> Path:
''' """
Path to the invokeai.db file. Path to the invokeai.db file.
''' """
return self._resolve(self.db_dir) / DB_FILE return self._resolve(self.db_dir) / DB_FILE
@property @property
def model_conf_path(self) -> Path: def model_conf_path(self) -> Path:
''' """
Path to models configuration file. Path to models configuration file.
''' """
return self._resolve(self.conf_path) return self._resolve(self.conf_path)
@property @property
def legacy_conf_path(self) -> Path: def legacy_conf_path(self) -> Path:
''' """
Path to directory of legacy configuration files (e.g. v1-inference.yaml) Path to directory of legacy configuration files (e.g. v1-inference.yaml)
''' """
return self._resolve(self.legacy_conf_dir) return self._resolve(self.legacy_conf_dir)
@property @property
def models_path(self) -> Path: def models_path(self) -> Path:
''' """
Path to the models directory Path to the models directory
''' """
return self._resolve(self.models_dir) return self._resolve(self.models_dir)
@property @property
def autoconvert_path(self) -> Path: def autoconvert_path(self) -> Path:
''' """
Path to the directory containing models to be imported automatically at startup. Path to the directory containing models to be imported automatically at startup.
''' """
return self._resolve(self.autoconvert_dir) if self.autoconvert_dir else None return self._resolve(self.autoconvert_dir) if self.autoconvert_dir else None
# the following methods support legacy calls leftover from the Globals era # the following methods support legacy calls leftover from the Globals era
@property @property
def full_precision(self) -> bool: def full_precision(self) -> bool:
"""Return true if precision set to float32""" """Return true if precision set to float32"""
return self.precision=='float32' return self.precision == "float32"
@property @property
def disable_xformers(self) -> bool: def disable_xformers(self) -> bool:
@ -525,26 +557,38 @@ setting environment variables INVOKEAI_<setting>.
"""Return true if patchmatch true""" """Return true if patchmatch true"""
return self.patchmatch return self.patchmatch
@property
def nsfw_checker(self) -> bool:
"""NSFW node is always active and disabled from Web UIe"""
return True
@property
def invisible_watermark(self) -> bool:
"""invisible watermark node is always active and disabled from Web UIe"""
return True
@staticmethod @staticmethod
def find_root() -> Path: def find_root() -> Path:
''' """
Choose the runtime root directory when not specified on command line or Choose the runtime root directory when not specified on command line or
init file. init file.
''' """
return _find_root() return _find_root()
class PagingArgumentParser(argparse.ArgumentParser): class PagingArgumentParser(argparse.ArgumentParser):
''' """
A custom ArgumentParser that uses pydoc to page its output. A custom ArgumentParser that uses pydoc to page its output.
It also supports reading defaults from an init file. It also supports reading defaults from an init file.
''' """
def print_help(self, file=None): def print_help(self, file=None):
text = self.format_help() text = self.format_help()
pydoc.pager(text) pydoc.pager(text)
def get_invokeai_config(**kwargs) -> InvokeAIAppConfig: def get_invokeai_config(**kwargs) -> InvokeAIAppConfig:
''' """
Legacy function which returns InvokeAIAppConfig.get_config() Legacy function which returns InvokeAIAppConfig.get_config()
''' """
return InvokeAIAppConfig.get_config(**kwargs) return InvokeAIAppConfig.get_config(**kwargs)

View File

@ -1,4 +1,5 @@
from ..invocations.latent import LatentsToImageInvocation, TextToLatentsInvocation from ..invocations.latent import LatentsToImageInvocation, TextToLatentsInvocation
from ..invocations.image import ImageNSFWBlurInvocation
from ..invocations.noise import NoiseInvocation from ..invocations.noise import NoiseInvocation
from ..invocations.compel import CompelInvocation from ..invocations.compel import CompelInvocation
from ..invocations.params import ParamIntInvocation from ..invocations.params import ParamIntInvocation
@ -6,45 +7,70 @@ from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Gr
from .item_storage import ItemStorageABC from .item_storage import ItemStorageABC
default_text_to_image_graph_id = '539b2af5-2b4d-4d8c-8071-e54a3255fc74' default_text_to_image_graph_id = "539b2af5-2b4d-4d8c-8071-e54a3255fc74"
def create_text_to_image() -> LibraryGraph: def create_text_to_image() -> LibraryGraph:
return LibraryGraph( return LibraryGraph(
id=default_text_to_image_graph_id, id=default_text_to_image_graph_id,
name='t2i', name="t2i",
description='Converts text to an image', description="Converts text to an image",
graph=Graph( graph=Graph(
nodes={ nodes={
'width': ParamIntInvocation(id='width', a=512), "width": ParamIntInvocation(id="width", a=512),
'height': ParamIntInvocation(id='height', a=512), "height": ParamIntInvocation(id="height", a=512),
'seed': ParamIntInvocation(id='seed', a=-1), "seed": ParamIntInvocation(id="seed", a=-1),
'3': NoiseInvocation(id='3'), "3": NoiseInvocation(id="3"),
'4': CompelInvocation(id='4'), "4": CompelInvocation(id="4"),
'5': CompelInvocation(id='5'), "5": CompelInvocation(id="5"),
'6': TextToLatentsInvocation(id='6'), "6": TextToLatentsInvocation(id="6"),
'7': LatentsToImageInvocation(id='7'), "7": LatentsToImageInvocation(id="7"),
"8": ImageNSFWBlurInvocation(id="8"),
}, },
edges=[ edges=[
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='3', field='width')), Edge(
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='3', field='height')), source=EdgeConnection(node_id="width", field="a"),
Edge(source=EdgeConnection(node_id='seed', field='a'), destination=EdgeConnection(node_id='3', field='seed')), destination=EdgeConnection(node_id="3", field="width"),
Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='6', field='noise')), ),
Edge(source=EdgeConnection(node_id='6', field='latents'), destination=EdgeConnection(node_id='7', field='latents')), Edge(
Edge(source=EdgeConnection(node_id='4', field='conditioning'), destination=EdgeConnection(node_id='6', field='positive_conditioning')), source=EdgeConnection(node_id="height", field="a"),
Edge(source=EdgeConnection(node_id='5', field='conditioning'), destination=EdgeConnection(node_id='6', field='negative_conditioning')), destination=EdgeConnection(node_id="3", field="height"),
] ),
Edge(
source=EdgeConnection(node_id="seed", field="a"),
destination=EdgeConnection(node_id="3", field="seed"),
),
Edge(
source=EdgeConnection(node_id="3", field="noise"),
destination=EdgeConnection(node_id="6", field="noise"),
),
Edge(
source=EdgeConnection(node_id="6", field="latents"),
destination=EdgeConnection(node_id="7", field="latents"),
),
Edge(
source=EdgeConnection(node_id="4", field="conditioning"),
destination=EdgeConnection(node_id="6", field="positive_conditioning"),
),
Edge(
source=EdgeConnection(node_id="5", field="conditioning"),
destination=EdgeConnection(node_id="6", field="negative_conditioning"),
),
Edge(
source=EdgeConnection(node_id="7", field="image"),
destination=EdgeConnection(node_id="8", field="image"),
),
],
), ),
exposed_inputs=[ exposed_inputs=[
ExposedNodeInput(node_path='4', field='prompt', alias='positive_prompt'), ExposedNodeInput(node_path="4", field="prompt", alias="positive_prompt"),
ExposedNodeInput(node_path='5', field='prompt', alias='negative_prompt'), ExposedNodeInput(node_path="5", field="prompt", alias="negative_prompt"),
ExposedNodeInput(node_path='width', field='a', alias='width'), ExposedNodeInput(node_path="width", field="a", alias="width"),
ExposedNodeInput(node_path='height', field='a', alias='height'), ExposedNodeInput(node_path="height", field="a", alias="height"),
ExposedNodeInput(node_path='seed', field='a', alias='seed'), ExposedNodeInput(node_path="seed", field="a", alias="seed"),
], ],
exposed_outputs=[ exposed_outputs=[ExposedNodeOutput(node_path="8", field="image", alias="image")],
ExposedNodeOutput(node_path='7', field='image', alias='image') )
])
def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[LibraryGraph]: def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[LibraryGraph]:

View File

@ -44,9 +44,7 @@ class EventServiceBase:
graph_execution_state_id=graph_execution_state_id, graph_execution_state_id=graph_execution_state_id,
node=node, node=node,
source_node_id=source_node_id, source_node_id=source_node_id,
progress_image=progress_image.dict() progress_image=progress_image.dict() if progress_image is not None else None,
if progress_image is not None
else None,
step=step, step=step,
total_steps=total_steps, total_steps=total_steps,
), ),
@ -90,9 +88,7 @@ class EventServiceBase:
), ),
) )
def emit_invocation_started( def emit_invocation_started(self, graph_execution_state_id: str, node: dict, source_node_id: str) -> None:
self, graph_execution_state_id: str, node: dict, source_node_id: str
) -> None:
"""Emitted when an invocation has started""" """Emitted when an invocation has started"""
self.__emit_session_event( self.__emit_session_event(
event_name="invocation_started", event_name="invocation_started",

View File

@ -28,6 +28,7 @@ from ..invocations.baseinvocation import (
# in 3.10 this would be "from types import NoneType" # in 3.10 this would be "from types import NoneType"
NoneType = type(None) NoneType = type(None)
class EdgeConnection(BaseModel): class EdgeConnection(BaseModel):
node_id: str = Field(description="The id of the node for this edge connection") node_id: str = Field(description="The id of the node for this edge connection")
field: str = Field(description="The field for this connection") field: str = Field(description="The field for this connection")
@ -61,6 +62,7 @@ def get_input_field(node: BaseInvocation, field: str) -> Any:
node_input_field = node_inputs.get(field) or None node_input_field = node_inputs.get(field) or None
return node_input_field return node_input_field
def is_union_subtype(t1, t2): def is_union_subtype(t1, t2):
t1_args = get_args(t1) t1_args = get_args(t1)
t2_args = get_args(t2) t2_args = get_args(t2)
@ -71,6 +73,7 @@ def is_union_subtype(t1, t2):
# t1 is a Union, check that all of its types are in t2_args # t1 is a Union, check that all of its types are in t2_args
return all(arg in t2_args for arg in t1_args) return all(arg in t2_args for arg in t1_args)
def is_list_or_contains_list(t): def is_list_or_contains_list(t):
t_args = get_args(t) t_args = get_args(t)
@ -154,15 +157,17 @@ class GraphInvocationOutput(BaseInvocationOutput):
class Config: class Config:
schema_extra = { schema_extra = {
'required': [ "required": [
'type', "type",
'image', "image",
] ]
} }
# TODO: Fill this out and move to invocations # TODO: Fill this out and move to invocations
class GraphInvocation(BaseInvocation): class GraphInvocation(BaseInvocation):
"""Execute a graph""" """Execute a graph"""
type: Literal["graph"] = "graph" type: Literal["graph"] = "graph"
# TODO: figure out how to create a default here # TODO: figure out how to create a default here
@ -182,23 +187,21 @@ class IterateInvocationOutput(BaseInvocationOutput):
class Config: class Config:
schema_extra = { schema_extra = {
'required': [ "required": [
'type', "type",
'item', "item",
] ]
} }
# TODO: Fill this out and move to invocations # TODO: Fill this out and move to invocations
class IterateInvocation(BaseInvocation): class IterateInvocation(BaseInvocation):
"""Iterates over a list of items""" """Iterates over a list of items"""
type: Literal["iterate"] = "iterate" type: Literal["iterate"] = "iterate"
collection: list[Any] = Field( collection: list[Any] = Field(description="The list of items to iterate over", default_factory=list)
description="The list of items to iterate over", default_factory=list index: int = Field(description="The index, will be provided on executed iterators", default=0)
)
index: int = Field(
description="The index, will be provided on executed iterators", default=0
)
def invoke(self, context: InvocationContext) -> IterateInvocationOutput: def invoke(self, context: InvocationContext) -> IterateInvocationOutput:
"""Produces the outputs as values""" """Produces the outputs as values"""
@ -212,12 +215,13 @@ class CollectInvocationOutput(BaseInvocationOutput):
class Config: class Config:
schema_extra = { schema_extra = {
'required': [ "required": [
'type', "type",
'collection', "collection",
] ]
} }
class CollectInvocation(BaseInvocation): class CollectInvocation(BaseInvocation):
"""Collects values into a collection""" """Collects values into a collection"""
@ -279,9 +283,7 @@ class Graph(BaseModel):
if node_path in self.nodes: if node_path in self.nodes:
return (self, node_path) return (self, node_path)
node_id = ( node_id = node_path if "." not in node_path else node_path[: node_path.index(".")]
node_path if "." not in node_path else node_path[: node_path.index(".")]
)
if node_id not in self.nodes: if node_id not in self.nodes:
raise NodeNotFoundError(f"Node {node_path} not found in graph") raise NodeNotFoundError(f"Node {node_path} not found in graph")
@ -343,9 +345,7 @@ class Graph(BaseModel):
return False return False
# Validate all edges reference nodes in the graph # Validate all edges reference nodes in the graph
node_ids = set( node_ids = set([e.source.node_id for e in self.edges] + [e.destination.node_id for e in self.edges])
[e.source.node_id for e in self.edges] + [e.destination.node_id for e in self.edges]
)
if not all((self.has_node(node_id) for node_id in node_ids)): if not all((self.has_node(node_id) for node_id in node_ids)):
return False return False
@ -371,22 +371,14 @@ class Graph(BaseModel):
# Validate all iterators # Validate all iterators
# TODO: may need to validate all iterators in subgraphs so edge connections in parent graphs will be available # TODO: may need to validate all iterators in subgraphs so edge connections in parent graphs will be available
if not all( if not all(
( (self._is_iterator_connection_valid(n.id) for n in self.nodes.values() if isinstance(n, IterateInvocation))
self._is_iterator_connection_valid(n.id)
for n in self.nodes.values()
if isinstance(n, IterateInvocation)
)
): ):
return False return False
# Validate all collectors # Validate all collectors
# TODO: may need to validate all collectors in subgraphs so edge connections in parent graphs will be available # TODO: may need to validate all collectors in subgraphs so edge connections in parent graphs will be available
if not all( if not all(
( (self._is_collector_connection_valid(n.id) for n in self.nodes.values() if isinstance(n, CollectInvocation))
self._is_collector_connection_valid(n.id)
for n in self.nodes.values()
if isinstance(n, CollectInvocation)
)
): ):
return False return False
@ -405,48 +397,51 @@ class Graph(BaseModel):
# Validate that an edge to this node+field doesn't already exist # Validate that an edge to this node+field doesn't already exist
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field) input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation): if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation):
raise InvalidEdgeError(f'Edge to node {edge.destination.node_id} field {edge.destination.field} already exists') raise InvalidEdgeError(
f"Edge to node {edge.destination.node_id} field {edge.destination.field} already exists"
)
# Validate that no cycles would be created # Validate that no cycles would be created
g = self.nx_graph_flat() g = self.nx_graph_flat()
g.add_edge(edge.source.node_id, edge.destination.node_id) g.add_edge(edge.source.node_id, edge.destination.node_id)
if not nx.is_directed_acyclic_graph(g): if not nx.is_directed_acyclic_graph(g):
raise InvalidEdgeError(f'Edge creates a cycle in the graph: {edge.source.node_id} -> {edge.destination.node_id}') raise InvalidEdgeError(
f"Edge creates a cycle in the graph: {edge.source.node_id} -> {edge.destination.node_id}"
)
# Validate that the field types are compatible # Validate that the field types are compatible
if not are_connections_compatible( if not are_connections_compatible(from_node, edge.source.field, to_node, edge.destination.field):
from_node, edge.source.field, to_node, edge.destination.field raise InvalidEdgeError(
): f"Fields are incompatible: cannot connect {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
raise InvalidEdgeError(f'Fields are incompatible: cannot connect {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}') )
# Validate if iterator output type matches iterator input type (if this edge results in both being set) # Validate if iterator output type matches iterator input type (if this edge results in both being set)
if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection": if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
if not self._is_iterator_connection_valid( if not self._is_iterator_connection_valid(edge.destination.node_id, new_input=edge.source):
edge.destination.node_id, new_input=edge.source raise InvalidEdgeError(
): f"Iterator input type does not match iterator output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
raise InvalidEdgeError(f'Iterator input type does not match iterator output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}') )
# Validate if iterator input type matches output type (if this edge results in both being set) # Validate if iterator input type matches output type (if this edge results in both being set)
if isinstance(from_node, IterateInvocation) and edge.source.field == "item": if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
if not self._is_iterator_connection_valid( if not self._is_iterator_connection_valid(edge.source.node_id, new_output=edge.destination):
edge.source.node_id, new_output=edge.destination raise InvalidEdgeError(
): f"Iterator output type does not match iterator input type:, {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
raise InvalidEdgeError(f'Iterator output type does not match iterator input type:, {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}') )
# Validate if collector input type matches output type (if this edge results in both being set) # Validate if collector input type matches output type (if this edge results in both being set)
if isinstance(to_node, CollectInvocation) and edge.destination.field == "item": if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
if not self._is_collector_connection_valid( if not self._is_collector_connection_valid(edge.destination.node_id, new_input=edge.source):
edge.destination.node_id, new_input=edge.source raise InvalidEdgeError(
): f"Collector output type does not match collector input type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
raise InvalidEdgeError(f'Collector output type does not match collector input type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}') )
# Validate if collector output type matches input type (if this edge results in both being set) # Validate if collector output type matches input type (if this edge results in both being set)
if isinstance(from_node, CollectInvocation) and edge.source.field == "collection": if isinstance(from_node, CollectInvocation) and edge.source.field == "collection":
if not self._is_collector_connection_valid( if not self._is_collector_connection_valid(edge.source.node_id, new_output=edge.destination):
edge.source.node_id, new_output=edge.destination raise InvalidEdgeError(
): f"Collector input type does not match collector output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
raise InvalidEdgeError(f'Collector input type does not match collector output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}') )
def has_node(self, node_path: str) -> bool: def has_node(self, node_path: str) -> bool:
"""Determines whether or not a node exists in the graph.""" """Determines whether or not a node exists in the graph."""
@ -475,17 +470,13 @@ class Graph(BaseModel):
# Ensure the node type matches the new node # Ensure the node type matches the new node
if type(node) != type(new_node): if type(node) != type(new_node):
raise TypeError( raise TypeError(f"Node {node_path} is type {type(node)} but new node is type {type(new_node)}")
f"Node {node_path} is type {type(node)} but new node is type {type(new_node)}"
)
# Ensure the new id is either the same or is not in the graph # Ensure the new id is either the same or is not in the graph
prefix = None if "." not in node_path else node_path[: node_path.rindex(".")] prefix = None if "." not in node_path else node_path[: node_path.rindex(".")]
new_path = self._get_node_path(new_node.id, prefix=prefix) new_path = self._get_node_path(new_node.id, prefix=prefix)
if new_node.id != node.id and self.has_node(new_path): if new_node.id != node.id and self.has_node(new_path):
raise NodeAlreadyInGraphError( raise NodeAlreadyInGraphError("Node with id {new_node.id} already exists in graph")
"Node with id {new_node.id} already exists in graph"
)
# Set the new node in the graph # Set the new node in the graph
graph.nodes[new_node.id] = new_node graph.nodes[new_node.id] = new_node
@ -507,9 +498,7 @@ class Graph(BaseModel):
graph.add_edge( graph.add_edge(
Edge( Edge(
source=edge.source, source=edge.source,
destination=EdgeConnection( destination=EdgeConnection(node_id=new_graph_node_path, field=edge.destination.field),
node_id=new_graph_node_path, field=edge.destination.field
)
) )
) )
@ -522,16 +511,12 @@ class Graph(BaseModel):
) )
graph.add_edge( graph.add_edge(
Edge( Edge(
source=EdgeConnection( source=EdgeConnection(node_id=new_graph_node_path, field=edge.source.field),
node_id=new_graph_node_path, field=edge.source.field destination=edge.destination,
),
destination=edge.destination
) )
) )
def _get_input_edges( def _get_input_edges(self, node_path: str, field: Optional[str] = None) -> list[Edge]:
self, node_path: str, field: Optional[str] = None
) -> list[Edge]:
"""Gets all input edges for a node""" """Gets all input edges for a node"""
edges = self._get_input_edges_and_graphs(node_path) edges = self._get_input_edges_and_graphs(node_path)
@ -548,7 +533,7 @@ class Graph(BaseModel):
destination=EdgeConnection( destination=EdgeConnection(
node_id=self._get_node_path(e.destination.node_id, prefix=prefix), node_id=self._get_node_path(e.destination.node_id, prefix=prefix),
field=e.destination.field, field=e.destination.field,
) ),
) )
for _, prefix, e in filtered_edges for _, prefix, e in filtered_edges
] ]
@ -560,32 +545,20 @@ class Graph(BaseModel):
edges = list() edges = list()
# Return any input edges that appear in this graph # Return any input edges that appear in this graph
edges.extend( edges.extend([(self, prefix, e) for e in self.edges if e.destination.node_id == node_path])
[(self, prefix, e) for e in self.edges if e.destination.node_id == node_path]
)
node_id = ( node_id = node_path if "." not in node_path else node_path[: node_path.index(".")]
node_path if "." not in node_path else node_path[: node_path.index(".")]
)
node = self.nodes[node_id] node = self.nodes[node_id]
if isinstance(node, GraphInvocation): if isinstance(node, GraphInvocation):
graph = node.graph graph = node.graph
graph_path = ( graph_path = node.id if prefix is None or prefix == "" else self._get_node_path(node.id, prefix=prefix)
node.id graph_edges = graph._get_input_edges_and_graphs(node_path[(len(node_id) + 1) :], prefix=graph_path)
if prefix is None or prefix == ""
else self._get_node_path(node.id, prefix=prefix)
)
graph_edges = graph._get_input_edges_and_graphs(
node_path[(len(node_id) + 1) :], prefix=graph_path
)
edges.extend(graph_edges) edges.extend(graph_edges)
return edges return edges
def _get_output_edges( def _get_output_edges(self, node_path: str, field: str) -> list[Edge]:
self, node_path: str, field: str
) -> list[Edge]:
"""Gets all output edges for a node""" """Gets all output edges for a node"""
edges = self._get_output_edges_and_graphs(node_path) edges = self._get_output_edges_and_graphs(node_path)
@ -602,7 +575,7 @@ class Graph(BaseModel):
destination=EdgeConnection( destination=EdgeConnection(
node_id=self._get_node_path(e.destination.node_id, prefix=prefix), node_id=self._get_node_path(e.destination.node_id, prefix=prefix),
field=e.destination.field, field=e.destination.field,
) ),
) )
for _, prefix, e in filtered_edges for _, prefix, e in filtered_edges
] ]
@ -614,25 +587,15 @@ class Graph(BaseModel):
edges = list() edges = list()
# Return any input edges that appear in this graph # Return any input edges that appear in this graph
edges.extend( edges.extend([(self, prefix, e) for e in self.edges if e.source.node_id == node_path])
[(self, prefix, e) for e in self.edges if e.source.node_id == node_path]
)
node_id = ( node_id = node_path if "." not in node_path else node_path[: node_path.index(".")]
node_path if "." not in node_path else node_path[: node_path.index(".")]
)
node = self.nodes[node_id] node = self.nodes[node_id]
if isinstance(node, GraphInvocation): if isinstance(node, GraphInvocation):
graph = node.graph graph = node.graph
graph_path = ( graph_path = node.id if prefix is None or prefix == "" else self._get_node_path(node.id, prefix=prefix)
node.id graph_edges = graph._get_output_edges_and_graphs(node_path[(len(node_id) + 1) :], prefix=graph_path)
if prefix is None or prefix == ""
else self._get_node_path(node.id, prefix=prefix)
)
graph_edges = graph._get_output_edges_and_graphs(
node_path[(len(node_id) + 1) :], prefix=graph_path
)
edges.extend(graph_edges) edges.extend(graph_edges)
return edges return edges
@ -656,12 +619,8 @@ class Graph(BaseModel):
return False return False
# Get input and output fields (the fields linked to the iterator's input/output) # Get input and output fields (the fields linked to the iterator's input/output)
input_field = get_output_field( input_field = get_output_field(self.get_node(inputs[0].node_id), inputs[0].field)
self.get_node(inputs[0].node_id), inputs[0].field output_fields = list([get_input_field(self.get_node(e.node_id), e.field) for e in outputs])
)
output_fields = list(
[get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
)
# Input type must be a list # Input type must be a list
if get_origin(input_field) != list: if get_origin(input_field) != list:
@ -669,12 +628,7 @@ class Graph(BaseModel):
# Validate that all outputs match the input type # Validate that all outputs match the input type
input_field_item_type = get_args(input_field)[0] input_field_item_type = get_args(input_field)[0]
if not all( if not all((are_connection_types_compatible(input_field_item_type, f) for f in output_fields)):
(
are_connection_types_compatible(input_field_item_type, f)
for f in output_fields
)
):
return False return False
return True return True
@ -694,35 +648,21 @@ class Graph(BaseModel):
outputs.append(new_output) outputs.append(new_output)
# Get input and output fields (the fields linked to the iterator's input/output) # Get input and output fields (the fields linked to the iterator's input/output)
input_fields = list( input_fields = list([get_output_field(self.get_node(e.node_id), e.field) for e in inputs])
[get_output_field(self.get_node(e.node_id), e.field) for e in inputs] output_fields = list([get_input_field(self.get_node(e.node_id), e.field) for e in outputs])
)
output_fields = list(
[get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
)
# Validate that all inputs are derived from or match a single type # Validate that all inputs are derived from or match a single type
input_field_types = set( input_field_types = set(
[ [
t t
for input_field in input_fields for input_field in input_fields
for t in ( for t in ([input_field] if get_origin(input_field) == None else get_args(input_field))
[input_field]
if get_origin(input_field) == None
else get_args(input_field)
)
if t != NoneType if t != NoneType
] ]
) # Get unique types ) # Get unique types
type_tree = nx.DiGraph() type_tree = nx.DiGraph()
type_tree.add_nodes_from(input_field_types) type_tree.add_nodes_from(input_field_types)
type_tree.add_edges_from( type_tree.add_edges_from([e for e in itertools.permutations(input_field_types, 2) if issubclass(e[1], e[0])])
[
e
for e in itertools.permutations(input_field_types, 2)
if issubclass(e[1], e[0])
]
)
type_degrees = type_tree.in_degree(type_tree.nodes) type_degrees = type_tree.in_degree(type_tree.nodes)
if sum((t[1] == 0 for t in type_degrees)) != 1: # type: ignore if sum((t[1] == 0 for t in type_degrees)) != 1: # type: ignore
return False # There is more than one root type return False # There is more than one root type
@ -739,9 +679,7 @@ class Graph(BaseModel):
return False return False
# Verify that all outputs match the input type (are a base class or the same class) # Verify that all outputs match the input type (are a base class or the same class)
if not all( if not all((issubclass(input_root_type, get_args(f)[0]) for f in output_fields)):
(issubclass(input_root_type, get_args(f)[0]) for f in output_fields)
):
return False return False
return True return True
@ -761,9 +699,7 @@ class Graph(BaseModel):
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges])) g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
return g return g
def nx_graph_flat( def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None) -> nx.DiGraph:
self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None
) -> nx.DiGraph:
"""Returns a flattened NetworkX DiGraph, including all subgraphs (but not with iterations expanded)""" """Returns a flattened NetworkX DiGraph, including all subgraphs (but not with iterations expanded)"""
g = nx_graph or nx.DiGraph() g = nx_graph or nx.DiGraph()
@ -772,26 +708,18 @@ class Graph(BaseModel):
[ [
self._get_node_path(n.id, prefix) self._get_node_path(n.id, prefix)
for n in self.nodes.values() for n in self.nodes.values()
if not isinstance(n, GraphInvocation) if not isinstance(n, GraphInvocation) and not isinstance(n, IterateInvocation)
and not isinstance(n, IterateInvocation)
] ]
) )
# Expand graph nodes # Expand graph nodes
for sgn in ( for sgn in (gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation)):
gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation)
):
g = sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix)) g = sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix))
# TODO: figure out if iteration nodes need to be expanded # TODO: figure out if iteration nodes need to be expanded
unique_edges = set([(e.source.node_id, e.destination.node_id) for e in self.edges]) unique_edges = set([(e.source.node_id, e.destination.node_id) for e in self.edges])
g.add_edges_from( g.add_edges_from([(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix)) for e in unique_edges])
[
(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix))
for e in unique_edges
]
)
return g return g
@ -810,9 +738,7 @@ class GraphExecutionState(BaseModel):
) )
# Nodes that have been executed # Nodes that have been executed
executed: set[str] = Field( executed: set[str] = Field(description="The set of node ids that have been executed", default_factory=set)
description="The set of node ids that have been executed", default_factory=set
)
executed_history: list[str] = Field( executed_history: list[str] = Field(
description="The list of node ids that have been executed, in order of execution", description="The list of node ids that have been executed, in order of execution",
default_factory=list, default_factory=list,
@ -821,14 +747,12 @@ class GraphExecutionState(BaseModel):
batch_indices: list[int] = Field(description="Tracker for which batch is currently being processed", default_factory=list) batch_indices: list[int] = Field(description="Tracker for which batch is currently being processed", default_factory=list)
# The results of executed nodes # The results of executed nodes
results: dict[ results: dict[str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]] = Field(
str, Annotated[InvocationOutputsUnion, Field(discriminator="type")] description="The results of node executions", default_factory=dict
] = Field(description="The results of node executions", default_factory=dict) )
# Errors raised when executing nodes # Errors raised when executing nodes
errors: dict[str, str] = Field( errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict)
description="Errors raised when executing nodes", default_factory=dict
)
# Map of prepared/executed nodes to their original nodes # Map of prepared/executed nodes to their original nodes
prepared_source_mapping: dict[str, str] = Field( prepared_source_mapping: dict[str, str] = Field(
@ -844,16 +768,16 @@ class GraphExecutionState(BaseModel):
class Config: class Config:
schema_extra = { schema_extra = {
'required': [ "required": [
'id', "id",
'graph', "graph",
'execution_graph', "execution_graph",
'executed', "executed",
'executed_history', "executed_history",
'results', "results",
'errors', "errors",
'prepared_source_mapping', "prepared_source_mapping",
'source_prepared_mapping', "source_prepared_mapping",
] ]
} }
@ -912,9 +836,7 @@ class GraphExecutionState(BaseModel):
"""Returns true if the graph has any errors""" """Returns true if the graph has any errors"""
return len(self.errors) > 0 return len(self.errors) > 0
def _create_execution_node( def _create_execution_node(self, node_path: str, iteration_node_map: list[tuple[str, str]]) -> list[str]:
self, node_path: str, iteration_node_map: list[tuple[str, str]]
) -> list[str]:
"""Prepares an iteration node and connects all edges, returning the new node id""" """Prepares an iteration node and connects all edges, returning the new node id"""
node = self.graph.get_node(node_path) node = self.graph.get_node(node_path)
@ -924,20 +846,12 @@ class GraphExecutionState(BaseModel):
# If this is an iterator node, we must create a copy for each iteration # If this is an iterator node, we must create a copy for each iteration
if isinstance(node, IterateInvocation): if isinstance(node, IterateInvocation):
# Get input collection edge (should error if there are no inputs) # Get input collection edge (should error if there are no inputs)
input_collection_edge = next( input_collection_edge = next(iter(self.graph._get_input_edges(node_path, "collection")))
iter(self.graph._get_input_edges(node_path, "collection"))
)
input_collection_prepared_node_id = next( input_collection_prepared_node_id = next(
n[1] n[1] for n in iteration_node_map if n[0] == input_collection_edge.source.node_id
for n in iteration_node_map
if n[0] == input_collection_edge.source.node_id
)
input_collection_prepared_node_output = self.results[
input_collection_prepared_node_id
]
input_collection = getattr(
input_collection_prepared_node_output, input_collection_edge.source.field
) )
input_collection_prepared_node_output = self.results[input_collection_prepared_node_id]
input_collection = getattr(input_collection_prepared_node_output, input_collection_edge.source.field)
self_iteration_count = len(input_collection) self_iteration_count = len(input_collection)
new_nodes = list() new_nodes = list()
@ -952,9 +866,7 @@ class GraphExecutionState(BaseModel):
# For collect nodes, this may contain multiple inputs to the same field # For collect nodes, this may contain multiple inputs to the same field
new_edges = list() new_edges = list()
for edge in input_edges: for edge in input_edges:
for input_node_id in ( for input_node_id in (n[1] for n in iteration_node_map if n[0] == edge.source.node_id):
n[1] for n in iteration_node_map if n[0] == edge.source.node_id
):
new_edge = Edge( new_edge = Edge(
source=EdgeConnection(node_id=input_node_id, field=edge.source.field), source=EdgeConnection(node_id=input_node_id, field=edge.source.field),
destination=EdgeConnection(node_id="", field=edge.destination.field), destination=EdgeConnection(node_id="", field=edge.destination.field),
@ -995,11 +907,7 @@ class GraphExecutionState(BaseModel):
def _iterator_graph(self) -> nx.DiGraph: def _iterator_graph(self) -> nx.DiGraph:
"""Gets a DiGraph with edges to collectors removed so an ancestor search produces all active iterators for any node""" """Gets a DiGraph with edges to collectors removed so an ancestor search produces all active iterators for any node"""
g = self.graph.nx_graph_flat() g = self.graph.nx_graph_flat()
collectors = ( collectors = (n for n in self.graph.nodes if isinstance(self.graph.get_node(n), CollectInvocation))
n
for n in self.graph.nodes
if isinstance(self.graph.get_node(n), CollectInvocation)
)
for c in collectors: for c in collectors:
g.remove_edges_from(list(g.in_edges(c))) g.remove_edges_from(list(g.in_edges(c)))
return g return g
@ -1007,11 +915,7 @@ class GraphExecutionState(BaseModel):
def _get_node_iterators(self, node_id: str) -> list[str]: def _get_node_iterators(self, node_id: str) -> list[str]:
"""Gets iterators for a node""" """Gets iterators for a node"""
g = self._iterator_graph() g = self._iterator_graph()
iterators = [ iterators = [n for n in nx.ancestors(g, node_id) if isinstance(self.graph.get_node(n), IterateInvocation)]
n
for n in nx.ancestors(g, node_id)
if isinstance(self.graph.get_node(n), IterateInvocation)
]
return iterators return iterators
def _apply_batch_config(self): def _apply_batch_config(self):
@ -1071,29 +975,18 @@ class GraphExecutionState(BaseModel):
if isinstance(next_node, CollectInvocation): if isinstance(next_node, CollectInvocation):
# Collapse all iterator input mappings and create a single execution node for the collect invocation # Collapse all iterator input mappings and create a single execution node for the collect invocation
all_iteration_mappings = list( all_iteration_mappings = list(
itertools.chain( itertools.chain(*(((s, p) for p in self.source_prepared_mapping[s]) for s in next_node_parents))
*(
((s, p) for p in self.source_prepared_mapping[s])
for s in next_node_parents
)
)
) )
# all_iteration_mappings = list(set(itertools.chain(*prepared_parent_mappings))) # all_iteration_mappings = list(set(itertools.chain(*prepared_parent_mappings)))
create_results = self._create_execution_node( create_results = self._create_execution_node(next_node_id, all_iteration_mappings)
next_node_id, all_iteration_mappings
)
if create_results is not None: if create_results is not None:
new_node_ids.extend(create_results) new_node_ids.extend(create_results)
else: # Iterators or normal nodes else: # Iterators or normal nodes
# Get all iterator combinations for this node # Get all iterator combinations for this node
# Will produce a list of lists of prepared iterator nodes, from which results can be iterated # Will produce a list of lists of prepared iterator nodes, from which results can be iterated
iterator_nodes = self._get_node_iterators(next_node_id) iterator_nodes = self._get_node_iterators(next_node_id)
iterator_nodes_prepared = [ iterator_nodes_prepared = [list(self.source_prepared_mapping[n]) for n in iterator_nodes]
list(self.source_prepared_mapping[n]) for n in iterator_nodes iterator_node_prepared_combinations = list(itertools.product(*iterator_nodes_prepared))
]
iterator_node_prepared_combinations = list(
itertools.product(*iterator_nodes_prepared)
)
# Select the correct prepared parents for each iteration # Select the correct prepared parents for each iteration
# For every iterator, the parent must either not be a child of that iterator, or must match the prepared iteration for that iterator # For every iterator, the parent must either not be a child of that iterator, or must match the prepared iteration for that iterator
@ -1122,31 +1015,16 @@ class GraphExecutionState(BaseModel):
return next(iter(prepared_nodes)) return next(iter(prepared_nodes))
# Check if the requested node is an iterator # Check if the requested node is an iterator
prepared_iterator = next( prepared_iterator = next((n for n in prepared_nodes if n in prepared_iterator_nodes), None)
(n for n in prepared_nodes if n in prepared_iterator_nodes), None
)
if prepared_iterator is not None: if prepared_iterator is not None:
return prepared_iterator return prepared_iterator
# Filter to only iterator nodes that are a parent of the specified node, in tuple format (prepared, source) # Filter to only iterator nodes that are a parent of the specified node, in tuple format (prepared, source)
iterator_source_node_mapping = [ iterator_source_node_mapping = [(n, self.prepared_source_mapping[n]) for n in prepared_iterator_nodes]
(n, self.prepared_source_mapping[n]) for n in prepared_iterator_nodes parent_iterators = [itn for itn in iterator_source_node_mapping if nx.has_path(graph, itn[1], source_node_path)]
]
parent_iterators = [
itn
for itn in iterator_source_node_mapping
if nx.has_path(graph, itn[1], source_node_path)
]
return next( return next(
( (n for n in prepared_nodes if all(nx.has_path(execution_graph, pit[0], n) for pit in parent_iterators)),
n
for n in prepared_nodes
if all(
nx.has_path(execution_graph, pit[0], n)
for pit in parent_iterators
)
),
None, None,
) )
@ -1247,15 +1125,18 @@ class ExposedNodeOutput(BaseModel):
field: str = Field(description="The field name of the output") field: str = Field(description="The field name of the output")
alias: str = Field(description="The alias of the output") alias: str = Field(description="The alias of the output")
class LibraryGraph(BaseModel): class LibraryGraph(BaseModel):
id: str = Field(description="The unique identifier for this library graph", default_factory=uuid.uuid4) id: str = Field(description="The unique identifier for this library graph", default_factory=uuid.uuid4)
graph: Graph = Field(description="The graph") graph: Graph = Field(description="The graph")
name: str = Field(description="The name of the graph") name: str = Field(description="The name of the graph")
description: str = Field(description="The description of the graph") description: str = Field(description="The description of the graph")
exposed_inputs: list[ExposedNodeInput] = Field(description="The inputs exposed by this graph", default_factory=list) exposed_inputs: list[ExposedNodeInput] = Field(description="The inputs exposed by this graph", default_factory=list)
exposed_outputs: list[ExposedNodeOutput] = Field(description="The outputs exposed by this graph", default_factory=list) exposed_outputs: list[ExposedNodeOutput] = Field(
description="The outputs exposed by this graph", default_factory=list
)
@validator('exposed_inputs', 'exposed_outputs') @validator("exposed_inputs", "exposed_outputs")
def validate_exposed_aliases(cls, v): def validate_exposed_aliases(cls, v):
if len(v) != len(set(i.alias for i in v)): if len(v) != len(set(i.alias for i in v)):
raise ValueError("Duplicate exposed alias") raise ValueError("Duplicate exposed alias")
@ -1263,23 +1144,27 @@ class LibraryGraph(BaseModel):
@root_validator @root_validator
def validate_exposed_nodes(cls, values): def validate_exposed_nodes(cls, values):
graph = values['graph'] graph = values["graph"]
# Validate exposed inputs # Validate exposed inputs
for exposed_input in values['exposed_inputs']: for exposed_input in values["exposed_inputs"]:
if not graph.has_node(exposed_input.node_path): if not graph.has_node(exposed_input.node_path):
raise ValueError(f"Exposed input node {exposed_input.node_path} does not exist") raise ValueError(f"Exposed input node {exposed_input.node_path} does not exist")
node = graph.get_node(exposed_input.node_path) node = graph.get_node(exposed_input.node_path)
if get_input_field(node, exposed_input.field) is None: if get_input_field(node, exposed_input.field) is None:
raise ValueError(f"Exposed input field {exposed_input.field} does not exist on node {exposed_input.node_path}") raise ValueError(
f"Exposed input field {exposed_input.field} does not exist on node {exposed_input.node_path}"
)
# Validate exposed outputs # Validate exposed outputs
for exposed_output in values['exposed_outputs']: for exposed_output in values["exposed_outputs"]:
if not graph.has_node(exposed_output.node_path): if not graph.has_node(exposed_output.node_path):
raise ValueError(f"Exposed output node {exposed_output.node_path} does not exist") raise ValueError(f"Exposed output node {exposed_output.node_path} does not exist")
node = graph.get_node(exposed_output.node_path) node = graph.get_node(exposed_output.node_path)
if get_output_field(node, exposed_output.field) is None: if get_output_field(node, exposed_output.field) is None:
raise ValueError(f"Exposed output field {exposed_output.field} does not exist on node {exposed_output.node_path}") raise ValueError(
f"Exposed output field {exposed_output.field} does not exist on node {exposed_output.node_path}"
)
return values return values

View File

@ -85,9 +85,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
self.__cache_ids = Queue() self.__cache_ids = Queue()
self.__max_cache_size = 10 # TODO: get this from config self.__max_cache_size = 10 # TODO: get this from config
self.__output_folder: Path = ( self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
output_folder if isinstance(output_folder, Path) else Path(output_folder)
)
self.__thumbnails_folder = self.__output_folder / "thumbnails" self.__thumbnails_folder = self.__output_folder / "thumbnails"
# Validate required output folders at launch # Validate required output folders at launch
@ -183,9 +181,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
def __set_cache(self, image_name: Path, image: PILImageType): def __set_cache(self, image_name: Path, image: PILImageType):
if not image_name in self.__cache: if not image_name in self.__cache:
self.__cache[image_name] = image self.__cache[image_name] = image
self.__cache_ids.put( self.__cache_ids.put(image_name) # TODO: this should refresh position for LRU cache
image_name
) # TODO: this should refresh position for LRU cache
if len(self.__cache) > self.__max_cache_size: if len(self.__cache) > self.__max_cache_size:
cache_id = self.__cache_ids.get() cache_id = self.__cache_ids.get()
if cache_id in self.__cache: if cache_id in self.__cache:

View File

@ -426,9 +426,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
finally: finally:
self._lock.release() self._lock.release()
return OffsetPaginatedResults( return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
items=images, offset=offset, limit=limit, total=count
)
def delete(self, image_name: str) -> None: def delete(self, image_name: str) -> None:
try: try:
@ -466,7 +464,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
finally: finally:
self._lock.release() self._lock.release()
def delete_intermediates(self) -> list[str]: def delete_intermediates(self) -> list[str]:
try: try:
self._lock.acquire() self._lock.acquire()
@ -505,9 +502,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
is_intermediate: bool = False, is_intermediate: bool = False,
) -> datetime: ) -> datetime:
try: try:
metadata_json = ( metadata_json = None if metadata is None else json.dumps(metadata)
None if metadata is None else json.dumps(metadata)
)
self._lock.acquire() self._lock.acquire()
self._cursor.execute( self._cursor.execute(
"""--sql """--sql

View File

@ -216,16 +216,9 @@ class ImageService(ImageServiceABC):
metadata=metadata, metadata=metadata,
session_id=session_id, session_id=session_id,
) )
if board_id is not None: if board_id is not None:
self._services.board_image_records.add_image_to_board( self._services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name)
board_id=board_id, image_name=image_name self._services.image_files.save(image_name=image_name, image=image, metadata=metadata, graph=graph)
)
self._services.image_files.save(
image_name=image_name, image=image, metadata=metadata, graph=graph
)
image_dto = self.get_dto(image_name) image_dto = self.get_dto(image_name)
return image_dto return image_dto
@ -236,7 +229,7 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Failed to save image file") self._services.logger.error("Failed to save image file")
raise raise
except Exception as e: except Exception as e:
self._services.logger.error("Problem saving image record and file") self._services.logger.error(f"Problem saving image record and file: {str(e)}")
raise e raise e
def update( def update(
@ -300,9 +293,7 @@ class ImageService(ImageServiceABC):
if not image_record.session_id: if not image_record.session_id:
return ImageMetadata() return ImageMetadata()
session_raw = self._services.graph_execution_manager.get_raw( session_raw = self._services.graph_execution_manager.get_raw(image_record.session_id)
image_record.session_id
)
graph = None graph = None
if session_raw: if session_raw:
@ -367,9 +358,7 @@ class ImageService(ImageServiceABC):
r, r,
self._services.urls.get_image_url(r.image_name), self._services.urls.get_image_url(r.image_name),
self._services.urls.get_image_url(r.image_name, True), self._services.urls.get_image_url(r.image_name, True),
self._services.board_image_records.get_board_for_image( self._services.board_image_records.get_board_for_image(r.image_name),
r.image_name
),
), ),
results.items, results.items,
) )
@ -401,11 +390,7 @@ class ImageService(ImageServiceABC):
def delete_images_on_board(self, board_id: str): def delete_images_on_board(self, board_id: str):
try: try:
image_names = ( image_names = self._services.board_image_records.get_all_board_image_names_for_board(board_id)
self._services.board_image_records.get_all_board_image_names_for_board(
board_id
)
)
for image_name in image_names: for image_name in image_names:
self._services.image_files.delete(image_name) self._services.image_files.delete(image_name)
self._services.image_records.delete_many(image_names) self._services.image_records.delete_many(image_names)

View File

@ -7,6 +7,7 @@ from queue import Queue
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import Optional from typing import Optional
class InvocationQueueItem(BaseModel): class InvocationQueueItem(BaseModel):
graph_execution_state_id: str = Field(description="The ID of the graph execution state") graph_execution_state_id: str = Field(description="The ID of the graph execution state")
invocation_id: str = Field(description="The ID of the node being invoked") invocation_id: str = Field(description="The ID of the node being invoked")
@ -45,9 +46,11 @@ class MemoryInvocationQueue(InvocationQueueABC):
def get(self) -> InvocationQueueItem: def get(self) -> InvocationQueueItem:
item = self.__queue.get() item = self.__queue.get()
while isinstance(item, InvocationQueueItem) \ while (
and item.graph_execution_state_id in self.__cancellations \ isinstance(item, InvocationQueueItem)
and self.__cancellations[item.graph_execution_state_id] > item.timestamp: and item.graph_execution_state_id in self.__cancellations
and self.__cancellations[item.graph_execution_state_id] > item.timestamp
):
item = self.__queue.get() item = self.__queue.get()
# Clear old items # Clear old items

View File

@ -7,6 +7,7 @@ from .graph import Graph, GraphExecutionState
from .invocation_queue import InvocationQueueItem from .invocation_queue import InvocationQueueItem
from .invocation_services import InvocationServices from .invocation_services import InvocationServices
class Invoker: class Invoker:
"""The invoker, used to execute invocations""" """The invoker, used to execute invocations"""
@ -16,9 +17,7 @@ class Invoker:
self.services = services self.services = services
self._start() self._start()
def invoke( def invoke(self, graph_execution_state: GraphExecutionState, invoke_all: bool = False) -> Optional[str]:
self, graph_execution_state: GraphExecutionState, invoke_all: bool = False
) -> Optional[str]:
"""Determines the next node to invoke and enqueues it, preparing if needed. """Determines the next node to invoke and enqueues it, preparing if needed.
Returns the id of the queued node, or `None` if there are no nodes left to enqueue.""" Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""
# Get the next invocation # Get the next invocation

View File

@ -9,6 +9,7 @@ T = TypeVar("T", bound=BaseModel)
class PaginatedResults(GenericModel, Generic[T]): class PaginatedResults(GenericModel, Generic[T]):
"""Paginated results""" """Paginated results"""
# fmt: off # fmt: off
items: list[T] = Field(description="Items") items: list[T] = Field(description="Items")
page: int = Field(description="Current Page") page: int = Field(description="Current Page")
@ -17,6 +18,7 @@ class PaginatedResults(GenericModel, Generic[T]):
total: int = Field(description="Total number of items in result") total: int = Field(description="Total number of items in result")
# fmt: on # fmt: on
class ItemStorageABC(ABC, Generic[T]): class ItemStorageABC(ABC, Generic[T]):
_on_changed_callbacks: list[Callable[[T], None]] _on_changed_callbacks: list[Callable[[T], None]]
_on_deleted_callbacks: list[Callable[[str], None]] _on_deleted_callbacks: list[Callable[[str], None]]
@ -48,9 +50,7 @@ class ItemStorageABC(ABC, Generic[T]):
pass pass
@abstractmethod @abstractmethod
def search( def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
self, query: str, page: int = 0, per_page: int = 10
) -> PaginatedResults[T]:
pass pass
def on_changed(self, on_changed: Callable[[T], None]) -> None: def on_changed(self, on_changed: Callable[[T], None]) -> None:

View File

@ -7,6 +7,7 @@ from typing import Dict, Union, Optional
import torch import torch
class LatentsStorageBase(ABC): class LatentsStorageBase(ABC):
"""Responsible for storing and retrieving latents.""" """Responsible for storing and retrieving latents."""
@ -88,7 +89,5 @@ class DiskLatentsStorage(LatentsStorageBase):
latent_path = self.get_path(name) latent_path = self.get_path(name)
latent_path.unlink() latent_path.unlink()
def get_path(self, name: str) -> Path: def get_path(self, name: str) -> Path:
return self.__output_folder / name return self.__output_folder / name

View File

@ -125,7 +125,7 @@ class ModelManagerServiceBase(ABC):
base_model: BaseModelType, base_model: BaseModelType,
model_type: ModelType, model_type: ModelType,
model_attributes: dict, model_attributes: dict,
clobber: bool = False clobber: bool = False,
) -> AddModelResult: ) -> AddModelResult:
""" """
Update the named model with a dictionary of attributes. Will fail with an Update the named model with a dictionary of attributes. Will fail with an
@ -169,7 +169,8 @@ class ModelManagerServiceBase(ABC):
pass pass
@abstractmethod @abstractmethod
def rename_model(self, def rename_model(
self,
model_name: str, model_name: str,
base_model: BaseModelType, base_model: BaseModelType,
model_type: ModelType, model_type: ModelType,
@ -181,9 +182,7 @@ class ModelManagerServiceBase(ABC):
pass pass
@abstractmethod @abstractmethod
def list_checkpoint_configs( def list_checkpoint_configs(self) -> List[Path]:
self
)->List[Path]:
""" """
List the checkpoint config paths from ROOT/configs/stable-diffusion. List the checkpoint config paths from ROOT/configs/stable-diffusion.
""" """
@ -211,11 +210,12 @@ class ModelManagerServiceBase(ABC):
pass pass
@abstractmethod @abstractmethod
def heuristic_import(self, def heuristic_import(
self,
items_to_import: set[str], items_to_import: set[str],
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None, prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
) -> dict[str, AddModelResult]: ) -> dict[str, AddModelResult]:
'''Import a list of paths, repo_ids or URLs. Returns the set of """Import a list of paths, repo_ids or URLs. Returns the set of
successfully imported items. successfully imported items.
:param items_to_import: Set of strings corresponding to models to be imported. :param items_to_import: Set of strings corresponding to models to be imported.
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType. :param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
@ -230,19 +230,23 @@ class ModelManagerServiceBase(ABC):
The result is a set of successfully installed models. Each element The result is a set of successfully installed models. Each element
of the set is a dict corresponding to the newly-created OmegaConf stanza for of the set is a dict corresponding to the newly-created OmegaConf stanza for
that model. that model.
''' """
pass pass
@abstractmethod @abstractmethod
def merge_models( def merge_models(
self, self,
model_names: List[str] = Field(default=None, min_items=2, max_items=3, description="List of model names to merge"), model_names: List[str] = Field(
base_model: Union[BaseModelType,str] = Field(default=None, description="Base model shared by all models to be merged"), default=None, min_items=2, max_items=3, description="List of model names to merge"
),
base_model: Union[BaseModelType, str] = Field(
default=None, description="Base model shared by all models to be merged"
),
merged_model_name: str = Field(default=None, description="Name of destination model after merging"), merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
alpha: Optional[float] = 0.5, alpha: Optional[float] = 0.5,
interp: Optional[MergeInterpolationMethod] = None, interp: Optional[MergeInterpolationMethod] = None,
force: Optional[bool] = False, force: Optional[bool] = False,
merge_dest_directory: Optional[Path] = None merge_dest_directory: Optional[Path] = None,
) -> AddModelResult: ) -> AddModelResult:
""" """
Merge two to three diffusrs pipeline models and save as a new model. Merge two to three diffusrs pipeline models and save as a new model.
@ -280,9 +284,11 @@ class ModelManagerServiceBase(ABC):
""" """
pass pass
# simple implementation # simple implementation
class ModelManagerService(ModelManagerServiceBase): class ModelManagerService(ModelManagerServiceBase):
"""Responsible for managing models on disk and in memory""" """Responsible for managing models on disk and in memory"""
def __init__( def __init__(
self, self,
config: InvokeAIAppConfig, config: InvokeAIAppConfig,
@ -299,16 +305,16 @@ class ModelManagerService(ModelManagerServiceBase):
else: else:
config_file = config.root_dir / "configs/models.yaml" config_file = config.root_dir / "configs/models.yaml"
logger.debug(f'Config file={config_file}') logger.debug(f"Config file={config_file}")
device = torch.device(choose_torch_device()) device = torch.device(choose_torch_device())
device_name = torch.cuda.get_device_name() if device==torch.device('cuda') else '' device_name = torch.cuda.get_device_name() if device == torch.device("cuda") else ""
logger.info(f'GPU device = {device} {device_name}') logger.info(f"GPU device = {device} {device_name}")
precision = config.precision precision = config.precision
if precision == "auto": if precision == "auto":
precision = choose_precision(device) precision = choose_precision(device)
dtype = torch.float32 if precision == 'float32' else torch.float16 dtype = torch.float32 if precision == "float32" else torch.float16
# this is transitional backward compatibility # this is transitional backward compatibility
# support for the deprecated `max_loaded_models` # support for the deprecated `max_loaded_models`
@ -316,9 +322,7 @@ class ModelManagerService(ModelManagerServiceBase):
# cache size is set to 2.5 GB times # cache size is set to 2.5 GB times
# the number of max_loaded_models. Otherwise # the number of max_loaded_models. Otherwise
# use new `max_cache_size` config setting # use new `max_cache_size` config setting
max_cache_size = config.max_cache_size \ max_cache_size = config.max_cache_size if hasattr(config, "max_cache_size") else config.max_loaded_models * 2.5
if hasattr(config,'max_cache_size') \
else config.max_loaded_models * 2.5
logger.debug(f"Maximum RAM cache size: {max_cache_size} GiB") logger.debug(f"Maximum RAM cache size: {max_cache_size} GiB")
@ -332,7 +336,7 @@ class ModelManagerService(ModelManagerServiceBase):
sequential_offload=sequential_offload, sequential_offload=sequential_offload,
logger=logger, logger=logger,
) )
logger.info('Model manager service initialized') logger.info("Model manager service initialized")
def get_model( def get_model(
self, self,
@ -371,7 +375,7 @@ class ModelManagerService(ModelManagerServiceBase):
base_model=base_model, base_model=base_model,
model_type=model_type, model_type=model_type,
submodel=submodel, submodel=submodel,
model_info=model_info model_info=model_info,
) )
return model_info return model_info
@ -405,9 +409,7 @@ class ModelManagerService(ModelManagerServiceBase):
return self.mgr.model_names() return self.mgr.model_names()
def list_models( def list_models(
self, self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None
) -> list[dict]: ) -> list[dict]:
""" """
Return a list of models. Return a list of models.
@ -418,9 +420,7 @@ class ModelManagerService(ModelManagerServiceBase):
""" """
Return information about the model using the same format as list_models() Return information about the model using the same format as list_models()
""" """
return self.mgr.list_model(model_name=model_name, return self.mgr.list_model(model_name=model_name, base_model=base_model, model_type=model_type)
base_model=base_model,
model_type=model_type)
def add_model( def add_model(
self, self,
@ -437,7 +437,7 @@ class ModelManagerService(ModelManagerServiceBase):
with an assertion error if provided attributes are incorrect or with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk. the model name is missing. Call commit() to write changes to disk.
""" """
self.logger.debug(f'add/update model {model_name}') self.logger.debug(f"add/update model {model_name}")
return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber) return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber)
def update_model( def update_model(
@ -454,7 +454,7 @@ class ModelManagerService(ModelManagerServiceBase):
with an assertion error if provided attributes are incorrect or with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk. the model name is missing. Call commit() to write changes to disk.
""" """
self.logger.debug(f'update model {model_name}') self.logger.debug(f"update model {model_name}")
if not self.model_exists(model_name, base_model, model_type): if not self.model_exists(model_name, base_model, model_type):
raise ModelNotFoundException(f"Unknown model {model_name}") raise ModelNotFoundException(f"Unknown model {model_name}")
return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True) return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True)
@ -470,7 +470,7 @@ class ModelManagerService(ModelManagerServiceBase):
then the underlying weight file or diffusers directory will be deleted then the underlying weight file or diffusers directory will be deleted
as well. as well.
""" """
self.logger.debug(f'delete model {model_name}') self.logger.debug(f"delete model {model_name}")
self.mgr.del_model(model_name, base_model, model_type) self.mgr.del_model(model_name, base_model, model_type)
self.mgr.commit() self.mgr.commit()
@ -479,7 +479,9 @@ class ModelManagerService(ModelManagerServiceBase):
model_name: str, model_name: str,
base_model: BaseModelType, base_model: BaseModelType,
model_type: Union[ModelType.Main, ModelType.Vae], model_type: Union[ModelType.Main, ModelType.Vae],
convert_dest_directory: Optional[Path] = Field(default=None, description="Optional directory location for merged model"), convert_dest_directory: Optional[Path] = Field(
default=None, description="Optional directory location for merged model"
),
) -> AddModelResult: ) -> AddModelResult:
""" """
Convert a checkpoint file into a diffusers folder, deleting the cached Convert a checkpoint file into a diffusers folder, deleting the cached
@ -494,7 +496,7 @@ class ModelManagerService(ModelManagerServiceBase):
also raise a ValueError in the event that there is a similarly-named diffusers also raise a ValueError in the event that there is a similarly-named diffusers
directory already in place. directory already in place.
""" """
self.logger.debug(f'convert model {model_name}') self.logger.debug(f"convert model {model_name}")
return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory) return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory)
def commit(self, conf_file: Optional[Path] = None): def commit(self, conf_file: Optional[Path] = None):
@ -524,7 +526,7 @@ class ModelManagerService(ModelManagerServiceBase):
base_model=base_model, base_model=base_model,
model_type=model_type, model_type=model_type,
submodel=submodel, submodel=submodel,
model_info=model_info model_info=model_info,
) )
else: else:
context.services.events.emit_model_load_started( context.services.events.emit_model_load_started(
@ -535,16 +537,16 @@ class ModelManagerService(ModelManagerServiceBase):
submodel=submodel, submodel=submodel,
) )
@property @property
def logger(self): def logger(self):
return self.mgr.logger return self.mgr.logger
def heuristic_import(self, def heuristic_import(
self,
items_to_import: set[str], items_to_import: set[str],
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None, prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
) -> dict[str, AddModelResult]: ) -> dict[str, AddModelResult]:
'''Import a list of paths, repo_ids or URLs. Returns the set of """Import a list of paths, repo_ids or URLs. Returns the set of
successfully imported items. successfully imported items.
:param items_to_import: Set of strings corresponding to models to be imported. :param items_to_import: Set of strings corresponding to models to be imported.
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType. :param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
@ -559,18 +561,24 @@ class ModelManagerService(ModelManagerServiceBase):
The result is a set of successfully installed models. Each element The result is a set of successfully installed models. Each element
of the set is a dict corresponding to the newly-created OmegaConf stanza for of the set is a dict corresponding to the newly-created OmegaConf stanza for
that model. that model.
''' """
return self.mgr.heuristic_import(items_to_import, prediction_type_helper) return self.mgr.heuristic_import(items_to_import, prediction_type_helper)
def merge_models( def merge_models(
self, self,
model_names: List[str] = Field(default=None, min_items=2, max_items=3, description="List of model names to merge"), model_names: List[str] = Field(
base_model: Union[BaseModelType,str] = Field(default=None, description="Base model shared by all models to be merged"), default=None, min_items=2, max_items=3, description="List of model names to merge"
),
base_model: Union[BaseModelType, str] = Field(
default=None, description="Base model shared by all models to be merged"
),
merged_model_name: str = Field(default=None, description="Name of destination model after merging"), merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
alpha: Optional[float] = 0.5, alpha: Optional[float] = 0.5,
interp: Optional[MergeInterpolationMethod] = None, interp: Optional[MergeInterpolationMethod] = None,
force: Optional[bool] = False, force: Optional[bool] = False,
merge_dest_directory: Optional[Path] = Field(default=None, description="Optional directory location for merged model"), merge_dest_directory: Optional[Path] = Field(
default=None, description="Optional directory location for merged model"
),
) -> AddModelResult: ) -> AddModelResult:
""" """
Merge two to three diffusrs pipeline models and save as a new model. Merge two to three diffusrs pipeline models and save as a new model.
@ -618,9 +626,10 @@ class ModelManagerService(ModelManagerServiceBase):
config = self.mgr.app_config config = self.mgr.app_config
conf_path = config.legacy_conf_path conf_path = config.legacy_conf_path
root_path = config.root_path root_path = config.root_path
return [(conf_path / x).relative_to(root_path) for x in conf_path.glob('**/*.yaml')] return [(conf_path / x).relative_to(root_path) for x in conf_path.glob("**/*.yaml")]
def rename_model(self, def rename_model(
self,
model_name: str, model_name: str,
base_model: BaseModelType, base_model: BaseModelType,
model_type: ModelType, model_type: ModelType,
@ -635,10 +644,10 @@ class ModelManagerService(ModelManagerServiceBase):
:param new_name: New name for the model :param new_name: New name for the model
:param new_base: New base for the model :param new_base: New base for the model
""" """
self.mgr.rename_model(base_model = base_model, self.mgr.rename_model(
base_model=base_model,
model_type=model_type, model_type=model_type,
model_name=model_name, model_name=model_name,
new_name=new_name, new_name=new_name,
new_base=new_base, new_base=new_base,
) )

View File

@ -11,30 +11,20 @@ class BoardRecord(BaseModel):
"""The unique ID of the board.""" """The unique ID of the board."""
board_name: str = Field(description="The name of the board.") board_name: str = Field(description="The name of the board.")
"""The name of the board.""" """The name of the board."""
created_at: Union[datetime, str] = Field( created_at: Union[datetime, str] = Field(description="The created timestamp of the board.")
description="The created timestamp of the board."
)
"""The created timestamp of the image.""" """The created timestamp of the image."""
updated_at: Union[datetime, str] = Field( updated_at: Union[datetime, str] = Field(description="The updated timestamp of the board.")
description="The updated timestamp of the board."
)
"""The updated timestamp of the image.""" """The updated timestamp of the image."""
deleted_at: Union[datetime, str, None] = Field( deleted_at: Union[datetime, str, None] = Field(description="The deleted timestamp of the board.")
description="The deleted timestamp of the board."
)
"""The updated timestamp of the image.""" """The updated timestamp of the image."""
cover_image_name: Optional[str] = Field( cover_image_name: Optional[str] = Field(description="The name of the cover image of the board.")
description="The name of the cover image of the board."
)
"""The name of the cover image of the board.""" """The name of the cover image of the board."""
class BoardDTO(BoardRecord): class BoardDTO(BoardRecord):
"""Deserialized board record with cover image URL and image count.""" """Deserialized board record with cover image URL and image count."""
cover_image_name: Optional[str] = Field( cover_image_name: Optional[str] = Field(description="The name of the board's cover image.")
description="The name of the board's cover image."
)
"""The URL of the thumbnail of the most recent image in the board.""" """The URL of the thumbnail of the most recent image in the board."""
image_count: int = Field(description="The number of images in the board.") image_count: int = Field(description="The number of images in the board.")
"""The number of images in the board.""" """The number of images in the board."""

View File

@ -20,17 +20,11 @@ class ImageRecord(BaseModel):
"""The actual width of the image in px. This may be different from the width in metadata.""" """The actual width of the image in px. This may be different from the width in metadata."""
height: int = Field(description="The height of the image in px.") height: int = Field(description="The height of the image in px.")
"""The actual height of the image in px. This may be different from the height in metadata.""" """The actual height of the image in px. This may be different from the height in metadata."""
created_at: Union[datetime.datetime, str] = Field( created_at: Union[datetime.datetime, str] = Field(description="The created timestamp of the image.")
description="The created timestamp of the image."
)
"""The created timestamp of the image.""" """The created timestamp of the image."""
updated_at: Union[datetime.datetime, str] = Field( updated_at: Union[datetime.datetime, str] = Field(description="The updated timestamp of the image.")
description="The updated timestamp of the image."
)
"""The updated timestamp of the image.""" """The updated timestamp of the image."""
deleted_at: Union[datetime.datetime, str, None] = Field( deleted_at: Union[datetime.datetime, str, None] = Field(description="The deleted timestamp of the image.")
description="The deleted timestamp of the image."
)
"""The deleted timestamp of the image.""" """The deleted timestamp of the image."""
is_intermediate: bool = Field(description="Whether this is an intermediate image.") is_intermediate: bool = Field(description="Whether this is an intermediate image.")
"""Whether this is an intermediate image.""" """Whether this is an intermediate image."""
@ -55,18 +49,14 @@ class ImageRecordChanges(BaseModel, extra=Extra.forbid):
- `is_intermediate`: change the image's `is_intermediate` flag - `is_intermediate`: change the image's `is_intermediate` flag
""" """
image_category: Optional[ImageCategory] = Field( image_category: Optional[ImageCategory] = Field(description="The image's new category.")
description="The image's new category."
)
"""The image's new category.""" """The image's new category."""
session_id: Optional[StrictStr] = Field( session_id: Optional[StrictStr] = Field(
default=None, default=None,
description="The image's new session ID.", description="The image's new session ID.",
) )
"""The image's new session ID.""" """The image's new session ID."""
is_intermediate: Optional[StrictBool] = Field( is_intermediate: Optional[StrictBool] = Field(default=None, description="The image's new `is_intermediate` flag.")
default=None, description="The image's new `is_intermediate` flag."
)
"""The image's new `is_intermediate` flag.""" """The image's new `is_intermediate` flag."""
@ -84,9 +74,7 @@ class ImageUrlsDTO(BaseModel):
class ImageDTO(ImageRecord, ImageUrlsDTO): class ImageDTO(ImageRecord, ImageUrlsDTO):
"""Deserialized image record, enriched for the frontend.""" """Deserialized image record, enriched for the frontend."""
board_id: Optional[str] = Field( board_id: Optional[str] = Field(description="The id of the board the image belongs to, if one exists.")
description="The id of the board the image belongs to, if one exists."
)
"""The id of the board the image belongs to, if one exists.""" """The id of the board the image belongs to, if one exists."""
pass pass
@ -110,12 +98,8 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
# TODO: do we really need to handle default values here? ideally the data is the correct shape... # TODO: do we really need to handle default values here? ideally the data is the correct shape...
image_name = image_dict.get("image_name", "unknown") image_name = image_dict.get("image_name", "unknown")
image_origin = ResourceOrigin( image_origin = ResourceOrigin(image_dict.get("image_origin", ResourceOrigin.INTERNAL.value))
image_dict.get("image_origin", ResourceOrigin.INTERNAL.value) image_category = ImageCategory(image_dict.get("image_category", ImageCategory.GENERAL.value))
)
image_category = ImageCategory(
image_dict.get("image_category", ImageCategory.GENERAL.value)
)
width = image_dict.get("width", 0) width = image_dict.get("width", 0)
height = image_dict.get("height", 0) height = image_dict.get("height", 0)
session_id = image_dict.get("session_id", None) session_id = image_dict.get("session_id", None)

View File

@ -9,6 +9,8 @@ from ..models.exceptions import CanceledException
from .graph import GraphExecutionState from .graph import GraphExecutionState
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
class DefaultInvocationProcessor(InvocationProcessorABC): class DefaultInvocationProcessor(InvocationProcessorABC):
__invoker_thread: Thread __invoker_thread: Thread
__stop_event: Event __stop_event: Event
@ -25,9 +27,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
target=self.__process, target=self.__process,
kwargs=dict(stop_event=self.__stop_event), kwargs=dict(stop_event=self.__stop_event),
) )
self.__invoker_thread.daemon = ( self.__invoker_thread.daemon = True # TODO: make async and do not use threads
True # TODO: make async and do not use threads
)
self.__invoker_thread.start() self.__invoker_thread.start()
def stop(self, *args, **kwargs) -> None: def stop(self, *args, **kwargs) -> None:
@ -48,11 +48,9 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
continue continue
try: try:
graph_execution_state = ( graph_execution_state = self.__invoker.services.graph_execution_manager.get(
self.__invoker.services.graph_execution_manager.get(
queue_item.graph_execution_state_id queue_item.graph_execution_state_id
) )
)
except Exception as e: except Exception as e:
self.__invoker.services.logger.error("Exception while retrieving session:\n%s" % e) self.__invoker.services.logger.error("Exception while retrieving session:\n%s" % e)
self.__invoker.services.events.emit_session_retrieval_error( self.__invoker.services.events.emit_session_retrieval_error(
@ -63,9 +61,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
continue continue
try: try:
invocation = graph_execution_state.execution_graph.get_node( invocation = graph_execution_state.execution_graph.get_node(queue_item.invocation_id)
queue_item.invocation_id
)
except Exception as e: except Exception as e:
self.__invoker.services.logger.error("Exception while retrieving invocation:\n%s" % e) self.__invoker.services.logger.error("Exception while retrieving invocation:\n%s" % e)
self.__invoker.services.events.emit_invocation_retrieval_error( self.__invoker.services.events.emit_invocation_retrieval_error(
@ -82,7 +78,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
self.__invoker.services.events.emit_invocation_started( self.__invoker.services.events.emit_invocation_started(
graph_execution_state_id=graph_execution_state.id, graph_execution_state_id=graph_execution_state.id,
node=invocation.dict(), node=invocation.dict(),
source_node_id=source_node_id source_node_id=source_node_id,
) )
# Invoke # Invoke
@ -95,18 +91,14 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
) )
# Check queue to see if this is canceled, and skip if so # Check queue to see if this is canceled, and skip if so
if self.__invoker.services.queue.is_canceled( if self.__invoker.services.queue.is_canceled(graph_execution_state.id):
graph_execution_state.id
):
continue continue
# Save outputs and history # Save outputs and history
graph_execution_state.complete(invocation.id, outputs) graph_execution_state.complete(invocation.id, outputs)
# Save the state changes # Save the state changes
self.__invoker.services.graph_execution_manager.set( self.__invoker.services.graph_execution_manager.set(graph_execution_state)
graph_execution_state
)
# Send complete event # Send complete event
self.__invoker.services.events.emit_invocation_complete( self.__invoker.services.events.emit_invocation_complete(
@ -130,9 +122,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
graph_execution_state.set_node_error(invocation.id, error) graph_execution_state.set_node_error(invocation.id, error)
# Save the state changes # Save the state changes
self.__invoker.services.graph_execution_manager.set( self.__invoker.services.graph_execution_manager.set(graph_execution_state)
graph_execution_state
)
self.__invoker.services.logger.error("Error while invoking:\n%s" % e) self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
# Send error event # Send error event
@ -147,9 +137,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
pass pass
# Check queue to see if this is canceled, and skip if so # Check queue to see if this is canceled, and skip if so
if self.__invoker.services.queue.is_canceled( if self.__invoker.services.queue.is_canceled(graph_execution_state.id):
graph_execution_state.id
):
continue continue
# Queue any further commands if invoking all # Queue any further commands if invoking all
@ -164,7 +152,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
node=invocation.dict(), node=invocation.dict(),
source_node_id=source_node_id, source_node_id=source_node_id,
error_type=e.__class__.__name__, error_type=e.__class__.__name__,
error=traceback.format_exc() error=traceback.format_exc(),
) )
elif queue_item.invoke_all and sum(graph_execution_state.batch_indices) > 0: elif queue_item.invoke_all and sum(graph_execution_state.batch_indices) > 0:
batch_indicies = graph_execution_state.batch_indices.copy() batch_indicies = graph_execution_state.batch_indices.copy()
@ -176,9 +164,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
self.__invoker.services.graph_execution_manager.set(new_ges) self.__invoker.services.graph_execution_manager.set(new_ges)
self.__invoker.invoke(new_ges, invoke_all=True) self.__invoker.invoke(new_ges, invoke_all=True)
elif is_complete: elif is_complete:
self.__invoker.services.events.emit_graph_execution_complete( self.__invoker.services.events.emit_graph_execution_complete(graph_execution_state.id)
graph_execution_state.id
)
except KeyboardInterrupt: except KeyboardInterrupt:
pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor

View File

@ -66,9 +66,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
def get(self, id: str) -> Optional[T]: def get(self, id: str) -> Optional[T]:
try: try:
self._lock.acquire() self._lock.acquire()
self._cursor.execute( self._cursor.execute(f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),))
f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),)
)
result = self._cursor.fetchone() result = self._cursor.fetchone()
finally: finally:
self._lock.release() self._lock.release()
@ -81,9 +79,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
def get_raw(self, id: str) -> Optional[str]: def get_raw(self, id: str) -> Optional[str]:
try: try:
self._lock.acquire() self._lock.acquire()
self._cursor.execute( self._cursor.execute(f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),))
f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),)
)
result = self._cursor.fetchone() result = self._cursor.fetchone()
finally: finally:
self._lock.release() self._lock.release()
@ -96,9 +92,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
def delete(self, id: str): def delete(self, id: str):
try: try:
self._lock.acquire() self._lock.acquire()
self._cursor.execute( self._cursor.execute(f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),))
f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),)
)
self._conn.commit() self._conn.commit()
finally: finally:
self._lock.release() self._lock.release()
@ -122,13 +116,9 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
pageCount = int(count / per_page) + 1 pageCount = int(count / per_page) + 1
return PaginatedResults[T]( return PaginatedResults[T](items=items, page=page, pages=pageCount, per_page=per_page, total=count)
items=items, page=page, pages=pageCount, per_page=per_page, total=count
)
def search( def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
self, query: str, page: int = 0, per_page: int = 10
) -> PaginatedResults[T]:
try: try:
self._lock.acquire() self._lock.acquire()
self._cursor.execute( self._cursor.execute(
@ -149,6 +139,4 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
pageCount = int(count / per_page) + 1 pageCount = int(count / per_page) + 1
return PaginatedResults[T]( return PaginatedResults[T](items=items, page=page, pages=pageCount, per_page=per_page, total=count)
items=items, page=page, pages=pageCount, per_page=per_page, total=count
)

View File

@ -17,16 +17,8 @@ from controlnet_aux.util import HWC3, resize_image
# If you use this, please Cite "High Quality Edge Thinning using Pure Python", Lvmin Zhang, In Mikubill/sd-webui-controlnet. # If you use this, please Cite "High Quality Edge Thinning using Pure Python", Lvmin Zhang, In Mikubill/sd-webui-controlnet.
lvmin_kernels_raw = [ lvmin_kernels_raw = [
np.array([ np.array([[-1, -1, -1], [0, 1, 0], [1, 1, 1]], dtype=np.int32),
[-1, -1, -1], np.array([[0, -1, -1], [1, 1, -1], [0, 1, 0]], dtype=np.int32),
[0, 1, 0],
[1, 1, 1]
], dtype=np.int32),
np.array([
[0, -1, -1],
[1, 1, -1],
[0, 1, 0]
], dtype=np.int32)
] ]
lvmin_kernels = [] lvmin_kernels = []
@ -36,16 +28,8 @@ lvmin_kernels += [np.rot90(x, k=2, axes=(0, 1)) for x in lvmin_kernels_raw]
lvmin_kernels += [np.rot90(x, k=3, axes=(0, 1)) for x in lvmin_kernels_raw] lvmin_kernels += [np.rot90(x, k=3, axes=(0, 1)) for x in lvmin_kernels_raw]
lvmin_prunings_raw = [ lvmin_prunings_raw = [
np.array([ np.array([[-1, -1, -1], [-1, 1, -1], [0, 0, -1]], dtype=np.int32),
[-1, -1, -1], np.array([[-1, -1, -1], [-1, 1, -1], [-1, 0, 0]], dtype=np.int32),
[-1, 1, -1],
[0, 0, -1]
], dtype=np.int32),
np.array([
[-1, -1, -1],
[-1, 1, -1],
[-1, 0, 0]
], dtype=np.int32)
] ]
lvmin_prunings = [] lvmin_prunings = []
@ -154,13 +138,7 @@ def pixel_perfect_resolution(
# modified for InvokeAI # modified for InvokeAI
########################################################################### ###########################################################################
# def detectmap_proc(detected_map, module, resize_mode, h, w): # def detectmap_proc(detected_map, module, resize_mode, h, w):
def np_img_resize( def np_img_resize(np_img: np.ndarray, resize_mode: str, h: int, w: int, device: torch.device = torch.device("cpu")):
np_img: np.ndarray,
resize_mode: str,
h: int,
w: int,
device: torch.device = torch.device('cpu')
):
# if 'inpaint' in module: # if 'inpaint' in module:
# np_img = np_img.astype(np.float32) # np_img = np_img.astype(np.float32)
# else: # else:
@ -184,15 +162,14 @@ def np_img_resize(
# below is very boring but do not change these. If you change these Apple or Mac may fail. # below is very boring but do not change these. If you change these Apple or Mac may fail.
y = torch.from_numpy(y) y = torch.from_numpy(y)
y = y.float() / 255.0 y = y.float() / 255.0
y = rearrange(y, 'h w c -> 1 c h w') y = rearrange(y, "h w c -> 1 c h w")
y = y.clone() y = y.clone()
# y = y.to(devices.get_device_for("controlnet")) # y = y.to(devices.get_device_for("controlnet"))
y = y.to(device) y = y.to(device)
y = y.clone() y = y.clone()
return y return y
def high_quality_resize(x: np.ndarray, def high_quality_resize(x: np.ndarray, size):
size):
# Written by lvmin # Written by lvmin
# Super high-quality control map up-scaling, considering binary, seg, and one-pixel edges # Super high-quality control map up-scaling, considering binary, seg, and one-pixel edges
inpaint_mask = None inpaint_mask = None
@ -284,6 +261,7 @@ def np_img_resize(
np_img = safe_numpy(np_img) np_img = safe_numpy(np_img)
return get_pytorch_control(np_img), np_img return get_pytorch_control(np_img), np_img
def prepare_control_image( def prepare_control_image(
# image used to be Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor, List[torch.Tensor]] # image used to be Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor, List[torch.Tensor]]
# but now should be able to assume that image is a single PIL.Image, which simplifies things # but now should be able to assume that image is a single PIL.Image, which simplifies things
@ -301,15 +279,17 @@ def prepare_control_image(
resize_mode="just_resize_simple", resize_mode="just_resize_simple",
): ):
# FIXME: implement "crop_resize_simple" and "fill_resize_simple", or pull them out # FIXME: implement "crop_resize_simple" and "fill_resize_simple", or pull them out
if (resize_mode == "just_resize_simple" or if (
resize_mode == "crop_resize_simple" or resize_mode == "just_resize_simple"
resize_mode == "fill_resize_simple"): or resize_mode == "crop_resize_simple"
or resize_mode == "fill_resize_simple"
):
image = image.convert("RGB") image = image.convert("RGB")
if (resize_mode == "just_resize_simple"): if resize_mode == "just_resize_simple":
image = image.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]) image = image.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
elif (resize_mode == "crop_resize_simple"): # not yet implemented elif resize_mode == "crop_resize_simple": # not yet implemented
pass pass
elif (resize_mode == "fill_resize_simple"): # not yet implemented elif resize_mode == "fill_resize_simple": # not yet implemented
pass pass
nimage = np.array(image) nimage = np.array(image)
nimage = nimage[None, :] nimage = nimage[None, :]
@ -320,7 +300,7 @@ def prepare_control_image(
timage = torch.from_numpy(nimage) timage = torch.from_numpy(nimage)
# use fancy lvmin controlnet resizing # use fancy lvmin controlnet resizing
elif (resize_mode == "just_resize" or resize_mode == "crop_resize" or resize_mode == "fill_resize"): elif resize_mode == "just_resize" or resize_mode == "crop_resize" or resize_mode == "fill_resize":
nimage = np.array(image) nimage = np.array(image)
timage, nimage = np_img_resize( timage, nimage = np_img_resize(
np_img=nimage, np_img=nimage,
@ -336,7 +316,7 @@ def prepare_control_image(
exit(1) exit(1)
timage = timage.to(device=device, dtype=dtype) timage = timage.to(device=device, dtype=dtype)
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced") cfg_injection = control_mode == "more_control" or control_mode == "unbalanced"
if do_classifier_free_guidance and not cfg_injection: if do_classifier_free_guidance and not cfg_injection:
timage = torch.cat([timage] * 2) timage = torch.cat([timage] * 2)
return timage return timage

View File

@ -18,10 +18,7 @@ def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix
latent_image = latent_image.permute(1, 2, 3, 0).squeeze(0) latent_image = latent_image.permute(1, 2, 3, 0).squeeze(0)
latents_ubyte = ( latents_ubyte = (
((latent_image + 1) / 2) ((latent_image + 1) / 2).clamp(0, 1).mul(0xFF).byte() # change scale from -1..1 to 0..1 # to 0..255
.clamp(0, 1) # change scale from -1..1 to 0..1
.mul(0xFF) # to 0..255
.byte()
).cpu() ).cpu()
return Image.fromarray(latents_ubyte.numpy()) return Image.fromarray(latents_ubyte.numpy())
@ -92,6 +89,7 @@ def stable_diffusion_step_callback(
total_steps=node["steps"], total_steps=node["steps"],
) )
def stable_diffusion_xl_step_callback( def stable_diffusion_xl_step_callback(
context: InvocationContext, context: InvocationContext,
node: dict, node: dict,

View File

@ -1,15 +1,6 @@
""" """
Initialization file for invokeai.backend Initialization file for invokeai.backend
""" """
from .generator import ( from .generator import InvokeAIGeneratorBasicParams, InvokeAIGenerator, InvokeAIGeneratorOutput, Img2Img, Inpaint
InvokeAIGeneratorBasicParams, from .model_management import ModelManager, ModelCache, BaseModelType, ModelType, SubModelType, ModelInfo
InvokeAIGenerator, from .model_management.models import SilenceWarnings
InvokeAIGeneratorOutput,
Img2Img,
Inpaint
)
from .model_management import (
ModelManager, ModelCache, BaseModelType,
ModelType, SubModelType, ModelInfo
)
from .safety_checker import SafetyChecker

View File

@ -28,12 +28,12 @@ from diffusers.schedulers import SchedulerMixin as Scheduler
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from ..image_util import configure_model_padding from ..image_util import configure_model_padding
from ..util.util import rand_perlin_2d from ..util.util import rand_perlin_2d
from ..safety_checker import SafetyChecker
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
from ..stable_diffusion.schedulers import SCHEDULER_MAP from ..stable_diffusion.schedulers import SCHEDULER_MAP
downsampling = 8 downsampling = 8
@dataclass @dataclass
class InvokeAIGeneratorBasicParams: class InvokeAIGeneratorBasicParams:
seed: Optional[int] = None seed: Optional[int] = None
@ -42,36 +42,39 @@ class InvokeAIGeneratorBasicParams:
cfg_scale: float = 7.5 cfg_scale: float = 7.5
steps: int = 20 steps: int = 20
ddim_eta: float = 0.0 ddim_eta: float = 0.0
scheduler: str='ddim' scheduler: str = "ddim"
precision: str='float16' precision: str = "float16"
perlin: float = 0.0 perlin: float = 0.0
threshold: float = 0.0 threshold: float = 0.0
seamless: bool = False seamless: bool = False
seamless_axes: List[str]=field(default_factory=lambda: ['x', 'y']) seamless_axes: List[str] = field(default_factory=lambda: ["x", "y"])
h_symmetry_time_pct: Optional[float] = None h_symmetry_time_pct: Optional[float] = None
v_symmetry_time_pct: Optional[float] = None v_symmetry_time_pct: Optional[float] = None
variation_amount: float = 0.0 variation_amount: float = 0.0
with_variations: list = field(default_factory=list) with_variations: list = field(default_factory=list)
safety_checker: Optional[SafetyChecker]=None
@dataclass @dataclass
class InvokeAIGeneratorOutput: class InvokeAIGeneratorOutput:
''' """
InvokeAIGeneratorOutput is a dataclass that contains the outputs of a generation InvokeAIGeneratorOutput is a dataclass that contains the outputs of a generation
operation, including the image, its seed, the model name used to generate the image operation, including the image, its seed, the model name used to generate the image
and the model hash, as well as all the generate() parameters that went into and the model hash, as well as all the generate() parameters that went into
generating the image (in .params, also available as attributes) generating the image (in .params, also available as attributes)
''' """
image: Image.Image image: Image.Image
seed: int seed: int
model_hash: str model_hash: str
attention_maps_images: List[Image.Image] attention_maps_images: List[Image.Image]
params: Namespace params: Namespace
# we are interposing a wrapper around the original Generator classes so that # we are interposing a wrapper around the original Generator classes so that
# old code that calls Generate will continue to work. # old code that calls Generate will continue to work.
class InvokeAIGenerator(metaclass=ABCMeta): class InvokeAIGenerator(metaclass=ABCMeta):
def __init__(self, def __init__(
self,
model_info: dict, model_info: dict,
params: InvokeAIGeneratorBasicParams = InvokeAIGeneratorBasicParams(), params: InvokeAIGeneratorBasicParams = InvokeAIGeneratorBasicParams(),
**kwargs, **kwargs,
@ -89,7 +92,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
iterations: int = 1, iterations: int = 1,
**keyword_args, **keyword_args,
) -> Iterator[InvokeAIGeneratorOutput]: ) -> Iterator[InvokeAIGeneratorOutput]:
''' """
Return an iterator across the indicated number of generations. Return an iterator across the indicated number of generations.
Each time the iterator is called it will return an InvokeAIGeneratorOutput Each time the iterator is called it will return an InvokeAIGeneratorOutput
object. Use like this: object. Use like this:
@ -109,7 +112,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
for o in outputs: for o in outputs:
print(o.image, o.seed) print(o.image, o.seed)
''' """
generator_args = dataclasses.asdict(self.params) generator_args = dataclasses.asdict(self.params)
generator_args.update(keyword_args) generator_args.update(keyword_args)
@ -120,21 +123,20 @@ class InvokeAIGenerator(metaclass=ABCMeta):
gen_class = self._generator_class() gen_class = self._generator_class()
generator = gen_class(model, self.params.precision, **self.kwargs) generator = gen_class(model, self.params.precision, **self.kwargs)
if self.params.variation_amount > 0: if self.params.variation_amount > 0:
generator.set_variation(generator_args.get('seed'), generator.set_variation(
generator_args.get('variation_amount'), generator_args.get("seed"),
generator_args.get('with_variations') generator_args.get("variation_amount"),
generator_args.get("with_variations"),
) )
if isinstance(model, DiffusionPipeline): if isinstance(model, DiffusionPipeline):
for component in [model.unet, model.vae]: for component in [model.unet, model.vae]:
configure_model_padding(component, configure_model_padding(
generator_args.get('seamless',False), component, generator_args.get("seamless", False), generator_args.get("seamless_axes")
generator_args.get('seamless_axes')
) )
else: else:
configure_model_padding(model, configure_model_padding(
generator_args.get('seamless',False), model, generator_args.get("seamless", False), generator_args.get("seamless_axes")
generator_args.get('seamless_axes')
) )
iteration_count = range(iterations) if iterations else itertools.count(start=0, step=1) iteration_count = range(iterations) if iterations else itertools.count(start=0, step=1)
@ -158,9 +160,9 @@ class InvokeAIGenerator(metaclass=ABCMeta):
@classmethod @classmethod
def schedulers(self) -> List[str]: def schedulers(self) -> List[str]:
''' """
Return list of all the schedulers that we currently handle. Return list of all the schedulers that we currently handle.
''' """
return list(SCHEDULER_MAP.keys()) return list(SCHEDULER_MAP.keys())
def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]): def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]):
@ -168,33 +170,33 @@ class InvokeAIGenerator(metaclass=ABCMeta):
@classmethod @classmethod
def _generator_class(cls) -> Type[Generator]: def _generator_class(cls) -> Type[Generator]:
''' """
In derived classes return the name of the generator to apply. In derived classes return the name of the generator to apply.
If you don't override will return the name of the derived If you don't override will return the name of the derived
class, which nicely parallels the generator class names. class, which nicely parallels the generator class names.
''' """
return Generator return Generator
# ------------------------------------ # ------------------------------------
class Img2Img(InvokeAIGenerator): class Img2Img(InvokeAIGenerator):
def generate(self, def generate(
init_image: Union[Image.Image, torch.FloatTensor], self, init_image: Union[Image.Image, torch.FloatTensor], strength: float = 0.75, **keyword_args
strength: float=0.75,
**keyword_args
) -> Iterator[InvokeAIGeneratorOutput]: ) -> Iterator[InvokeAIGeneratorOutput]:
return super().generate(init_image=init_image, return super().generate(init_image=init_image, strength=strength, **keyword_args)
strength=strength,
**keyword_args
)
@classmethod @classmethod
def _generator_class(cls): def _generator_class(cls):
from .img2img import Img2Img from .img2img import Img2Img
return Img2Img return Img2Img
# ------------------------------------ # ------------------------------------
# Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff # Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff
class Inpaint(Img2Img): class Inpaint(Img2Img):
def generate(self, def generate(
self,
mask_image: Union[Image.Image, torch.FloatTensor], mask_image: Union[Image.Image, torch.FloatTensor],
# Seam settings - when 0, doesn't fill seam # Seam settings - when 0, doesn't fill seam
seam_size: int = 96, seam_size: int = 96,
@ -207,7 +209,7 @@ class Inpaint(Img2Img):
inpaint_width=None, inpaint_width=None,
inpaint_height=None, inpaint_height=None,
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF), inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF),
**keyword_args **keyword_args,
) -> Iterator[InvokeAIGeneratorOutput]: ) -> Iterator[InvokeAIGeneratorOutput]:
return super().generate( return super().generate(
mask_image=mask_image, mask_image=mask_image,
@ -221,13 +223,16 @@ class Inpaint(Img2Img):
inpaint_width=inpaint_width, inpaint_width=inpaint_width,
inpaint_height=inpaint_height, inpaint_height=inpaint_height,
inpaint_fill=inpaint_fill, inpaint_fill=inpaint_fill,
**keyword_args **keyword_args,
) )
@classmethod @classmethod
def _generator_class(cls): def _generator_class(cls):
from .inpaint import Inpaint from .inpaint import Inpaint
return Inpaint return Inpaint
class Generator: class Generator:
downsampling_factor: int downsampling_factor: int
latent_channels: int latent_channels: int
@ -240,7 +245,6 @@ class Generator:
self.seed = None self.seed = None
self.latent_channels = model.unet.config.in_channels self.latent_channels = model.unet.config.in_channels
self.downsampling_factor = downsampling # BUG: should come from model or config self.downsampling_factor = downsampling # BUG: should come from model or config
self.safety_checker = None
self.perlin = 0.0 self.perlin = 0.0
self.threshold = 0 self.threshold = 0
self.variation_amount = 0 self.variation_amount = 0
@ -254,9 +258,7 @@ class Generator:
Returns a function returning an image derived from the prompt and the initial image Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it Return value depends on the seed at the time you call it
""" """
raise NotImplementedError( raise NotImplementedError("image_iterator() must be implemented in a descendent class")
"image_iterator() must be implemented in a descendent class"
)
def set_variation(self, seed, variation_amount, with_variations): def set_variation(self, seed, variation_amount, with_variations):
self.seed = seed self.seed = seed
@ -277,17 +279,13 @@ class Generator:
perlin=0.0, perlin=0.0,
h_symmetry_time_pct=None, h_symmetry_time_pct=None,
v_symmetry_time_pct=None, v_symmetry_time_pct=None,
safety_checker: SafetyChecker=None,
free_gpu_mem: bool = False, free_gpu_mem: bool = False,
**kwargs, **kwargs,
): ):
scope = nullcontext scope = nullcontext
self.safety_checker = safety_checker
self.free_gpu_mem = free_gpu_mem self.free_gpu_mem = free_gpu_mem
attention_maps_images = [] attention_maps_images = []
attention_maps_callback = lambda saver: attention_maps_images.append( attention_maps_callback = lambda saver: attention_maps_images.append(saver.get_stacked_maps_image())
saver.get_stacked_maps_image()
)
make_image = self.get_make_image( make_image = self.get_make_image(
sampler=sampler, sampler=sampler,
init_image=init_image, init_image=init_image,
@ -329,17 +327,10 @@ class Generator:
# Pass on the seed in case a layer beneath us needs to generate noise on its own. # Pass on the seed in case a layer beneath us needs to generate noise on its own.
image = make_image(x_T, seed) image = make_image(x_T, seed)
if self.safety_checker is not None:
image = self.safety_checker.check(image)
results.append([image, seed, attention_maps_images]) results.append([image, seed, attention_maps_images])
if image_callback is not None: if image_callback is not None:
attention_maps_image = ( attention_maps_image = None if len(attention_maps_images) == 0 else attention_maps_images[-1]
None
if len(attention_maps_images) == 0
else attention_maps_images[-1]
)
image_callback( image_callback(
image, image,
seed, seed,
@ -350,9 +341,7 @@ class Generator:
seed = self.new_seed() seed = self.new_seed()
# Free up memory from the last generation. # Free up memory from the last generation.
clear_cuda_cache = ( clear_cuda_cache = kwargs["clear_cuda_cache"] if "clear_cuda_cache" in kwargs else None
kwargs["clear_cuda_cache"] if "clear_cuda_cache" in kwargs else None
)
if clear_cuda_cache is not None: if clear_cuda_cache is not None:
clear_cuda_cache() clear_cuda_cache()
@ -379,14 +368,8 @@ class Generator:
# Get the original alpha channel of the mask if there is one. # Get the original alpha channel of the mask if there is one.
# Otherwise it is some other black/white image format ('1', 'L' or 'RGB') # Otherwise it is some other black/white image format ('1', 'L' or 'RGB')
pil_init_mask = ( pil_init_mask = init_mask.getchannel("A") if init_mask.mode == "RGBA" else init_mask.convert("L")
init_mask.getchannel("A") pil_init_image = init_image.convert("RGBA") # Add an alpha channel if one doesn't exist
if init_mask.mode == "RGBA"
else init_mask.convert("L")
)
pil_init_image = init_image.convert(
"RGBA"
) # Add an alpha channel if one doesn't exist
# Build an image with only visible pixels from source to use as reference for color-matching. # Build an image with only visible pixels from source to use as reference for color-matching.
init_rgb_pixels = np.asarray(init_image.convert("RGB"), dtype=np.uint8) init_rgb_pixels = np.asarray(init_image.convert("RGB"), dtype=np.uint8)
@ -412,10 +395,7 @@ class Generator:
np_matched_result[:, :, :] = ( np_matched_result[:, :, :] = (
( (
( (
( (np_matched_result[:, :, :].astype(np.float32) - gen_means[None, None, :])
np_matched_result[:, :, :].astype(np.float32)
- gen_means[None, None, :]
)
/ gen_std[None, None, :] / gen_std[None, None, :]
) )
* init_std[None, None, :] * init_std[None, None, :]
@ -441,9 +421,7 @@ class Generator:
else: else:
blurred_init_mask = pil_init_mask blurred_init_mask = pil_init_mask
multiplied_blurred_init_mask = ImageChops.multiply( multiplied_blurred_init_mask = ImageChops.multiply(blurred_init_mask, self.pil_image.split()[-1])
blurred_init_mask, self.pil_image.split()[-1]
)
# Paste original on color-corrected generation (using blurred mask) # Paste original on color-corrected generation (using blurred mask)
matched_result.paste(init_image, (0, 0), mask=multiplied_blurred_init_mask) matched_result.paste(init_image, (0, 0), mask=multiplied_blurred_init_mask)
@ -469,10 +447,7 @@ class Generator:
latent_image = samples[0].permute(1, 2, 0) @ v1_5_latent_rgb_factors latent_image = samples[0].permute(1, 2, 0) @ v1_5_latent_rgb_factors
latents_ubyte = ( latents_ubyte = (
((latent_image + 1) / 2) ((latent_image + 1) / 2).clamp(0, 1).mul(0xFF).byte() # change scale from -1..1 to 0..1 # to 0..255
.clamp(0, 1) # change scale from -1..1 to 0..1
.mul(0xFF) # to 0..255
.byte()
).cpu() ).cpu()
return Image.fromarray(latents_ubyte.numpy()) return Image.fromarray(latents_ubyte.numpy())
@ -502,9 +477,7 @@ class Generator:
temp_height = int((height + 7) / 8) * 8 temp_height = int((height + 7) / 8) * 8
noise = torch.stack( noise = torch.stack(
[ [
rand_perlin_2d( rand_perlin_2d((temp_height, temp_width), (8, 8), device=self.model.device).to(fixdevice)
(temp_height, temp_width), (8, 8), device=self.model.device
).to(fixdevice)
for _ in range(input_channels) for _ in range(input_channels)
], ],
dim=0, dim=0,
@ -581,8 +554,6 @@ class Generator:
device=device, device=device,
) )
if self.perlin > 0.0: if self.perlin > 0.0:
perlin_noise = self.get_perlin_noise( perlin_noise = self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
width // self.downsampling_factor, height // self.downsampling_factor
)
x = (1 - self.perlin) * x + self.perlin * perlin_noise x = (1 - self.perlin) * x + self.perlin * perlin_noise
return x return x

View File

@ -77,10 +77,7 @@ class Img2Img(Generator):
callback=step_callback, callback=step_callback,
seed=seed, seed=seed,
) )
if ( if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
pipeline_output.attention_map_saver is not None
and attention_maps_callback is not None
):
attention_maps_callback(pipeline_output.attention_map_saver) attention_maps_callback(pipeline_output.attention_map_saver)
return pipeline.numpy_to_pil(pipeline_output.images)[0] return pipeline.numpy_to_pil(pipeline_output.images)[0]
@ -91,7 +88,5 @@ class Img2Img(Generator):
x = torch.randn_like(like, device=device) x = torch.randn_like(like, device=device)
if self.perlin > 0.0: if self.perlin > 0.0:
shape = like.shape shape = like.shape
x = (1 - self.perlin) * x + self.perlin * self.get_perlin_noise( x = (1 - self.perlin) * x + self.perlin * self.get_perlin_noise(shape[3], shape[2])
shape[3], shape[2]
)
return x return x

View File

@ -68,15 +68,11 @@ class Inpaint(Img2Img):
return im return im
# Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though) # Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though)
im_patched_np = PatchMatch.inpaint( im_patched_np = PatchMatch.inpaint(im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3)
im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3
)
im_patched = Image.fromarray(im_patched_np, mode="RGB") im_patched = Image.fromarray(im_patched_np, mode="RGB")
return im_patched return im_patched
def tile_fill_missing( def tile_fill_missing(self, im: Image.Image, tile_size: int = 16, seed: Optional[int] = None) -> Image.Image:
self, im: Image.Image, tile_size: int = 16, seed: Optional[int] = None
) -> Image.Image:
# Only fill if there's an alpha layer # Only fill if there's an alpha layer
if im.mode != "RGBA": if im.mode != "RGBA":
return im return im
@ -127,15 +123,11 @@ class Inpaint(Img2Img):
return si return si
def mask_edge( def mask_edge(self, mask: Image.Image, edge_size: int, edge_blur: int) -> Image.Image:
self, mask: Image.Image, edge_size: int, edge_blur: int
) -> Image.Image:
npimg = np.asarray(mask, dtype=np.uint8) npimg = np.asarray(mask, dtype=np.uint8)
# Detect any partially transparent regions # Detect any partially transparent regions
npgradient = np.uint8( npgradient = np.uint8(255 * (1.0 - np.floor(np.abs(0.5 - np.float32(npimg) / 255.0) * 2.0)))
255 * (1.0 - np.floor(np.abs(0.5 - np.float32(npimg) / 255.0) * 2.0))
)
# Detect hard edges # Detect hard edges
npedge = cv2.Canny(npimg, threshold1=100, threshold2=200) npedge = cv2.Canny(npimg, threshold1=100, threshold2=200)
@ -144,9 +136,7 @@ class Inpaint(Img2Img):
npmask = npgradient + npedge npmask = npgradient + npedge
# Expand # Expand
npmask = cv2.dilate( npmask = cv2.dilate(npmask, np.ones((3, 3), np.uint8), iterations=int(edge_size / 2))
npmask, np.ones((3, 3), np.uint8), iterations=int(edge_size / 2)
)
new_mask = Image.fromarray(npmask) new_mask = Image.fromarray(npmask)
@ -242,25 +232,19 @@ class Inpaint(Img2Img):
if infill_method == "patchmatch" and PatchMatch.patchmatch_available(): if infill_method == "patchmatch" and PatchMatch.patchmatch_available():
init_filled = self.infill_patchmatch(self.pil_image.copy()) init_filled = self.infill_patchmatch(self.pil_image.copy())
elif infill_method == "tile": elif infill_method == "tile":
init_filled = self.tile_fill_missing( init_filled = self.tile_fill_missing(self.pil_image.copy(), seed=self.seed, tile_size=tile_size)
self.pil_image.copy(), seed=self.seed, tile_size=tile_size
)
elif infill_method == "solid": elif infill_method == "solid":
solid_bg = Image.new("RGBA", init_image.size, inpaint_fill) solid_bg = Image.new("RGBA", init_image.size, inpaint_fill)
init_filled = Image.alpha_composite(solid_bg, init_image) init_filled = Image.alpha_composite(solid_bg, init_image)
else: else:
raise ValueError( raise ValueError(f"Non-supported infill type {infill_method}", infill_method)
f"Non-supported infill type {infill_method}", infill_method
)
init_filled.paste(init_image, (0, 0), init_image.split()[-1]) init_filled.paste(init_image, (0, 0), init_image.split()[-1])
# Resize if requested for inpainting # Resize if requested for inpainting
if inpaint_width and inpaint_height: if inpaint_width and inpaint_height:
init_filled = init_filled.resize((inpaint_width, inpaint_height)) init_filled = init_filled.resize((inpaint_width, inpaint_height))
debug_image( debug_image(init_filled, "init_filled", debug_status=self.enable_image_debugging)
init_filled, "init_filled", debug_status=self.enable_image_debugging
)
# Create init tensor # Create init tensor
init_image = image_resized_to_grid_as_tensor(init_filled.convert("RGB")) init_image = image_resized_to_grid_as_tensor(init_filled.convert("RGB"))
@ -289,9 +273,7 @@ class Inpaint(Img2Img):
"mask_image AFTER multiply with pil_image", "mask_image AFTER multiply with pil_image",
debug_status=self.enable_image_debugging, debug_status=self.enable_image_debugging,
) )
mask: torch.FloatTensor = image_resized_to_grid_as_tensor( mask: torch.FloatTensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
mask_image, normalize=False
)
else: else:
mask: torch.FloatTensor = mask_image mask: torch.FloatTensor = mask_image
@ -302,9 +284,9 @@ class Inpaint(Img2Img):
# todo: support cross-attention control # todo: support cross-attention control
uc, c, _ = conditioning uc, c, _ = conditioning
conditioning_data = ConditioningData( conditioning_data = ConditioningData(uc, c, cfg_scale).add_scheduler_args_if_applicable(
uc, c, cfg_scale pipeline.scheduler, eta=ddim_eta
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta) )
def make_image(x_T: torch.Tensor, seed: int): def make_image(x_T: torch.Tensor, seed: int):
pipeline_output = pipeline.inpaint_from_embeddings( pipeline_output = pipeline.inpaint_from_embeddings(
@ -318,15 +300,10 @@ class Inpaint(Img2Img):
seed=seed, seed=seed,
) )
if ( if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
pipeline_output.attention_map_saver is not None
and attention_maps_callback is not None
):
attention_maps_callback(pipeline_output.attention_map_saver) attention_maps_callback(pipeline_output.attention_map_saver)
result = self.postprocess_size_and_mask( result = self.postprocess_size_and_mask(pipeline.numpy_to_pil(pipeline_output.images)[0])
pipeline.numpy_to_pil(pipeline_output.images)[0]
)
# Seam paint if this is our first pass (seam_size set to 0 during seam painting) # Seam paint if this is our first pass (seam_size set to 0 during seam painting)
if seam_size > 0: if seam_size > 0:

View File

@ -8,9 +8,7 @@ from .txt2mask import Txt2Mask
from .util import InitImageResizer, make_grid from .util import InitImageResizer, make_grid
def debug_image( def debug_image(debug_image, debug_text, debug_show=True, debug_result=False, debug_status=False):
debug_image, debug_text, debug_show=True, debug_result=False, debug_status=False
):
if not debug_status: if not debug_status:
return return

View File

@ -0,0 +1,34 @@
"""
This module defines a singleton object, "invisible_watermark" that
wraps the invisible watermark model. It respects the global "invisible_watermark"
configuration variable, that allows the watermarking to be supressed.
"""
import numpy as np
import cv2
from PIL import Image
from imwatermark import WatermarkEncoder
from invokeai.app.services.config import InvokeAIAppConfig
import invokeai.backend.util.logging as logger
config = InvokeAIAppConfig.get_config()
class InvisibleWatermark:
"""
Wrapper around InvisibleWatermark module.
"""
@classmethod
def invisible_watermark_available(self) -> bool:
return config.invisible_watermark
@classmethod
def add_watermark(self, image: Image, watermark_text: str) -> Image:
if not self.invisible_watermark_available():
return image
logger.debug(f'Applying invisible watermark "{watermark_text}"')
bgr = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
encoder = WatermarkEncoder()
encoder.set_watermark("bytes", watermark_text.encode("utf-8"))
bgr_encoded = encoder.encode(bgr, "dwtDct")
return Image.fromarray(cv2.cvtColor(bgr_encoded, cv2.COLOR_BGR2RGB)).convert("RGBA")

View File

@ -7,8 +7,10 @@ be suppressed or deferred
import numpy as np import numpy as np
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
config = InvokeAIAppConfig.get_config() config = InvokeAIAppConfig.get_config()
class PatchMatch: class PatchMatch:
""" """
Thin class wrapper around the patchmatch function. Thin class wrapper around the patchmatch function.

View File

@ -34,9 +34,7 @@ class PngWriter:
# saves image named _image_ to outdir/name, writing metadata from prompt # saves image named _image_ to outdir/name, writing metadata from prompt
# returns full path of output # returns full path of output
def save_image_and_prompt_to_png( def save_image_and_prompt_to_png(self, image, dream_prompt, name, metadata=None, compress_level=6):
self, image, dream_prompt, name, metadata=None, compress_level=6
):
path = os.path.join(self.outdir, name) path = os.path.join(self.outdir, name)
info = PngImagePlugin.PngInfo() info = PngImagePlugin.PngInfo()
info.add_text("Dream", dream_prompt) info.add_text("Dream", dream_prompt)
@ -114,8 +112,6 @@ class PromptFormatter:
if opt.variation_amount > 0: if opt.variation_amount > 0:
switches.append(f"-v{opt.variation_amount}") switches.append(f"-v{opt.variation_amount}")
if opt.with_variations: if opt.with_variations:
formatted_variations = ",".join( formatted_variations = ",".join(f"{seed}:{weight}" for seed, weight in opt.with_variations)
f"{seed}:{weight}" for seed, weight in opt.with_variations
)
switches.append(f"-V{formatted_variations}") switches.append(f"-V{formatted_variations}")
return " ".join(switches) return " ".join(switches)

View File

@ -0,0 +1,64 @@
"""
This module defines a singleton object, "safety_checker" that
wraps the safety_checker model. It respects the global "nsfw_checker"
configuration variable, that allows the checker to be supressed.
"""
import numpy as np
from PIL import Image
from invokeai.backend import SilenceWarnings
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util.devices import choose_torch_device
import invokeai.backend.util.logging as logger
config = InvokeAIAppConfig.get_config()
CHECKER_PATH = "core/convert/stable-diffusion-safety-checker"
class SafetyChecker:
"""
Wrapper around SafetyChecker model.
"""
safety_checker = None
feature_extractor = None
tried_load: bool = False
@classmethod
def _load_safety_checker(self):
if self.tried_load:
return
if config.nsfw_checker:
try:
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(config.models_path / CHECKER_PATH)
self.feature_extractor = AutoFeatureExtractor.from_pretrained(config.models_path / CHECKER_PATH)
logger.info("NSFW checker initialized")
except Exception as e:
logger.warning(f"Could not load NSFW checker: {str(e)}")
else:
logger.info("NSFW checker loading disabled")
self.tried_load = True
@classmethod
def safety_checker_available(self) -> bool:
self._load_safety_checker()
return self.safety_checker is not None
@classmethod
def has_nsfw_concept(self, image: Image) -> bool:
if not self.safety_checker_available():
return False
device = choose_torch_device()
features = self.feature_extractor([image], return_tensors="pt")
features.to(device)
self.safety_checker.to(device)
x_image = np.array(image).astype(np.float32) / 255.0
x_image = x_image[None].transpose(0, 3, 1, 2)
with SilenceWarnings():
checked_image, has_nsfw_concept = self.safety_checker(images=x_image, clip_input=features.pixel_values)
return has_nsfw_concept[0]

View File

@ -5,12 +5,8 @@ def _conv_forward_asymmetric(self, input, weight, bias):
""" """
Patch for Conv2d._conv_forward that supports asymmetric padding Patch for Conv2d._conv_forward that supports asymmetric padding
""" """
working = nn.functional.pad( working = nn.functional.pad(input, self.asymmetric_padding["x"], mode=self.asymmetric_padding_mode["x"])
input, self.asymmetric_padding["x"], mode=self.asymmetric_padding_mode["x"] working = nn.functional.pad(working, self.asymmetric_padding["y"], mode=self.asymmetric_padding_mode["y"])
)
working = nn.functional.pad(
working, self.asymmetric_padding["y"], mode=self.asymmetric_padding_mode["y"]
)
return nn.functional.conv2d( return nn.functional.conv2d(
working, working,
weight, weight,
@ -32,18 +28,14 @@ def configure_model_padding(model, seamless, seamless_axes):
if seamless: if seamless:
m.asymmetric_padding_mode = {} m.asymmetric_padding_mode = {}
m.asymmetric_padding = {} m.asymmetric_padding = {}
m.asymmetric_padding_mode["x"] = ( m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant"
"circular" if ("x" in seamless_axes) else "constant"
)
m.asymmetric_padding["x"] = ( m.asymmetric_padding["x"] = (
m._reversed_padding_repeated_twice[0], m._reversed_padding_repeated_twice[0],
m._reversed_padding_repeated_twice[1], m._reversed_padding_repeated_twice[1],
0, 0,
0, 0,
) )
m.asymmetric_padding_mode["y"] = ( m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant"
"circular" if ("y" in seamless_axes) else "constant"
)
m.asymmetric_padding["y"] = ( m.asymmetric_padding["y"] = (
0, 0,
0, 0,

View File

@ -39,23 +39,18 @@ CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
CLIPSEG_SIZE = 352 CLIPSEG_SIZE = 352
config = InvokeAIAppConfig.get_config() config = InvokeAIAppConfig.get_config()
class SegmentedGrayscale(object): class SegmentedGrayscale(object):
def __init__(self, image: Image, heatmap: torch.Tensor): def __init__(self, image: Image, heatmap: torch.Tensor):
self.heatmap = heatmap self.heatmap = heatmap
self.image = image self.image = image
def to_grayscale(self, invert: bool = False) -> Image: def to_grayscale(self, invert: bool = False) -> Image:
return self._rescale( return self._rescale(Image.fromarray(np.uint8(255 - self.heatmap * 255 if invert else self.heatmap * 255)))
Image.fromarray(
np.uint8(255 - self.heatmap * 255 if invert else self.heatmap * 255)
)
)
def to_mask(self, threshold: float = 0.5) -> Image: def to_mask(self, threshold: float = 0.5) -> Image:
discrete_heatmap = self.heatmap.lt(threshold).int() discrete_heatmap = self.heatmap.lt(threshold).int()
return self._rescale( return self._rescale(Image.fromarray(np.uint8(discrete_heatmap * 255), mode="L"))
Image.fromarray(np.uint8(discrete_heatmap * 255), mode="L")
)
def to_transparent(self, invert: bool = False) -> Image: def to_transparent(self, invert: bool = False) -> Image:
transparent_image = self.image.copy() transparent_image = self.image.copy()
@ -67,11 +62,7 @@ class SegmentedGrayscale(object):
# unscales and uncrops the 352x352 heatmap so that it matches the image again # unscales and uncrops the 352x352 heatmap so that it matches the image again
def _rescale(self, heatmap: Image) -> Image: def _rescale(self, heatmap: Image) -> Image:
size = ( size = self.image.width if (self.image.width > self.image.height) else self.image.height
self.image.width
if (self.image.width > self.image.height)
else self.image.height
)
resized_image = heatmap.resize((size, size), resample=Image.Resampling.LANCZOS) resized_image = heatmap.resize((size, size), resample=Image.Resampling.LANCZOS)
return resized_image.crop((0, 0, self.image.width, self.image.height)) return resized_image.crop((0, 0, self.image.width, self.image.height))
@ -87,12 +78,8 @@ class Txt2Mask(object):
# BUG: we are not doing anything with the device option at this time # BUG: we are not doing anything with the device option at this time
self.device = device self.device = device
self.processor = AutoProcessor.from_pretrained( self.processor = AutoProcessor.from_pretrained(CLIPSEG_MODEL, cache_dir=config.cache_dir)
CLIPSEG_MODEL, cache_dir=config.cache_dir self.model = CLIPSegForImageSegmentation.from_pretrained(CLIPSEG_MODEL, cache_dir=config.cache_dir)
)
self.model = CLIPSegForImageSegmentation.from_pretrained(
CLIPSEG_MODEL, cache_dir=config.cache_dir
)
@torch.no_grad() @torch.no_grad()
def segment(self, image, prompt: str) -> SegmentedGrayscale: def segment(self, image, prompt: str) -> SegmentedGrayscale:
@ -107,9 +94,7 @@ class Txt2Mask(object):
image = ImageOps.exif_transpose(image) image = ImageOps.exif_transpose(image)
img = self._scale_and_crop(image) img = self._scale_and_crop(image)
inputs = self.processor( inputs = self.processor(text=[prompt], images=[img], padding=True, return_tensors="pt")
text=[prompt], images=[img], padding=True, return_tensors="pt"
)
outputs = self.model(**inputs) outputs = self.model(**inputs)
heatmap = torch.sigmoid(outputs.logits) heatmap = torch.sigmoid(outputs.logits)
return SegmentedGrayscale(image, heatmap) return SegmentedGrayscale(image, heatmap)

View File

@ -0,0 +1,36 @@
"""
Check that the invokeai_root is correctly configured and exit if not.
"""
import sys
from invokeai.app.services.config import (
InvokeAIAppConfig,
)
def check_invokeai_root(config: InvokeAIAppConfig):
try:
assert config.model_conf_path.exists(), f"{config.model_conf_path} not found"
assert config.db_path.parent.exists(), f"{config.db_path.parent} not found"
assert config.models_path.exists(), f"{config.models_path} not found"
for model in [
"CLIP-ViT-bigG-14-laion2B-39B-b160k",
"bert-base-uncased",
"clip-vit-large-patch14",
"sd-vae-ft-mse",
"stable-diffusion-2-clip",
"stable-diffusion-safety-checker",
]:
path = config.models_path / f"core/convert/{model}"
assert path.exists(), f"{path} is missing"
except Exception as e:
print()
print(f"An exception has occurred: {str(e)}")
print("== STARTUP ABORTED ==")
print("** One or more necessary files is missing from your InvokeAI root directory **")
print("** Please rerun the configuration script to fix this problem. **")
print("** From the launcher, selection option [7]. **")
print(
'** From the command line, activate the virtual environment and run "invokeai-configure --yes --skip-sd-weights" **'
)
input("Press any key to continue...")
sys.exit(0)

View File

@ -13,8 +13,8 @@ import os
import shutil import shutil
import textwrap import textwrap
import traceback import traceback
import warnings
import yaml import yaml
import warnings
from argparse import Namespace from argparse import Namespace
from pathlib import Path from pathlib import Path
from shutil import get_terminal_size from shutil import get_terminal_size
@ -32,6 +32,7 @@ from omegaconf import OmegaConf
from tqdm import tqdm from tqdm import tqdm
from transformers import ( from transformers import (
CLIPTextModel, CLIPTextModel,
CLIPTextConfig,
CLIPTokenizer, CLIPTokenizer,
AutoFeatureExtractor, AutoFeatureExtractor,
BertTokenizerFast, BertTokenizerFast,
@ -44,6 +45,7 @@ from invokeai.app.services.config import (
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
from invokeai.frontend.install.widgets import ( from invokeai.frontend.install.widgets import (
SingleSelectColumns,
CenteredButtonPress, CenteredButtonPress,
FileBox, FileBox,
IntTitleSlider, IntTitleSlider,
@ -58,9 +60,7 @@ from invokeai.backend.install.model_install_backend import (
InstallSelections, InstallSelections,
ModelInstall, ModelInstall,
) )
from invokeai.backend.model_management.model_probe import ( from invokeai.backend.model_management.model_probe import ModelType, BaseModelType
ModelType, BaseModelType
)
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
@ -75,7 +75,7 @@ Model_dir = "models"
Default_config_file = config.model_conf_path Default_config_file = config.model_conf_path
SD_Configs = config.legacy_conf_path SD_Configs = config.legacy_conf_path
PRECISION_CHOICES = ['auto','float16','float32'] PRECISION_CHOICES = ["auto", "float16", "float32"]
INIT_FILE_PREAMBLE = """# InvokeAI initialization file INIT_FILE_PREAMBLE = """# InvokeAI initialization file
# This is the InvokeAI initialization file, which contains command-line default values. # This is the InvokeAI initialization file, which contains command-line default values.
@ -85,6 +85,7 @@ INIT_FILE_PREAMBLE = """# InvokeAI initialization file
logger = InvokeAILogger.getLogger() logger = InvokeAILogger.getLogger()
# -------------------------------------------- # --------------------------------------------
def postscript(errors: None): def postscript(errors: None):
if not any(errors): if not any(errors):
@ -106,7 +107,9 @@ Add the '--help' argument to see all of the command-line switches available for
""" """
else: else:
message = "\n** There were errors during installation. It is possible some of the models were not fully downloaded.\n" message = (
"\n** There were errors during installation. It is possible some of the models were not fully downloaded.\n"
)
for err in errors: for err in errors:
message += f"\t - {err}\n" message += f"\t - {err}\n"
message += "Please check the logs above and correct any issues." message += "Please check the logs above and correct any issues."
@ -167,9 +170,7 @@ def download_with_progress_bar(model_url: str, model_dest: str, label: str = "th
logger.info(f"Installing {label} model file {model_url}...") logger.info(f"Installing {label} model file {model_url}...")
if not os.path.exists(model_dest): if not os.path.exists(model_dest):
os.makedirs(os.path.dirname(model_dest), exist_ok=True) os.makedirs(os.path.dirname(model_dest), exist_ok=True)
request.urlretrieve( request.urlretrieve(model_url, model_dest, ProgressBar(os.path.basename(model_dest)))
model_url, model_dest, ProgressBar(os.path.basename(model_dest))
)
logger.info("...downloaded successfully") logger.info("...downloaded successfully")
else: else:
logger.info("...exists") logger.info("...exists")
@ -180,48 +181,58 @@ def download_with_progress_bar(model_url: str, model_dest: str, label: str = "th
def download_conversion_models(): def download_conversion_models():
target_dir = config.root_path / 'models/core/convert' target_dir = config.models_path / "core/convert"
kwargs = dict() # for future use kwargs = dict() # for future use
try: try:
logger.info('Downloading core tokenizers and text encoders') logger.info("Downloading core tokenizers and text encoders")
# bert # bert
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore", category=DeprecationWarning)
bert = BertTokenizerFast.from_pretrained("bert-base-uncased", **kwargs) bert = BertTokenizerFast.from_pretrained("bert-base-uncased", **kwargs)
bert.save_pretrained(target_dir / 'bert-base-uncased', safe_serialization=True) bert.save_pretrained(target_dir / "bert-base-uncased", safe_serialization=True)
# sd-1 # sd-1
repo_id = 'openai/clip-vit-large-patch14' repo_id = "openai/clip-vit-large-patch14"
hf_download_from_pretrained(CLIPTokenizer, repo_id, target_dir / 'clip-vit-large-patch14') hf_download_from_pretrained(CLIPTokenizer, repo_id, target_dir / "clip-vit-large-patch14")
hf_download_from_pretrained(CLIPTextModel, repo_id, target_dir / 'clip-vit-large-patch14') hf_download_from_pretrained(CLIPTextModel, repo_id, target_dir / "clip-vit-large-patch14")
# sd-2 # sd-2
repo_id = "stabilityai/stable-diffusion-2" repo_id = "stabilityai/stable-diffusion-2"
pipeline = CLIPTokenizer.from_pretrained(repo_id, subfolder="tokenizer", **kwargs) pipeline = CLIPTokenizer.from_pretrained(repo_id, subfolder="tokenizer", **kwargs)
pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'tokenizer', safe_serialization=True) pipeline.save_pretrained(target_dir / "stable-diffusion-2-clip" / "tokenizer", safe_serialization=True)
pipeline = CLIPTextModel.from_pretrained(repo_id, subfolder="text_encoder", **kwargs) pipeline = CLIPTextModel.from_pretrained(repo_id, subfolder="text_encoder", **kwargs)
pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'text_encoder', safe_serialization=True) pipeline.save_pretrained(target_dir / "stable-diffusion-2-clip" / "text_encoder", safe_serialization=True)
# sd-xl - tokenizer_2
repo_id = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
_, model_name = repo_id.split("/")
pipeline = CLIPTokenizer.from_pretrained(repo_id, **kwargs)
pipeline.save_pretrained(target_dir / model_name, safe_serialization=True)
pipeline = CLIPTextConfig.from_pretrained(repo_id, **kwargs)
pipeline.save_pretrained(target_dir / model_name, safe_serialization=True)
# VAE # VAE
logger.info('Downloading stable diffusion VAE') logger.info("Downloading stable diffusion VAE")
vae = AutoencoderKL.from_pretrained('stabilityai/sd-vae-ft-mse', **kwargs) vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", **kwargs)
vae.save_pretrained(target_dir / 'sd-vae-ft-mse', safe_serialization=True) vae.save_pretrained(target_dir / "sd-vae-ft-mse", safe_serialization=True)
# safety checking # safety checking
logger.info('Downloading safety checker') logger.info("Downloading safety checker")
repo_id = "CompVis/stable-diffusion-safety-checker" repo_id = "CompVis/stable-diffusion-safety-checker"
pipeline = AutoFeatureExtractor.from_pretrained(repo_id, **kwargs) pipeline = AutoFeatureExtractor.from_pretrained(repo_id, **kwargs)
pipeline.save_pretrained(target_dir / 'stable-diffusion-safety-checker', safe_serialization=True) pipeline.save_pretrained(target_dir / "stable-diffusion-safety-checker", safe_serialization=True)
pipeline = StableDiffusionSafetyChecker.from_pretrained(repo_id, **kwargs) pipeline = StableDiffusionSafetyChecker.from_pretrained(repo_id, **kwargs)
pipeline.save_pretrained(target_dir / 'stable-diffusion-safety-checker', safe_serialization=True) pipeline.save_pretrained(target_dir / "stable-diffusion-safety-checker", safe_serialization=True)
except KeyboardInterrupt: except KeyboardInterrupt:
raise raise
except Exception as e: except Exception as e:
logger.error(str(e)) logger.error(str(e))
# --------------------------------------------- # ---------------------------------------------
def download_realesrgan(): def download_realesrgan():
logger.info("Installing ESRGAN Upscaling models...") logger.info("Installing ESRGAN Upscaling models...")
@ -248,13 +259,15 @@ def download_realesrgan():
), ),
] ]
for model in URLs: for model in URLs:
download_with_progress_bar(model['url'], config.models_path / model['dest'], model['description']) download_with_progress_bar(model["url"], config.models_path / model["dest"], model["description"])
# --------------------------------------------- # ---------------------------------------------
def download_support_models(): def download_support_models():
download_realesrgan() download_realesrgan()
download_conversion_models() download_conversion_models()
# ------------------------------------- # -------------------------------------
def get_root(root: str = None) -> str: def get_root(root: str = None) -> str:
if root: if root:
@ -264,6 +277,7 @@ def get_root(root: str = None) -> str:
else: else:
return str(config.root_path) return str(config.root_path)
# ------------------------------------- # -------------------------------------
class editOptsForm(CyclingForm, npyscreen.FormMultiPage): class editOptsForm(CyclingForm, npyscreen.FormMultiPage):
# for responsive resizing - disabled # for responsive resizing - disabled
@ -272,7 +286,7 @@ class editOptsForm(CyclingForm, npyscreen.FormMultiPage):
def create(self): def create(self):
program_opts = self.parentApp.program_opts program_opts = self.parentApp.program_opts
old_opts = self.parentApp.invokeai_opts old_opts = self.parentApp.invokeai_opts
first_time = not (config.root_path / 'invokeai.yaml').exists() first_time = not (config.root_path / "invokeai.yaml").exists()
access_token = HfFolder.get_token() access_token = HfFolder.get_token()
window_width, window_height = get_terminal_size() window_width, window_height = get_terminal_size()
label = """Configure startup settings. You can come back and change these later. label = """Configure startup settings. You can come back and change these later.
@ -287,47 +301,6 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
color="CONTROL", color="CONTROL",
) )
self.nextrely += 1
self.add_widget_intelligent(
npyscreen.TitleFixedText,
name="== BASIC OPTIONS ==",
begin_entry_at=0,
editable=False,
color="CONTROL",
scroll_exit=True,
)
self.nextrely -= 1
self.add_widget_intelligent(
npyscreen.FixedText,
value="Select an output directory for images:",
editable=False,
color="CONTROL",
)
self.outdir = self.add_widget_intelligent(
npyscreen.TitleFilename,
name="(<tab> autocompletes, ctrl-N advances):",
value=str(default_output_dir()),
select_dir=True,
must_exist=False,
use_two_lines=False,
labelColor="GOOD",
begin_entry_at=40,
scroll_exit=True,
)
self.nextrely += 1
self.add_widget_intelligent(
npyscreen.FixedText,
value="Activate the NSFW checker to blur images showing potential sexual imagery:",
editable=False,
color="CONTROL",
)
self.nsfw_checker = self.add_widget_intelligent(
npyscreen.Checkbox,
name="NSFW checker",
value=old_opts.nsfw_checker,
relx=5,
scroll_exit=True,
)
self.nextrely += 1 self.nextrely += 1
label = """HuggingFace access token (OPTIONAL) for automatic model downloads. See https://huggingface.co/settings/tokens.""" label = """HuggingFace access token (OPTIONAL) for automatic model downloads. See https://huggingface.co/settings/tokens."""
for line in textwrap.wrap(label, width=window_width - 6): for line in textwrap.wrap(label, width=window_width - 6):
@ -347,15 +320,6 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
scroll_exit=True, scroll_exit=True,
) )
self.nextrely += 1 self.nextrely += 1
self.add_widget_intelligent(
npyscreen.TitleFixedText,
name="== ADVANCED OPTIONS ==",
begin_entry_at=0,
editable=False,
color="CONTROL",
scroll_exit=True,
)
self.nextrely -= 1
self.add_widget_intelligent( self.add_widget_intelligent(
npyscreen.TitleFixedText, npyscreen.TitleFixedText,
name="GPU Management", name="GPU Management",
@ -369,34 +333,47 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
npyscreen.Checkbox, npyscreen.Checkbox,
name="Free GPU memory after each generation", name="Free GPU memory after each generation",
value=old_opts.free_gpu_mem, value=old_opts.free_gpu_mem,
max_width=45,
relx=5, relx=5,
scroll_exit=True, scroll_exit=True,
) )
self.nextrely -= 1
self.xformers_enabled = self.add_widget_intelligent( self.xformers_enabled = self.add_widget_intelligent(
npyscreen.Checkbox, npyscreen.Checkbox,
name="Enable xformers support if available", name="Enable xformers support",
value=old_opts.xformers_enabled, value=old_opts.xformers_enabled,
relx=5, max_width=30,
relx=50,
scroll_exit=True, scroll_exit=True,
) )
self.nextrely -= 1
self.always_use_cpu = self.add_widget_intelligent( self.always_use_cpu = self.add_widget_intelligent(
npyscreen.Checkbox, npyscreen.Checkbox,
name="Force CPU to be used on GPU systems", name="Force CPU to be used on GPU systems",
value=old_opts.always_use_cpu, value=old_opts.always_use_cpu,
relx=5, relx=80,
scroll_exit=True, scroll_exit=True,
) )
precision = old_opts.precision or ( precision = old_opts.precision or ("float32" if program_opts.full_precision else "auto")
"float32" if program_opts.full_precision else "auto" self.nextrely += 1
self.add_widget_intelligent(
npyscreen.TitleFixedText,
name="Floating Point Precision",
begin_entry_at=0,
editable=False,
color="CONTROL",
scroll_exit=True,
) )
self.nextrely -= 1
self.precision = self.add_widget_intelligent( self.precision = self.add_widget_intelligent(
npyscreen.TitleSelectOne, SingleSelectColumns,
columns = 2, columns=3,
name="Precision", name="Precision",
values=PRECISION_CHOICES, values=PRECISION_CHOICES,
value=PRECISION_CHOICES.index(precision), value=PRECISION_CHOICES.index(precision),
begin_entry_at=3, begin_entry_at=3,
max_height=len(PRECISION_CHOICES) + 1, max_height=2,
max_width=80,
scroll_exit=True, scroll_exit=True,
) )
self.max_cache_size = self.add_widget_intelligent( self.max_cache_size = self.add_widget_intelligent(
@ -409,16 +386,22 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
scroll_exit=True, scroll_exit=True,
) )
self.nextrely += 1 self.nextrely += 1
self.add_widget_intelligent( self.outdir = self.add_widget_intelligent(
npyscreen.FixedText, FileBox,
value="Folder to recursively scan for new checkpoints, ControlNets, LoRAs and TI models (<tab> autocompletes, ctrl-N advances):", name="Output directory for images (<tab> autocompletes, ctrl-N advances):",
editable=False, value=str(default_output_dir()),
color="CONTROL", select_dir=True,
must_exist=False,
use_two_lines=False,
labelColor="GOOD",
begin_entry_at=40,
max_height=3,
scroll_exit=True,
) )
self.autoimport_dirs = {} self.autoimport_dirs = {}
self.autoimport_dirs['autoimport_dir'] = self.add_widget_intelligent( self.autoimport_dirs["autoimport_dir"] = self.add_widget_intelligent(
FileBox, FileBox,
name=f'Autoimport Folder', name=f"Folder to recursively scan for new checkpoints, ControlNets, LoRAs and TI models",
value=str(config.root_path / config.autoimport_dir), value=str(config.root_path / config.autoimport_dir),
select_dir=True, select_dir=True,
must_exist=False, must_exist=False,
@ -426,21 +409,13 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
labelColor="GOOD", labelColor="GOOD",
begin_entry_at=32, begin_entry_at=32,
max_height=3, max_height=3,
scroll_exit=True
)
self.nextrely += 1
self.add_widget_intelligent(
npyscreen.TitleFixedText,
name="== LICENSE ==",
begin_entry_at=0,
editable=False,
color="CONTROL",
scroll_exit=True, scroll_exit=True,
) )
self.nextrely -= 1 self.nextrely += 1
label = """BY DOWNLOADING THE STABLE DIFFUSION WEIGHT FILES, YOU AGREE TO HAVE READ label = """BY DOWNLOADING THE STABLE DIFFUSION WEIGHT FILES, YOU AGREE TO HAVE READ
AND ACCEPTED THE CREATIVEML RESPONSIBLE AI LICENSE LOCATED AT AND ACCEPTED THE CREATIVEML RESPONSIBLE AI LICENSES LOCATED AT
https://huggingface.co/spaces/CompVis/stable-diffusion-license https://huggingface.co/spaces/CompVis/stable-diffusion-license and
https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENSE.md
""" """
for i in textwrap.wrap(label, width=window_width - 6): for i in textwrap.wrap(label, width=window_width - 6):
self.add_widget_intelligent( self.add_widget_intelligent(
@ -451,22 +426,17 @@ https://huggingface.co/spaces/CompVis/stable-diffusion-license
) )
self.license_acceptance = self.add_widget_intelligent( self.license_acceptance = self.add_widget_intelligent(
npyscreen.Checkbox, npyscreen.Checkbox,
name="I accept the CreativeML Responsible AI License", name="I accept the CreativeML Responsible AI Licenses",
value=not first_time, value=not first_time,
relx=2, relx=2,
scroll_exit=True, scroll_exit=True,
) )
self.nextrely += 1 self.nextrely += 1
label = ( label = "DONE" if program_opts.skip_sd_weights or program_opts.default_only else "NEXT"
"DONE"
if program_opts.skip_sd_weights or program_opts.default_only
else "NEXT"
)
self.ok_button = self.add_widget_intelligent( self.ok_button = self.add_widget_intelligent(
CenteredButtonPress, CenteredButtonPress,
name=label, name=label,
relx=(window_width - len(label)) // 2, relx=(window_width - len(label)) // 2,
rely=-3,
when_pressed_function=self.on_ok, when_pressed_function=self.on_ok,
) )
@ -485,9 +455,7 @@ https://huggingface.co/spaces/CompVis/stable-diffusion-license
def validate_field_values(self, opt: Namespace) -> bool: def validate_field_values(self, opt: Namespace) -> bool:
bad_fields = [] bad_fields = []
if not opt.license_acceptance: if not opt.license_acceptance:
bad_fields.append( bad_fields.append("Please accept the license terms before proceeding to model downloads")
"Please accept the license terms before proceeding to model downloads"
)
if not Path(opt.outdir).parent.exists(): if not Path(opt.outdir).parent.exists():
bad_fields.append( bad_fields.append(
f"The output directory does not seem to be valid. Please check that {str(Path(opt.outdir).parent)} is an existing directory." f"The output directory does not seem to be valid. Please check that {str(Path(opt.outdir).parent)} is an existing directory."
@ -506,7 +474,6 @@ https://huggingface.co/spaces/CompVis/stable-diffusion-license
for attr in [ for attr in [
"outdir", "outdir",
"nsfw_checker",
"free_gpu_mem", "free_gpu_mem",
"max_cache_size", "max_cache_size",
"xformers_enabled", "xformers_enabled",
@ -542,7 +509,7 @@ class EditOptApplication(npyscreen.NPSAppManaged):
"MAIN", "MAIN",
editOptsForm, editOptsForm,
name="InvokeAI Startup Options", name="InvokeAI Startup Options",
cycle_widgets=True, cycle_widgets=False,
) )
if not (self.program_opts.skip_sd_weights or self.program_opts.default_only): if not (self.program_opts.skip_sd_weights or self.program_opts.default_only):
self.model_select = self.addForm( self.model_select = self.addForm(
@ -550,7 +517,7 @@ class EditOptApplication(npyscreen.NPSAppManaged):
addModelsForm, addModelsForm,
name="Install Stable Diffusion Models", name="Install Stable Diffusion Models",
multipage=True, multipage=True,
cycle_widgets=True, cycle_widgets=False,
) )
def new_opts(self): def new_opts(self):
@ -562,18 +529,17 @@ def edit_opts(program_opts: Namespace, invokeai_opts: Namespace) -> argparse.Nam
editApp.run() editApp.run()
return editApp.new_opts() return editApp.new_opts()
def default_startup_options(init_file: Path) -> Namespace: def default_startup_options(init_file: Path) -> Namespace:
opts = InvokeAIAppConfig.get_config() opts = InvokeAIAppConfig.get_config()
if not init_file.exists():
opts.nsfw_checker = True
return opts return opts
def default_user_selections(program_opts: Namespace) -> InstallSelections:
def default_user_selections(program_opts: Namespace) -> InstallSelections:
try: try:
installer = ModelInstall(config) installer = ModelInstall(config)
except omegaconf.errors.ConfigKeyError: except omegaconf.errors.ConfigKeyError:
logger.warning('Your models.yaml file is corrupt or out of date. Reinitializing') logger.warning("Your models.yaml file is corrupt or out of date. Reinitializing")
initialize_rootdir(config.root_path, True) initialize_rootdir(config.root_path, True)
installer = ModelInstall(config) installer = ModelInstall(config)
@ -586,44 +552,46 @@ def default_user_selections(program_opts: Namespace) -> InstallSelections:
else list(), else list(),
) )
# ------------------------------------- # -------------------------------------
def initialize_rootdir(root: Path, yes_to_all: bool = False): def initialize_rootdir(root: Path, yes_to_all: bool = False):
logger.info("** INITIALIZING INVOKEAI RUNTIME DIRECTORY **") logger.info("Initializing InvokeAI runtime directory")
for name in ( for name in ("models", "databases", "text-inversion-output", "text-inversion-training-data", "configs"):
"models",
"databases",
"text-inversion-output",
"text-inversion-training-data",
"configs"
):
os.makedirs(os.path.join(root, name), exist_ok=True) os.makedirs(os.path.join(root, name), exist_ok=True)
for model_type in ModelType: for model_type in ModelType:
Path(root, 'autoimport', model_type.value).mkdir(parents=True, exist_ok=True) Path(root, "autoimport", model_type.value).mkdir(parents=True, exist_ok=True)
configs_src = Path(configs.__path__[0]) configs_src = Path(configs.__path__[0])
configs_dest = root / "configs" configs_dest = root / "configs"
if not os.path.samefile(configs_src, configs_dest): if not os.path.samefile(configs_src, configs_dest):
shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True) shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True)
dest = root / 'models' dest = root / "models"
for model_base in BaseModelType: for model_base in BaseModelType:
for model_type in ModelType: for model_type in ModelType:
path = dest / model_base.value / model_type.value path = dest / model_base.value / model_type.value
path.mkdir(parents=True, exist_ok=True) path.mkdir(parents=True, exist_ok=True)
path = dest / 'core' path = dest / "core"
path.mkdir(parents=True, exist_ok=True) path.mkdir(parents=True, exist_ok=True)
with open(root / 'configs' / 'models.yaml','w') as yaml_file: maybe_create_models_yaml(root)
yaml_file.write(yaml.dump({'__metadata__':
{'version':'3.0.0'}
} def maybe_create_models_yaml(root: Path):
) models_yaml = root / "configs" / "models.yaml"
) if models_yaml.exists():
if OmegaConf.load(models_yaml).get("__metadata__"): # up to date
return
else:
logger.info("Creating new models.yaml, original saved as models.yaml.orig")
models_yaml.rename(models_yaml.parent / "models.yaml.orig")
with open(models_yaml, "w") as yaml_file:
yaml_file.write(yaml.dump({"__metadata__": {"version": "3.0.0"}}))
# ------------------------------------- # -------------------------------------
def run_console_ui( def run_console_ui(program_opts: Namespace, initfile: Path = None) -> (Namespace, Namespace):
program_opts: Namespace, initfile: Path = None
) -> (Namespace, Namespace):
# parse_args() will read from init file if present # parse_args() will read from init file if present
invokeai_opts = default_startup_options(initfile) invokeai_opts = default_startup_options(initfile)
invokeai_opts.root = program_opts.root invokeai_opts.root = program_opts.root
@ -635,6 +603,7 @@ def run_console_ui(
# the install-models application spawns a subprocess to install # the install-models application spawns a subprocess to install
# models, and will crash unless this is set before running. # models, and will crash unless this is set before running.
import torch import torch
torch.multiprocessing.set_start_method("spawn") torch.multiprocessing.set_start_method("spawn")
editApp = EditOptApplication(program_opts, invokeai_opts) editApp = EditOptApplication(program_opts, invokeai_opts)
@ -658,28 +627,31 @@ def write_opts(opts: Namespace, init_file: Path):
if hasattr(new_config, key): if hasattr(new_config, key):
setattr(new_config, key, value) setattr(new_config, key, value)
with open(init_file,'w', encoding='utf-8') as file: with open(init_file, "w", encoding="utf-8") as file:
file.write(new_config.to_yaml()) file.write(new_config.to_yaml())
if hasattr(opts,'hf_token') and opts.hf_token: if hasattr(opts, "hf_token") and opts.hf_token:
HfLogin(opts.hf_token) HfLogin(opts.hf_token)
# ------------------------------------- # -------------------------------------
def default_output_dir() -> Path: def default_output_dir() -> Path:
return config.root_path / "outputs" return config.root_path / "outputs"
# ------------------------------------- # -------------------------------------
def write_default_options(program_opts: Namespace, initfile: Path): def write_default_options(program_opts: Namespace, initfile: Path):
opt = default_startup_options(initfile) opt = default_startup_options(initfile)
write_opts(opt, initfile) write_opts(opt, initfile)
# ------------------------------------- # -------------------------------------
# Here we bring in # Here we bring in
# the legacy Args object in order to parse # the legacy Args object in order to parse
# the old init file and write out the new # the old init file and write out the new
# yaml format. # yaml format.
def migrate_init_file(legacy_format: Path): def migrate_init_file(legacy_format: Path):
old = legacy_parser.parse_args([f'@{str(legacy_format)}']) old = legacy_parser.parse_args([f"@{str(legacy_format)}"])
new = InvokeAIAppConfig.get_config() new = InvokeAIAppConfig.get_config()
fields = list(get_type_hints(InvokeAIAppConfig).keys()) fields = list(get_type_hints(InvokeAIAppConfig).keys())
@ -689,41 +661,43 @@ def migrate_init_file(legacy_format:Path):
# a few places where the field names have changed and we have to # a few places where the field names have changed and we have to
# manually add in the new names/values # manually add in the new names/values
new.nsfw_checker = old.safety_checker
new.xformers_enabled = old.xformers new.xformers_enabled = old.xformers
new.conf_path = old.conf new.conf_path = old.conf
new.root = legacy_format.parent.resolve() new.root = legacy_format.parent.resolve()
invokeai_yaml = legacy_format.parent / 'invokeai.yaml' invokeai_yaml = legacy_format.parent / "invokeai.yaml"
with open(invokeai_yaml, "w", encoding="utf-8") as outfile: with open(invokeai_yaml, "w", encoding="utf-8") as outfile:
outfile.write(new.to_yaml()) outfile.write(new.to_yaml())
legacy_format.replace(legacy_format.parent / 'invokeai.init.orig') legacy_format.replace(legacy_format.parent / "invokeai.init.orig")
# ------------------------------------- # -------------------------------------
def migrate_models(root: Path): def migrate_models(root: Path):
from invokeai.backend.install.migrate_to_3 import do_migrate from invokeai.backend.install.migrate_to_3 import do_migrate
do_migrate(root, root) do_migrate(root, root)
def migrate_if_needed(opt: Namespace, root: Path) -> bool: def migrate_if_needed(opt: Namespace, root: Path) -> bool:
# We check for to see if the runtime directory is correctly initialized. # We check for to see if the runtime directory is correctly initialized.
old_init_file = root / 'invokeai.init' old_init_file = root / "invokeai.init"
new_init_file = root / 'invokeai.yaml' new_init_file = root / "invokeai.yaml"
old_hub = root / 'models/hub' old_hub = root / "models/hub"
migration_needed = (old_init_file.exists() and not new_init_file.exists()) and old_hub.exists() migration_needed = (old_init_file.exists() and not new_init_file.exists()) and old_hub.exists()
if migration_needed: if migration_needed:
if opt.yes_to_all or \ if opt.yes_to_all or yes_or_no(
yes_or_no(f'{str(config.root_path)} appears to be a 2.3 format root directory. Convert to version 3.0?'): f"{str(config.root_path)} appears to be a 2.3 format root directory. Convert to version 3.0?"
):
logger.info('** Migrating invokeai.init to invokeai.yaml') logger.info("** Migrating invokeai.init to invokeai.yaml")
migrate_init_file(old_init_file) migrate_init_file(old_init_file)
config.parse_args(argv=[], conf=OmegaConf.load(new_init_file)) config.parse_args(argv=[], conf=OmegaConf.load(new_init_file))
if old_hub.exists(): if old_hub.exists():
migrate_models(config.root_path) migrate_models(config.root_path)
else: else:
print('Cannot continue without conversion. Aborting.') print("Cannot continue without conversion. Aborting.")
return migration_needed return migration_needed
@ -784,9 +758,9 @@ def main():
invoke_args = [] invoke_args = []
if opt.root: if opt.root:
invoke_args.extend(['--root',opt.root]) invoke_args.extend(["--root", opt.root])
if opt.full_precision: if opt.full_precision:
invoke_args.extend(['--precision','float32']) invoke_args.extend(["--precision", "float32"])
config.parse_args(invoke_args) config.parse_args(invoke_args)
logger = InvokeAILogger().getLogger(config=config) logger = InvokeAILogger().getLogger(config=config)
@ -798,41 +772,36 @@ def main():
if migrate_if_needed(opt, config.root_path): if migrate_if_needed(opt, config.root_path):
sys.exit(0) sys.exit(0)
if not config.model_conf_path.exists(): # run this unconditionally in case new directories need to be added
initialize_rootdir(config.root_path, opt.yes_to_all) initialize_rootdir(config.root_path, opt.yes_to_all)
models_to_download = default_user_selections(opt) models_to_download = default_user_selections(opt)
new_init_file = config.root_path / 'invokeai.yaml' new_init_file = config.root_path / "invokeai.yaml"
if opt.yes_to_all: if opt.yes_to_all:
write_default_options(opt, new_init_file) write_default_options(opt, new_init_file)
init_options = Namespace( init_options = Namespace(precision="float32" if opt.full_precision else "float16")
precision="float32" if opt.full_precision else "float16"
)
else: else:
init_options, models_to_download = run_console_ui(opt, new_init_file) init_options, models_to_download = run_console_ui(opt, new_init_file)
if init_options: if init_options:
write_opts(init_options, new_init_file) write_opts(init_options, new_init_file)
else: else:
logger.info( logger.info('\n** CANCELLED AT USER\'S REQUEST. USE THE "invoke.sh" LAUNCHER TO RUN LATER **\n')
'\n** CANCELLED AT USER\'S REQUEST. USE THE "invoke.sh" LAUNCHER TO RUN LATER **\n'
)
sys.exit(0) sys.exit(0)
if opt.skip_support_models: if opt.skip_support_models:
logger.info("SKIPPING SUPPORT MODEL DOWNLOADS PER USER REQUEST") logger.info("Skipping support models at user's request")
else: else:
logger.info("CHECKING/UPDATING SUPPORT MODELS") logger.info("Installing support models")
download_support_models() download_support_models()
if opt.skip_sd_weights: if opt.skip_sd_weights:
logger.warning("SKIPPING DIFFUSION WEIGHTS DOWNLOAD PER USER REQUEST") logger.warning("Skipping diffusion weights download per user request")
elif models_to_download: elif models_to_download:
logger.info("DOWNLOADING DIFFUSION WEIGHTS")
process_and_execute(opt, models_to_download) process_and_execute(opt, models_to_download)
postscript(errors=errors) postscript(errors=errors)
if not opt.yes_to_all: if not opt.yes_to_all:
input('Press any key to continue...') input("Press any key to continue...")
except KeyboardInterrupt: except KeyboardInterrupt:
print("\nGoodbye! Come back soon.") print("\nGoodbye! Come back soon.")

View File

@ -47,17 +47,18 @@ PRECISION_CHOICES = [
"float16", "float16",
] ]
class FileArgumentParser(ArgumentParser): class FileArgumentParser(ArgumentParser):
""" """
Supports reading defaults from an init file. Supports reading defaults from an init file.
""" """
def convert_arg_line_to_args(self, arg_line): def convert_arg_line_to_args(self, arg_line):
return shlex.split(arg_line, comments=True) return shlex.split(arg_line, comments=True)
legacy_parser = FileArgumentParser( legacy_parser = FileArgumentParser(
description= description="""
"""
Generate images using Stable Diffusion. Generate images using Stable Diffusion.
Use --web to launch the web interface. Use --web to launch the web interface.
Use --from_file to load prompts from a file path or standard input ("-"). Use --from_file to load prompts from a file path or standard input ("-").
@ -65,304 +66,279 @@ Generate images using Stable Diffusion.
Other command-line arguments are defaults that can usually be overridden Other command-line arguments are defaults that can usually be overridden
prompt the command prompt. prompt the command prompt.
""", """,
fromfile_prefix_chars='@', fromfile_prefix_chars="@",
) )
general_group = legacy_parser.add_argument_group('General') general_group = legacy_parser.add_argument_group("General")
model_group = legacy_parser.add_argument_group('Model selection') model_group = legacy_parser.add_argument_group("Model selection")
file_group = legacy_parser.add_argument_group('Input/output') file_group = legacy_parser.add_argument_group("Input/output")
web_server_group = legacy_parser.add_argument_group('Web server') web_server_group = legacy_parser.add_argument_group("Web server")
render_group = legacy_parser.add_argument_group('Rendering') render_group = legacy_parser.add_argument_group("Rendering")
postprocessing_group = legacy_parser.add_argument_group('Postprocessing') postprocessing_group = legacy_parser.add_argument_group("Postprocessing")
deprecated_group = legacy_parser.add_argument_group('Deprecated options') deprecated_group = legacy_parser.add_argument_group("Deprecated options")
deprecated_group.add_argument('--laion400m') deprecated_group.add_argument("--laion400m")
deprecated_group.add_argument('--weights') # deprecated deprecated_group.add_argument("--weights") # deprecated
general_group.add_argument( general_group.add_argument("--version", "-V", action="store_true", help="Print InvokeAI version number")
'--version','-V',
action='store_true',
help='Print InvokeAI version number'
)
model_group.add_argument( model_group.add_argument(
'--root_dir', "--root_dir",
default=None, default=None,
help='Path to directory containing "models", "outputs" and "configs". If not present will read from environment variable INVOKEAI_ROOT. Defaults to ~/invokeai.', help='Path to directory containing "models", "outputs" and "configs". If not present will read from environment variable INVOKEAI_ROOT. Defaults to ~/invokeai.',
) )
model_group.add_argument( model_group.add_argument(
'--config', "--config",
'-c', "-c",
'-config', "-config",
dest='conf', dest="conf",
default='./configs/models.yaml', default="./configs/models.yaml",
help='Path to configuration file for alternate models.', help="Path to configuration file for alternate models.",
) )
model_group.add_argument( model_group.add_argument(
'--model', "--model",
help='Indicates which diffusion model to load (defaults to "default" stanza in configs/models.yaml)', help='Indicates which diffusion model to load (defaults to "default" stanza in configs/models.yaml)',
) )
model_group.add_argument( model_group.add_argument(
'--weight_dirs', "--weight_dirs",
nargs='+', nargs="+",
type=str, type=str,
help='List of one or more directories that will be auto-scanned for new model weights to import', help="List of one or more directories that will be auto-scanned for new model weights to import",
) )
model_group.add_argument( model_group.add_argument(
'--png_compression','-z', "--png_compression",
"-z",
type=int, type=int,
default=6, default=6,
choices=range(0, 9), choices=range(0, 9),
dest='png_compression', dest="png_compression",
help='level of PNG compression, from 0 (none) to 9 (maximum). Default is 6.' help="level of PNG compression, from 0 (none) to 9 (maximum). Default is 6.",
) )
model_group.add_argument( model_group.add_argument(
'-F', "-F",
'--full_precision', "--full_precision",
dest='full_precision', dest="full_precision",
action='store_true', action="store_true",
help='Deprecated way to set --precision=float32', help="Deprecated way to set --precision=float32",
) )
model_group.add_argument( model_group.add_argument(
'--max_loaded_models', "--max_loaded_models",
dest='max_loaded_models', dest="max_loaded_models",
type=int, type=int,
default=2, default=2,
help='Maximum number of models to keep in memory for fast switching, including the one in GPU', help="Maximum number of models to keep in memory for fast switching, including the one in GPU",
) )
model_group.add_argument( model_group.add_argument(
'--free_gpu_mem', "--free_gpu_mem",
dest='free_gpu_mem', dest="free_gpu_mem",
action='store_true', action="store_true",
help='Force free gpu memory before final decoding', help="Force free gpu memory before final decoding",
) )
model_group.add_argument( model_group.add_argument(
'--sequential_guidance', "--sequential_guidance",
dest='sequential_guidance', dest="sequential_guidance",
action='store_true', action="store_true",
help="Calculate guidance in serial instead of in parallel, lowering memory requirement " help="Calculate guidance in serial instead of in parallel, lowering memory requirement " "at the expense of speed",
"at the expense of speed",
) )
model_group.add_argument( model_group.add_argument(
'--xformers', "--xformers",
action=argparse.BooleanOptionalAction, action=argparse.BooleanOptionalAction,
default=True, default=True,
help='Enable/disable xformers support (default enabled if installed)', help="Enable/disable xformers support (default enabled if installed)",
) )
model_group.add_argument( model_group.add_argument(
"--always_use_cpu", "--always_use_cpu", dest="always_use_cpu", action="store_true", help="Force use of CPU even if GPU is available"
dest="always_use_cpu",
action="store_true",
help="Force use of CPU even if GPU is available"
) )
model_group.add_argument( model_group.add_argument(
'--precision', "--precision",
dest='precision', dest="precision",
type=str, type=str,
choices=PRECISION_CHOICES, choices=PRECISION_CHOICES,
metavar='PRECISION', metavar="PRECISION",
help=f'Set model precision. Defaults to auto selected based on device. Options: {", ".join(PRECISION_CHOICES)}', help=f'Set model precision. Defaults to auto selected based on device. Options: {", ".join(PRECISION_CHOICES)}',
default='auto', default="auto",
) )
model_group.add_argument( model_group.add_argument(
'--ckpt_convert', "--ckpt_convert",
action=argparse.BooleanOptionalAction, action=argparse.BooleanOptionalAction,
dest='ckpt_convert', dest="ckpt_convert",
default=True, default=True,
help='Deprecated option. Legacy ckpt files are now always converted to diffusers when loaded.' help="Deprecated option. Legacy ckpt files are now always converted to diffusers when loaded.",
) )
model_group.add_argument( model_group.add_argument(
'--internet', "--internet",
action=argparse.BooleanOptionalAction, action=argparse.BooleanOptionalAction,
dest='internet_available', dest="internet_available",
default=True, default=True,
help='Indicate whether internet is available for just-in-time model downloading (default: probe automatically).', help="Indicate whether internet is available for just-in-time model downloading (default: probe automatically).",
) )
model_group.add_argument( model_group.add_argument(
'--nsfw_checker', "--nsfw_checker",
'--safety_checker', "--safety_checker",
action=argparse.BooleanOptionalAction, action=argparse.BooleanOptionalAction,
dest='safety_checker', dest="safety_checker",
default=False, default=False,
help='Check for and blur potentially NSFW images. Use --no-nsfw_checker to disable.', help="Check for and blur potentially NSFW images. Use --no-nsfw_checker to disable.",
) )
model_group.add_argument( model_group.add_argument(
'--autoimport', "--autoimport",
default=None, default=None,
type=str, type=str,
help='Check the indicated directory for .ckpt/.safetensors weights files at startup and import directly', help="Check the indicated directory for .ckpt/.safetensors weights files at startup and import directly",
) )
model_group.add_argument( model_group.add_argument(
'--autoconvert', "--autoconvert",
default=None, default=None,
type=str, type=str,
help='Check the indicated directory for .ckpt/.safetensors weights files at startup and import as optimized diffuser models', help="Check the indicated directory for .ckpt/.safetensors weights files at startup and import as optimized diffuser models",
) )
model_group.add_argument( model_group.add_argument(
'--patchmatch', "--patchmatch",
action=argparse.BooleanOptionalAction, action=argparse.BooleanOptionalAction,
default=True, default=True,
help='Load the patchmatch extension for outpainting. Use --no-patchmatch to disable.', help="Load the patchmatch extension for outpainting. Use --no-patchmatch to disable.",
) )
file_group.add_argument( file_group.add_argument(
'--from_file', "--from_file",
dest='infile', dest="infile",
type=str, type=str,
help='If specified, load prompts from this file', help="If specified, load prompts from this file",
) )
file_group.add_argument( file_group.add_argument(
'--outdir', "--outdir",
'-o', "-o",
type=str, type=str,
help='Directory to save generated images and a log of prompts and seeds. Default: ROOTDIR/outputs', help="Directory to save generated images and a log of prompts and seeds. Default: ROOTDIR/outputs",
default='outputs', default="outputs",
) )
file_group.add_argument( file_group.add_argument(
'--prompt_as_dir', "--prompt_as_dir",
'-p', "-p",
action='store_true', action="store_true",
help='Place images in subdirectories named after the prompt.', help="Place images in subdirectories named after the prompt.",
) )
render_group.add_argument( render_group.add_argument(
'--fnformat', "--fnformat",
default='{prefix}.{seed}.png', default="{prefix}.{seed}.png",
type=str, type=str,
help='Overwrite the filename format. You can use any argument as wildcard enclosed in curly braces. Default is {prefix}.{seed}.png', help="Overwrite the filename format. You can use any argument as wildcard enclosed in curly braces. Default is {prefix}.{seed}.png",
) )
render_group.add_argument("-s", "--steps", type=int, default=50, help="Number of steps")
render_group.add_argument( render_group.add_argument(
'-s', "-W",
'--steps', "--width",
type=int, type=int,
default=50, help="Image width, multiple of 64",
help='Number of steps'
) )
render_group.add_argument( render_group.add_argument(
'-W', "-H",
'--width', "--height",
type=int, type=int,
help='Image width, multiple of 64', help="Image height, multiple of 64",
) )
render_group.add_argument( render_group.add_argument(
'-H', "-C",
'--height', "--cfg_scale",
type=int,
help='Image height, multiple of 64',
)
render_group.add_argument(
'-C',
'--cfg_scale',
default=7.5, default=7.5,
type=float, type=float,
help='Classifier free guidance (CFG) scale - higher numbers cause generator to "try" harder.', help='Classifier free guidance (CFG) scale - higher numbers cause generator to "try" harder.',
) )
render_group.add_argument( render_group.add_argument(
'--sampler', "--sampler",
'-A', "-A",
'-m', "-m",
dest='sampler_name', dest="sampler_name",
type=str, type=str,
choices=SAMPLER_CHOICES, choices=SAMPLER_CHOICES,
metavar='SAMPLER_NAME', metavar="SAMPLER_NAME",
help=f'Set the default sampler. Supported samplers: {", ".join(SAMPLER_CHOICES)}', help=f'Set the default sampler. Supported samplers: {", ".join(SAMPLER_CHOICES)}',
default='k_lms', default="k_lms",
) )
render_group.add_argument( render_group.add_argument(
'--log_tokenization', "--log_tokenization", "-t", action="store_true", help="shows how the prompt is split into tokens"
'-t',
action='store_true',
help='shows how the prompt is split into tokens'
) )
render_group.add_argument( render_group.add_argument(
'-f', "-f",
'--strength', "--strength",
type=float, type=float,
help='img2img strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely', help="img2img strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely",
) )
render_group.add_argument( render_group.add_argument(
'-T', "-T",
'-fit', "-fit",
'--fit', "--fit",
action=argparse.BooleanOptionalAction, action=argparse.BooleanOptionalAction,
help='If specified, will resize the input image to fit within the dimensions of width x height (512x512 default)', help="If specified, will resize the input image to fit within the dimensions of width x height (512x512 default)",
) )
render_group.add_argument("--grid", "-g", action=argparse.BooleanOptionalAction, help="generate a grid")
render_group.add_argument( render_group.add_argument(
'--grid', "--embedding_directory",
'-g', "--embedding_path",
action=argparse.BooleanOptionalAction, dest="embedding_path",
help='generate a grid' default="embeddings",
)
render_group.add_argument(
'--embedding_directory',
'--embedding_path',
dest='embedding_path',
default='embeddings',
type=str, type=str,
help='Path to a directory containing .bin and/or .pt files, or a single .bin/.pt file. You may use subdirectories. (default is ROOTDIR/embeddings)' help="Path to a directory containing .bin and/or .pt files, or a single .bin/.pt file. You may use subdirectories. (default is ROOTDIR/embeddings)",
) )
render_group.add_argument( render_group.add_argument(
'--lora_directory', "--lora_directory",
dest='lora_path', dest="lora_path",
default='loras', default="loras",
type=str, type=str,
help='Path to a directory containing LoRA files; subdirectories are not supported. (default is ROOTDIR/loras)' help="Path to a directory containing LoRA files; subdirectories are not supported. (default is ROOTDIR/loras)",
) )
render_group.add_argument( render_group.add_argument(
'--embeddings', "--embeddings",
action=argparse.BooleanOptionalAction, action=argparse.BooleanOptionalAction,
default=True, default=True,
help='Enable embedding directory (default). Use --no-embeddings to disable.', help="Enable embedding directory (default). Use --no-embeddings to disable.",
) )
render_group.add_argument("--enable_image_debugging", action="store_true", help="Generates debugging image to display")
render_group.add_argument( render_group.add_argument(
'--enable_image_debugging', "--karras_max",
action='store_true',
help='Generates debugging image to display'
)
render_group.add_argument(
'--karras_max',
type=int, type=int,
default=None, default=None,
help="control the point at which the K* samplers will shift from using the Karras noise schedule (good for low step counts) to the LatentDiffusion noise schedule (good for high step counts). Set to 0 to use LatentDiffusion for all step values, and to a high value (e.g. 1000) to use Karras for all step values. [29]." help="control the point at which the K* samplers will shift from using the Karras noise schedule (good for low step counts) to the LatentDiffusion noise schedule (good for high step counts). Set to 0 to use LatentDiffusion for all step values, and to a high value (e.g. 1000) to use Karras for all step values. [29].",
) )
# Restoration related args # Restoration related args
postprocessing_group.add_argument( postprocessing_group.add_argument(
'--no_restore', "--no_restore",
dest='restore', dest="restore",
action='store_false', action="store_false",
help='Disable face restoration with GFPGAN or codeformer', help="Disable face restoration with GFPGAN or codeformer",
) )
postprocessing_group.add_argument( postprocessing_group.add_argument(
'--no_upscale', "--no_upscale",
dest='esrgan', dest="esrgan",
action='store_false', action="store_false",
help='Disable upscaling with ESRGAN', help="Disable upscaling with ESRGAN",
) )
postprocessing_group.add_argument( postprocessing_group.add_argument(
'--esrgan_bg_tile', "--esrgan_bg_tile",
type=int, type=int,
default=400, default=400,
help='Tile size for background sampler, 0 for no tile during testing. Default: 400.', help="Tile size for background sampler, 0 for no tile during testing. Default: 400.",
) )
postprocessing_group.add_argument( postprocessing_group.add_argument(
'--esrgan_denoise_str', "--esrgan_denoise_str",
type=float, type=float,
default=0.75, default=0.75,
help='esrgan denoise str. 0 is no denoise, 1 is max denoise. Default: 0.75', help="esrgan denoise str. 0 is no denoise, 1 is max denoise. Default: 0.75",
) )
postprocessing_group.add_argument( postprocessing_group.add_argument(
'--gfpgan_model_path', "--gfpgan_model_path",
type=str, type=str,
default='./models/gfpgan/GFPGANv1.4.pth', default="./models/gfpgan/GFPGANv1.4.pth",
help='Indicates the path to the GFPGAN model', help="Indicates the path to the GFPGAN model",
) )
web_server_group.add_argument( web_server_group.add_argument(
'--web', "--web",
dest='web', dest="web",
action='store_true', action="store_true",
help='Start in web server mode.', help="Start in web server mode.",
) )
web_server_group.add_argument( web_server_group.add_argument(
'--web_develop', "--web_develop",
dest='web_develop', dest="web_develop",
action='store_true', action="store_true",
help='Start in web server development mode.', help="Start in web server development mode.",
) )
web_server_group.add_argument( web_server_group.add_argument(
"--web_verbose", "--web_verbose",
@ -376,32 +352,27 @@ web_server_group.add_argument(
help="Additional allowed origins, comma-separated", help="Additional allowed origins, comma-separated",
) )
web_server_group.add_argument( web_server_group.add_argument(
'--host', "--host",
type=str, type=str,
default='127.0.0.1', default="127.0.0.1",
help='Web server: Host or IP to listen on. Set to 0.0.0.0 to accept traffic from other devices on your network.' help="Web server: Host or IP to listen on. Set to 0.0.0.0 to accept traffic from other devices on your network.",
) )
web_server_group.add_argument("--port", type=int, default="9090", help="Web server: Port to listen on")
web_server_group.add_argument( web_server_group.add_argument(
'--port', "--certfile",
type=int,
default='9090',
help='Web server: Port to listen on'
)
web_server_group.add_argument(
'--certfile',
type=str, type=str,
default=None, default=None,
help='Web server: Path to certificate file to use for SSL. Use together with --keyfile' help="Web server: Path to certificate file to use for SSL. Use together with --keyfile",
) )
web_server_group.add_argument( web_server_group.add_argument(
'--keyfile', "--keyfile",
type=str, type=str,
default=None, default=None,
help='Web server: Path to private key file to use for SSL. Use together with --certfile' help="Web server: Path to private key file to use for SSL. Use together with --certfile",
) )
web_server_group.add_argument( web_server_group.add_argument(
'--gui', "--gui",
dest='gui', dest="gui",
action='store_true', action="store_true",
help='Start InvokeAI GUI', help="Start InvokeAI GUI",
) )

View File

@ -1,7 +1,7 @@
''' """
Migrate the models directory and models.yaml file from an existing Migrate the models directory and models.yaml file from an existing
InvokeAI 2.3 installation to 3.0.0. InvokeAI 2.3 installation to 3.0.0.
''' """
import os import os
import argparse import argparse
@ -29,14 +29,13 @@ from transformers import (
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_management import ModelManager from invokeai.backend.model_management import ModelManager
from invokeai.backend.model_management.model_probe import ( from invokeai.backend.model_management.model_probe import ModelProbe, ModelType, BaseModelType, ModelProbeInfo
ModelProbe, ModelType, BaseModelType, ModelProbeInfo
)
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
diffusers.logging.set_verbosity_error() diffusers.logging.set_verbosity_error()
# holder for paths that we will migrate # holder for paths that we will migrate
@dataclass @dataclass
class ModelPaths: class ModelPaths:
@ -45,8 +44,10 @@ class ModelPaths:
loras: Path loras: Path
controlnets: Path controlnets: Path
class MigrateTo3(object): class MigrateTo3(object):
def __init__(self, def __init__(
self,
from_root: Path, from_root: Path,
to_models: Path, to_models: Path,
model_manager: ModelManager, model_manager: ModelManager,
@ -59,62 +60,61 @@ class MigrateTo3(object):
@classmethod @classmethod
def initialize_yaml(cls, yaml_file: Path): def initialize_yaml(cls, yaml_file: Path):
with open(yaml_file, 'w') as file: with open(yaml_file, "w") as file:
file.write( file.write(yaml.dump({"__metadata__": {"version": "3.0.0"}}))
yaml.dump(
{
'__metadata__': {'version':'3.0.0'}
}
)
)
def create_directory_structure(self): def create_directory_structure(self):
''' """
Create the basic directory structure for the models folder. Create the basic directory structure for the models folder.
''' """
for model_base in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]: for model_base in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
for model_type in [ModelType.Main, ModelType.Vae, ModelType.Lora, for model_type in [
ModelType.ControlNet,ModelType.TextualInversion]: ModelType.Main,
ModelType.Vae,
ModelType.Lora,
ModelType.ControlNet,
ModelType.TextualInversion,
]:
path = self.dest_models / model_base.value / model_type.value path = self.dest_models / model_base.value / model_type.value
path.mkdir(parents=True, exist_ok=True) path.mkdir(parents=True, exist_ok=True)
path = self.dest_models / 'core' path = self.dest_models / "core"
path.mkdir(parents=True, exist_ok=True) path.mkdir(parents=True, exist_ok=True)
@staticmethod @staticmethod
def copy_file(src: Path, dest: Path): def copy_file(src: Path, dest: Path):
''' """
copy a single file with logging copy a single file with logging
''' """
if dest.exists(): if dest.exists():
logger.info(f'Skipping existing {str(dest)}') logger.info(f"Skipping existing {str(dest)}")
return return
logger.info(f'Copying {str(src)} to {str(dest)}') logger.info(f"Copying {str(src)} to {str(dest)}")
try: try:
shutil.copy(src, dest) shutil.copy(src, dest)
except Exception as e: except Exception as e:
logger.error(f'COPY FAILED: {str(e)}') logger.error(f"COPY FAILED: {str(e)}")
@staticmethod @staticmethod
def copy_dir(src: Path, dest: Path): def copy_dir(src: Path, dest: Path):
''' """
Recursively copy a directory with logging Recursively copy a directory with logging
''' """
if dest.exists(): if dest.exists():
logger.info(f'Skipping existing {str(dest)}') logger.info(f"Skipping existing {str(dest)}")
return return
logger.info(f'Copying {str(src)} to {str(dest)}') logger.info(f"Copying {str(src)} to {str(dest)}")
try: try:
shutil.copytree(src, dest) shutil.copytree(src, dest)
except Exception as e: except Exception as e:
logger.error(f'COPY FAILED: {str(e)}') logger.error(f"COPY FAILED: {str(e)}")
def migrate_models(self, src_dir: Path): def migrate_models(self, src_dir: Path):
''' """
Recursively walk through src directory, probe anything Recursively walk through src directory, probe anything
that looks like a model, and copy the model into the that looks like a model, and copy the model into the
appropriate location within the destination models directory. appropriate location within the destination models directory.
''' """
directories_scanned = set() directories_scanned = set()
for root, dirs, files in os.walk(src_dir): for root, dirs, files in os.walk(src_dir):
for d in dirs: for d in dirs:
@ -136,7 +136,7 @@ class MigrateTo3(object):
# don't copy raw learned_embeds.bin or pytorch_lora_weights.bin # don't copy raw learned_embeds.bin or pytorch_lora_weights.bin
# let them be copied as part of a tree copy operation # let them be copied as part of a tree copy operation
try: try:
if f in {'learned_embeds.bin','pytorch_lora_weights.bin'}: if f in {"learned_embeds.bin", "pytorch_lora_weights.bin"}:
continue continue
model = Path(root, f) model = Path(root, f)
if model.parent in directories_scanned: if model.parent in directories_scanned:
@ -154,97 +154,95 @@ class MigrateTo3(object):
logger.error(str(e)) logger.error(str(e))
def migrate_support_models(self): def migrate_support_models(self):
''' """
Copy the clipseg, upscaler, and restoration models to their new Copy the clipseg, upscaler, and restoration models to their new
locations. locations.
''' """
dest_directory = self.dest_models dest_directory = self.dest_models
if (self.root_directory / 'models/clipseg').exists(): if (self.root_directory / "models/clipseg").exists():
self.copy_dir(self.root_directory / 'models/clipseg', dest_directory / 'core/misc/clipseg') self.copy_dir(self.root_directory / "models/clipseg", dest_directory / "core/misc/clipseg")
if (self.root_directory / 'models/realesrgan').exists(): if (self.root_directory / "models/realesrgan").exists():
self.copy_dir(self.root_directory / 'models/realesrgan', dest_directory / 'core/upscaling/realesrgan') self.copy_dir(self.root_directory / "models/realesrgan", dest_directory / "core/upscaling/realesrgan")
for d in ['codeformer','gfpgan']: for d in ["codeformer", "gfpgan"]:
path = self.root_directory / 'models' / d path = self.root_directory / "models" / d
if path.exists(): if path.exists():
self.copy_dir(path,dest_directory / f'core/face_restoration/{d}') self.copy_dir(path, dest_directory / f"core/face_restoration/{d}")
def migrate_tuning_models(self): def migrate_tuning_models(self):
''' """
Migrate the embeddings, loras and controlnets directories to their new homes. Migrate the embeddings, loras and controlnets directories to their new homes.
''' """
for src in [self.src_paths.embeddings, self.src_paths.loras, self.src_paths.controlnets]: for src in [self.src_paths.embeddings, self.src_paths.loras, self.src_paths.controlnets]:
if not src: if not src:
continue continue
if src.is_dir(): if src.is_dir():
logger.info(f'Scanning {src}') logger.info(f"Scanning {src}")
self.migrate_models(src) self.migrate_models(src)
else: else:
logger.info(f'{src} directory not found; skipping') logger.info(f"{src} directory not found; skipping")
continue continue
def migrate_conversion_models(self): def migrate_conversion_models(self):
''' """
Migrate all the models that are needed by the ckpt_to_diffusers conversion Migrate all the models that are needed by the ckpt_to_diffusers conversion
script. script.
''' """
dest_directory = self.dest_models dest_directory = self.dest_models
kwargs = dict( kwargs = dict(
cache_dir = self.root_directory / 'models/hub', cache_dir=self.root_directory / "models/hub",
# local_files_only = True # local_files_only = True
) )
try: try:
logger.info('Migrating core tokenizers and text encoders') logger.info("Migrating core tokenizers and text encoders")
target_dir = dest_directory / 'core' / 'convert' target_dir = dest_directory / "core" / "convert"
self._migrate_pretrained(BertTokenizerFast, self._migrate_pretrained(
repo_id='bert-base-uncased', BertTokenizerFast, repo_id="bert-base-uncased", dest=target_dir / "bert-base-uncased", **kwargs
dest = target_dir / 'bert-base-uncased', )
**kwargs)
# sd-1 # sd-1
repo_id = 'openai/clip-vit-large-patch14' repo_id = "openai/clip-vit-large-patch14"
self._migrate_pretrained(CLIPTokenizer, self._migrate_pretrained(
repo_id= repo_id, CLIPTokenizer, repo_id=repo_id, dest=target_dir / "clip-vit-large-patch14", **kwargs
dest= target_dir / 'clip-vit-large-patch14', )
**kwargs) self._migrate_pretrained(
self._migrate_pretrained(CLIPTextModel, CLIPTextModel, repo_id=repo_id, dest=target_dir / "clip-vit-large-patch14", force=True, **kwargs
repo_id = repo_id, )
dest = target_dir / 'clip-vit-large-patch14',
force = True,
**kwargs)
# sd-2 # sd-2
repo_id = "stabilityai/stable-diffusion-2" repo_id = "stabilityai/stable-diffusion-2"
self._migrate_pretrained(CLIPTokenizer, self._migrate_pretrained(
CLIPTokenizer,
repo_id=repo_id, repo_id=repo_id,
dest = target_dir / 'stable-diffusion-2-clip' / 'tokenizer', dest=target_dir / "stable-diffusion-2-clip" / "tokenizer",
**{'subfolder':'tokenizer',**kwargs} **{"subfolder": "tokenizer", **kwargs},
) )
self._migrate_pretrained(CLIPTextModel, self._migrate_pretrained(
CLIPTextModel,
repo_id=repo_id, repo_id=repo_id,
dest = target_dir / 'stable-diffusion-2-clip' / 'text_encoder', dest=target_dir / "stable-diffusion-2-clip" / "text_encoder",
**{'subfolder':'text_encoder',**kwargs} **{"subfolder": "text_encoder", **kwargs},
) )
# VAE # VAE
logger.info('Migrating stable diffusion VAE') logger.info("Migrating stable diffusion VAE")
self._migrate_pretrained(AutoencoderKL, self._migrate_pretrained(
repo_id = 'stabilityai/sd-vae-ft-mse', AutoencoderKL, repo_id="stabilityai/sd-vae-ft-mse", dest=target_dir / "sd-vae-ft-mse", **kwargs
dest = target_dir / 'sd-vae-ft-mse', )
**kwargs)
# safety checking # safety checking
logger.info('Migrating safety checker') logger.info("Migrating safety checker")
repo_id = "CompVis/stable-diffusion-safety-checker" repo_id = "CompVis/stable-diffusion-safety-checker"
self._migrate_pretrained(AutoFeatureExtractor, self._migrate_pretrained(
AutoFeatureExtractor, repo_id=repo_id, dest=target_dir / "stable-diffusion-safety-checker", **kwargs
)
self._migrate_pretrained(
StableDiffusionSafetyChecker,
repo_id=repo_id, repo_id=repo_id,
dest = target_dir / 'stable-diffusion-safety-checker', dest=target_dir / "stable-diffusion-safety-checker",
**kwargs) **kwargs,
self._migrate_pretrained(StableDiffusionSafetyChecker, )
repo_id = repo_id,
dest = target_dir / 'stable-diffusion-safety-checker',
**kwargs)
except KeyboardInterrupt: except KeyboardInterrupt:
raise raise
except Exception as e: except Exception as e:
@ -255,7 +253,7 @@ class MigrateTo3(object):
def _migrate_pretrained(self, model_class, repo_id: str, dest: Path, force: bool = False, **kwargs): def _migrate_pretrained(self, model_class, repo_id: str, dest: Path, force: bool = False, **kwargs):
if dest.exists() and not force: if dest.exists() and not force:
logger.info(f'Skipping existing {dest}') logger.info(f"Skipping existing {dest}")
return return
model = model_class.from_pretrained(repo_id, **kwargs) model = model_class.from_pretrained(repo_id, **kwargs)
self._save_pretrained(model, dest, overwrite=force) self._save_pretrained(model, dest, overwrite=force)
@ -265,22 +263,22 @@ class MigrateTo3(object):
if overwrite: if overwrite:
model.save_pretrained(dest, safe_serialization=True) model.save_pretrained(dest, safe_serialization=True)
else: else:
download_path = dest.with_name(f'{model_name}.downloading') download_path = dest.with_name(f"{model_name}.downloading")
model.save_pretrained(download_path, safe_serialization=True) model.save_pretrained(download_path, safe_serialization=True)
download_path.replace(dest) download_path.replace(dest)
def _download_vae(self, repo_id: str, subfolder: str = None) -> Path: def _download_vae(self, repo_id: str, subfolder: str = None) -> Path:
vae = AutoencoderKL.from_pretrained(repo_id, cache_dir=self.root_directory / 'models/hub', subfolder=subfolder) vae = AutoencoderKL.from_pretrained(repo_id, cache_dir=self.root_directory / "models/hub", subfolder=subfolder)
info = ModelProbe().heuristic_probe(vae) info = ModelProbe().heuristic_probe(vae)
_, model_name = repo_id.split('/') _, model_name = repo_id.split("/")
dest = self._model_probe_to_path(info) / self.unique_name(model_name, info) dest = self._model_probe_to_path(info) / self.unique_name(model_name, info)
vae.save_pretrained(dest, safe_serialization=True) vae.save_pretrained(dest, safe_serialization=True)
return dest return dest
def _vae_path(self, vae: Union[str, dict]) -> Path: def _vae_path(self, vae: Union[str, dict]) -> Path:
''' """
Convert 2.3 VAE stanza to a straight path. Convert 2.3 VAE stanza to a straight path.
''' """
vae_path = None vae_path = None
# First get a path # First get a path
@ -288,14 +286,14 @@ class MigrateTo3(object):
vae_path = vae vae_path = vae
elif isinstance(vae, DictConfig): elif isinstance(vae, DictConfig):
if p := vae.get('path'): if p := vae.get("path"):
vae_path = p vae_path = p
elif repo_id := vae.get('repo_id'): elif repo_id := vae.get("repo_id"):
if repo_id=='stabilityai/sd-vae-ft-mse': # this guy is already downloaded if repo_id == "stabilityai/sd-vae-ft-mse": # this guy is already downloaded
vae_path = 'models/core/convert/sd-vae-ft-mse' vae_path = "models/core/convert/sd-vae-ft-mse"
return vae_path return vae_path
else: else:
vae_path = self._download_vae(repo_id, vae.get('subfolder')) vae_path = self._download_vae(repo_id, vae.get("subfolder"))
assert vae_path is not None, "Couldn't find VAE for this model" assert vae_path is not None, "Couldn't find VAE for this model"
@ -314,56 +312,53 @@ class MigrateTo3(object):
if vae_path.is_relative_to(self.dest_models): if vae_path.is_relative_to(self.dest_models):
rel_path = vae_path.relative_to(self.dest_models) rel_path = vae_path.relative_to(self.dest_models)
return Path('models',rel_path) return Path("models", rel_path)
else: else:
return vae_path return vae_path
def migrate_repo_id(self, repo_id: str, model_name: str = None, **extra_config): def migrate_repo_id(self, repo_id: str, model_name: str = None, **extra_config):
''' """
Migrate a locally-cached diffusers pipeline identified with a repo_id Migrate a locally-cached diffusers pipeline identified with a repo_id
''' """
dest_dir = self.dest_models dest_dir = self.dest_models
cache = self.root_directory / 'models/hub' cache = self.root_directory / "models/hub"
kwargs = dict( kwargs = dict(
cache_dir=cache, cache_dir=cache,
safety_checker=None, safety_checker=None,
# local_files_only = True, # local_files_only = True,
) )
owner,repo_name = repo_id.split('/') owner, repo_name = repo_id.split("/")
model_name = model_name or repo_name model_name = model_name or repo_name
model = cache / '--'.join(['models',owner,repo_name]) model = cache / "--".join(["models", owner, repo_name])
if len(list(model.glob('snapshots/**/model_index.json')))==0: if len(list(model.glob("snapshots/**/model_index.json"))) == 0:
return return
revisions = [x.name for x in model.glob('refs/*')] revisions = [x.name for x in model.glob("refs/*")]
# if an fp16 is available we use that # if an fp16 is available we use that
revision = 'fp16' if len(revisions) > 1 and 'fp16' in revisions else revisions[0] revision = "fp16" if len(revisions) > 1 and "fp16" in revisions else revisions[0]
pipeline = StableDiffusionPipeline.from_pretrained( pipeline = StableDiffusionPipeline.from_pretrained(repo_id, revision=revision, **kwargs)
repo_id,
revision=revision,
**kwargs)
info = ModelProbe().heuristic_probe(pipeline) info = ModelProbe().heuristic_probe(pipeline)
if not info: if not info:
return return
if self.mgr.model_exists(model_name, info.base_type, info.model_type): if self.mgr.model_exists(model_name, info.base_type, info.model_type):
logger.warning(f'A model named {model_name} already exists at the destination. Skipping migration.') logger.warning(f"A model named {model_name} already exists at the destination. Skipping migration.")
return return
dest = self._model_probe_to_path(info) / model_name dest = self._model_probe_to_path(info) / model_name
self._save_pretrained(pipeline, dest) self._save_pretrained(pipeline, dest)
rel_path = Path('models',dest.relative_to(dest_dir)) rel_path = Path("models", dest.relative_to(dest_dir))
self._add_model(model_name, info, rel_path, **extra_config) self._add_model(model_name, info, rel_path, **extra_config)
def migrate_path(self, location: Path, model_name: str = None, **extra_config): def migrate_path(self, location: Path, model_name: str = None, **extra_config):
''' """
Migrate a model referred to using 'weights' or 'path' Migrate a model referred to using 'weights' or 'path'
''' """
# handle relative paths # handle relative paths
dest_dir = self.dest_models dest_dir = self.dest_models
@ -375,7 +370,7 @@ class MigrateTo3(object):
return return
if self.mgr.model_exists(model_name, info.base_type, info.model_type): if self.mgr.model_exists(model_name, info.base_type, info.model_type):
logger.warning(f'A model named {model_name} already exists at the destination. Skipping migration.') logger.warning(f"A model named {model_name} already exists at the destination. Skipping migration.")
return return
# uh oh, weights is in the old models directory - move it into the new one # uh oh, weights is in the old models directory - move it into the new one
@ -385,15 +380,11 @@ class MigrateTo3(object):
self.copy_dir(location, dest) self.copy_dir(location, dest)
else: else:
self.copy_file(location, dest) self.copy_file(location, dest)
location = Path('models', info.base_type.value, info.model_type.value, location.name) location = Path("models", info.base_type.value, info.model_type.value, location.name)
self._add_model(model_name, info, location, **extra_config) self._add_model(model_name, info, location, **extra_config)
def _add_model(self, def _add_model(self, model_name: str, info: ModelProbeInfo, location: Path, **extra_config):
model_name: str,
info: ModelProbeInfo,
location: Path,
**extra_config):
if info.model_type != ModelType.Main: if info.model_type != ModelType.Main:
return return
@ -403,49 +394,48 @@ class MigrateTo3(object):
model_type=info.model_type, model_type=info.model_type,
clobber=True, clobber=True,
model_attributes={ model_attributes={
'path': str(location), "path": str(location),
'description': f'A {info.base_type.value} {info.model_type.value} model', "description": f"A {info.base_type.value} {info.model_type.value} model",
'model_format': info.format, "model_format": info.format,
'variant': info.variant_type.value, "variant": info.variant_type.value,
**extra_config, **extra_config,
} },
) )
def migrate_defined_models(self): def migrate_defined_models(self):
''' """
Migrate models defined in models.yaml Migrate models defined in models.yaml
''' """
# find any models referred to in old models.yaml # find any models referred to in old models.yaml
conf = OmegaConf.load(self.root_directory / 'configs/models.yaml') conf = OmegaConf.load(self.root_directory / "configs/models.yaml")
for model_name, stanza in conf.items(): for model_name, stanza in conf.items():
try: try:
passthru_args = {} passthru_args = {}
if vae := stanza.get('vae'): if vae := stanza.get("vae"):
try: try:
passthru_args['vae'] = str(self._vae_path(vae)) passthru_args["vae"] = str(self._vae_path(vae))
except Exception as e: except Exception as e:
logger.warning(f'Could not find a VAE matching "{vae}" for model "{model_name}"') logger.warning(f'Could not find a VAE matching "{vae}" for model "{model_name}"')
logger.warning(str(e)) logger.warning(str(e))
if config := stanza.get('config'): if config := stanza.get("config"):
passthru_args['config'] = config passthru_args["config"] = config
if description:= stanza.get('description'): if description := stanza.get("description"):
passthru_args['description'] = description passthru_args["description"] = description
if repo_id := stanza.get('repo_id'): if repo_id := stanza.get("repo_id"):
logger.info(f'Migrating diffusers model {model_name}') logger.info(f"Migrating diffusers model {model_name}")
self.migrate_repo_id(repo_id, model_name, **passthru_args) self.migrate_repo_id(repo_id, model_name, **passthru_args)
elif location := stanza.get('weights'): elif location := stanza.get("weights"):
logger.info(f'Migrating checkpoint model {model_name}') logger.info(f"Migrating checkpoint model {model_name}")
self.migrate_path(Path(location), model_name, **passthru_args) self.migrate_path(Path(location), model_name, **passthru_args)
elif location := stanza.get('path'): elif location := stanza.get("path"):
logger.info(f'Migrating diffusers model {model_name}') logger.info(f"Migrating diffusers model {model_name}")
self.migrate_path(Path(location), model_name, **passthru_args) self.migrate_path(Path(location), model_name, **passthru_args)
except KeyboardInterrupt: except KeyboardInterrupt:
@ -461,44 +451,46 @@ class MigrateTo3(object):
self.migrate_tuning_models() self.migrate_tuning_models()
self.migrate_defined_models() self.migrate_defined_models()
def _parse_legacy_initfile(root: Path, initfile: Path) -> ModelPaths: def _parse_legacy_initfile(root: Path, initfile: Path) -> ModelPaths:
''' """
Returns tuple of (embedding_path, lora_path, controlnet_path) Returns tuple of (embedding_path, lora_path, controlnet_path)
''' """
parser = argparse.ArgumentParser(fromfile_prefix_chars='@') parser = argparse.ArgumentParser(fromfile_prefix_chars="@")
parser.add_argument( parser.add_argument(
'--embedding_directory', "--embedding_directory",
'--embedding_path', "--embedding_path",
type=Path, type=Path,
dest='embedding_path', dest="embedding_path",
default=Path('embeddings'), default=Path("embeddings"),
) )
parser.add_argument( parser.add_argument(
'--lora_directory', "--lora_directory",
dest='lora_path', dest="lora_path",
type=Path, type=Path,
default=Path('loras'), default=Path("loras"),
) )
opt,_ = parser.parse_known_args([f'@{str(initfile)}']) opt, _ = parser.parse_known_args([f"@{str(initfile)}"])
return ModelPaths( return ModelPaths(
models = root / 'models', models=root / "models",
embeddings=root / str(opt.embedding_path).strip('"'), embeddings=root / str(opt.embedding_path).strip('"'),
loras=root / str(opt.lora_path).strip('"'), loras=root / str(opt.lora_path).strip('"'),
controlnets = root / 'controlnets', controlnets=root / "controlnets",
) )
def _parse_legacy_yamlfile(root: Path, initfile: Path) -> ModelPaths: def _parse_legacy_yamlfile(root: Path, initfile: Path) -> ModelPaths:
''' """
Returns tuple of (embedding_path, lora_path, controlnet_path) Returns tuple of (embedding_path, lora_path, controlnet_path)
''' """
# Don't use the config object because it is unforgiving of version updates # Don't use the config object because it is unforgiving of version updates
# Just use omegaconf directly # Just use omegaconf directly
opt = OmegaConf.load(initfile) opt = OmegaConf.load(initfile)
paths = opt.InvokeAI.Paths paths = opt.InvokeAI.Paths
models = paths.get('models_dir','models') models = paths.get("models_dir", "models")
embeddings = paths.get('embedding_dir','embeddings') embeddings = paths.get("embedding_dir", "embeddings")
loras = paths.get('lora_dir','loras') loras = paths.get("lora_dir", "loras")
controlnets = paths.get('controlnet_dir','controlnets') controlnets = paths.get("controlnet_dir", "controlnets")
return ModelPaths( return ModelPaths(
models=root / models, models=root / models,
embeddings=root / embeddings, embeddings=root / embeddings,
@ -506,22 +498,24 @@ def _parse_legacy_yamlfile(root: Path, initfile: Path)->ModelPaths:
controlnets=root / controlnets, controlnets=root / controlnets,
) )
def get_legacy_embeddings(root: Path) -> ModelPaths: def get_legacy_embeddings(root: Path) -> ModelPaths:
path = root / 'invokeai.init' path = root / "invokeai.init"
if path.exists(): if path.exists():
return _parse_legacy_initfile(root, path) return _parse_legacy_initfile(root, path)
path = root / 'invokeai.yaml' path = root / "invokeai.yaml"
if path.exists(): if path.exists():
return _parse_legacy_yamlfile(root, path) return _parse_legacy_yamlfile(root, path)
def do_migrate(src_directory: Path, dest_directory: Path): def do_migrate(src_directory: Path, dest_directory: Path):
""" """
Migrate models from src to dest InvokeAI root directories Migrate models from src to dest InvokeAI root directories
""" """
config_file = dest_directory / 'configs' / 'models.yaml.3' config_file = dest_directory / "configs" / "models.yaml.3"
dest_models = dest_directory / 'models.3' dest_models = dest_directory / "models.3"
version_3 = (dest_directory / 'models' / 'core').exists() version_3 = (dest_directory / "models" / "core").exists()
# Here we create the destination models.yaml file. # Here we create the destination models.yaml file.
# If we are writing into a version 3 directory and the # If we are writing into a version 3 directory and the
@ -530,80 +524,80 @@ def do_migrate(src_directory: Path, dest_directory: Path):
# create a new empty one. # create a new empty one.
if version_3: # write into the dest directory if version_3: # write into the dest directory
try: try:
shutil.copy(dest_directory / 'configs' / 'models.yaml', config_file) shutil.copy(dest_directory / "configs" / "models.yaml", config_file)
except: except:
MigrateTo3.initialize_yaml(config_file) MigrateTo3.initialize_yaml(config_file)
mgr = ModelManager(config_file) # important to initialize BEFORE moving the models directory mgr = ModelManager(config_file) # important to initialize BEFORE moving the models directory
(dest_directory / 'models').replace(dest_models) (dest_directory / "models").replace(dest_models)
else: else:
MigrateTo3.initialize_yaml(config_file) MigrateTo3.initialize_yaml(config_file)
mgr = ModelManager(config_file) mgr = ModelManager(config_file)
paths = get_legacy_embeddings(src_directory) paths = get_legacy_embeddings(src_directory)
migrator = MigrateTo3( migrator = MigrateTo3(from_root=src_directory, to_models=dest_models, model_manager=mgr, src_paths=paths)
from_root = src_directory,
to_models = dest_models,
model_manager = mgr,
src_paths = paths
)
migrator.migrate() migrator.migrate()
print("Migration successful.") print("Migration successful.")
if not version_3: if not version_3:
(dest_directory / 'models').replace(src_directory / 'models.orig') (dest_directory / "models").replace(src_directory / "models.orig")
print(f'Original models directory moved to {dest_directory}/models.orig') print(f"Original models directory moved to {dest_directory}/models.orig")
(dest_directory / 'configs' / 'models.yaml').replace(src_directory / 'configs' / 'models.yaml.orig') (dest_directory / "configs" / "models.yaml").replace(src_directory / "configs" / "models.yaml.orig")
print(f'Original models.yaml file moved to {dest_directory}/configs/models.yaml.orig') print(f"Original models.yaml file moved to {dest_directory}/configs/models.yaml.orig")
config_file.replace(config_file.with_suffix(""))
dest_models.replace(dest_models.with_suffix(""))
config_file.replace(config_file.with_suffix(''))
dest_models.replace(dest_models.with_suffix(''))
def main(): def main():
parser = argparse.ArgumentParser(prog="invokeai-migrate3", parser = argparse.ArgumentParser(
prog="invokeai-migrate3",
description=""" description="""
This will copy and convert the models directory and the configs/models.yaml from the InvokeAI 2.3 format This will copy and convert the models directory and the configs/models.yaml from the InvokeAI 2.3 format
'--from-directory' root to the InvokeAI 3.0 '--to-directory' root. These may be abbreviated '--from' and '--to'.a '--from-directory' root to the InvokeAI 3.0 '--to-directory' root. These may be abbreviated '--from' and '--to'.a
The old models directory and config file will be renamed 'models.orig' and 'models.yaml.orig' respectively. The old models directory and config file will be renamed 'models.orig' and 'models.yaml.orig' respectively.
It is safe to provide the same directory for both arguments, but it is better to use the invokeai_configure It is safe to provide the same directory for both arguments, but it is better to use the invokeai_configure
script, which will perform a full upgrade in place.""" script, which will perform a full upgrade in place.""",
) )
parser.add_argument('--from-directory', parser.add_argument(
dest='src_root', "--from-directory",
dest="src_root",
type=Path, type=Path,
required=True, required=True,
help='Source InvokeAI 2.3 root directory (containing "invokeai.init" or "invokeai.yaml")' help='Source InvokeAI 2.3 root directory (containing "invokeai.init" or "invokeai.yaml")',
) )
parser.add_argument('--to-directory', parser.add_argument(
dest='dest_root', "--to-directory",
dest="dest_root",
type=Path, type=Path,
required=True, required=True,
help='Destination InvokeAI 3.0 directory (containing "invokeai.yaml")' help='Destination InvokeAI 3.0 directory (containing "invokeai.yaml")',
) )
args = parser.parse_args() args = parser.parse_args()
src_root = args.src_root src_root = args.src_root
assert src_root.is_dir(), f"{src_root} is not a valid directory" assert src_root.is_dir(), f"{src_root} is not a valid directory"
assert (src_root / 'models').is_dir(), f"{src_root} does not contain a 'models' subdirectory" assert (src_root / "models").is_dir(), f"{src_root} does not contain a 'models' subdirectory"
assert (src_root / 'models' / 'hub').exists(), f"{src_root} does not contain a version 2.3 models directory" assert (src_root / "models" / "hub").exists(), f"{src_root} does not contain a version 2.3 models directory"
assert (src_root / 'invokeai.init').exists() or (src_root / 'invokeai.yaml').exists(), f"{src_root} does not contain an InvokeAI init file." assert (src_root / "invokeai.init").exists() or (
src_root / "invokeai.yaml"
).exists(), f"{src_root} does not contain an InvokeAI init file."
dest_root = args.dest_root dest_root = args.dest_root
assert dest_root.is_dir(), f"{dest_root} is not a valid directory" assert dest_root.is_dir(), f"{dest_root} is not a valid directory"
config = InvokeAIAppConfig.get_config() config = InvokeAIAppConfig.get_config()
config.parse_args(['--root',str(dest_root)]) config.parse_args(["--root", str(dest_root)])
# TODO: revisit - don't rely on invokeai.yaml to exist yet! # TODO: revisit - don't rely on invokeai.yaml to exist yet!
dest_is_setup = (dest_root / 'models/core').exists() and (dest_root / 'databases').exists() dest_is_setup = (dest_root / "models/core").exists() and (dest_root / "databases").exists()
if not dest_is_setup: if not dest_is_setup:
import invokeai.frontend.install.invokeai_configure import invokeai.frontend.install.invokeai_configure
from invokeai.backend.install.invokeai_configure import initialize_rootdir from invokeai.backend.install.invokeai_configure import initialize_rootdir
initialize_rootdir(dest_root, True) initialize_rootdir(dest_root, True)
do_migrate(src_root, dest_root) do_migrate(src_root, dest_root)
if __name__ == '__main__':
if __name__ == "__main__":
main() main()

View File

@ -7,7 +7,7 @@ import warnings
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import List, Dict, Callable, Union, Set from typing import List, Dict, Callable, Union, Set, Optional
import requests import requests
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
@ -28,7 +28,7 @@ warnings.filterwarnings("ignore")
# --------------------------globals----------------------- # --------------------------globals-----------------------
config = InvokeAIAppConfig.get_config() config = InvokeAIAppConfig.get_config()
logger = InvokeAILogger.getLogger(name='InvokeAI') logger = InvokeAILogger.getLogger(name="InvokeAI")
# the initial "configs" dir is now bundled in the `invokeai.configs` package # the initial "configs" dir is now bundled in the `invokeai.configs` package
Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml" Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
@ -45,51 +45,63 @@ Config_preamble = """
LEGACY_CONFIGS = { LEGACY_CONFIGS = {
BaseModelType.StableDiffusion1: { BaseModelType.StableDiffusion1: {
ModelVariantType.Normal: 'v1-inference.yaml', ModelVariantType.Normal: "v1-inference.yaml",
ModelVariantType.Inpaint: 'v1-inpainting-inference.yaml', ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
}, },
BaseModelType.StableDiffusion2: { BaseModelType.StableDiffusion2: {
ModelVariantType.Normal: { ModelVariantType.Normal: {
SchedulerPredictionType.Epsilon: 'v2-inference.yaml', SchedulerPredictionType.Epsilon: "v2-inference.yaml",
SchedulerPredictionType.VPrediction: 'v2-inference-v.yaml', SchedulerPredictionType.VPrediction: "v2-inference-v.yaml",
}, },
ModelVariantType.Inpaint: { ModelVariantType.Inpaint: {
SchedulerPredictionType.Epsilon: 'v2-inpainting-inference.yaml', SchedulerPredictionType.Epsilon: "v2-inpainting-inference.yaml",
SchedulerPredictionType.VPrediction: 'v2-inpainting-inference-v.yaml', SchedulerPredictionType.VPrediction: "v2-inpainting-inference-v.yaml",
} },
} },
BaseModelType.StableDiffusionXL: {
ModelVariantType.Normal: "sd_xl_base.yaml",
},
BaseModelType.StableDiffusionXLRefiner: {
ModelVariantType.Normal: "sd_xl_refiner.yaml",
},
} }
@dataclass @dataclass
class ModelInstallList: class ModelInstallList:
'''Class for listing models to be installed/removed''' """Class for listing models to be installed/removed"""
install_models: List[str] = field(default_factory=list) install_models: List[str] = field(default_factory=list)
remove_models: List[str] = field(default_factory=list) remove_models: List[str] = field(default_factory=list)
@dataclass @dataclass
class InstallSelections(): class InstallSelections:
install_models: List[str] = field(default_factory=list) install_models: List[str] = field(default_factory=list)
remove_models: List[str] = field(default_factory=list) remove_models: List[str] = field(default_factory=list)
@dataclass @dataclass
class ModelLoadInfo(): class ModelLoadInfo:
name: str name: str
model_type: ModelType model_type: ModelType
base_type: BaseModelType base_type: BaseModelType
path: Path = None path: Path = None
repo_id: str = None repo_id: str = None
description: str = '' description: str = ""
installed: bool = False installed: bool = False
recommended: bool = False recommended: bool = False
default: bool = False default: bool = False
class ModelInstall(object): class ModelInstall(object):
def __init__(self, def __init__(
self,
config: InvokeAIAppConfig, config: InvokeAIAppConfig,
prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None, prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None,
model_manager: ModelManager = None, model_manager: ModelManager = None,
access_token:str = None): access_token: str = None,
):
self.config = config self.config = config
self.mgr = model_manager or ModelManager(config.model_conf_path) self.mgr = model_manager or ModelManager(config.model_conf_path)
self.datasets = OmegaConf.load(Dataset_path) self.datasets = OmegaConf.load(Dataset_path)
@ -98,30 +110,32 @@ class ModelInstall(object):
self.reverse_paths = self._reverse_paths(self.datasets) self.reverse_paths = self._reverse_paths(self.datasets)
def all_models(self) -> Dict[str, ModelLoadInfo]: def all_models(self) -> Dict[str, ModelLoadInfo]:
''' """
Return dict of model_key=>ModelLoadInfo objects. Return dict of model_key=>ModelLoadInfo objects.
This method consolidates and simplifies the entries in both This method consolidates and simplifies the entries in both
models.yaml and INITIAL_MODELS.yaml so that they can models.yaml and INITIAL_MODELS.yaml so that they can
be treated uniformly. It also sorts the models alphabetically be treated uniformly. It also sorts the models alphabetically
by their name, to improve the display somewhat. by their name, to improve the display somewhat.
''' """
model_dict = dict() model_dict = dict()
# first populate with the entries in INITIAL_MODELS.yaml # first populate with the entries in INITIAL_MODELS.yaml
for key, value in self.datasets.items(): for key, value in self.datasets.items():
name, base, model_type = ModelManager.parse_key(key) name, base, model_type = ModelManager.parse_key(key)
value['name'] = name value["name"] = name
value['base_type'] = base value["base_type"] = base
value['model_type'] = model_type value["model_type"] = model_type
model_dict[key] = ModelLoadInfo(**value) model_dict[key] = ModelLoadInfo(**value)
# supplement with entries in models.yaml # supplement with entries in models.yaml
installed_models = self.mgr.list_models() installed_models = [x for x in self.mgr.list_models()]
# suppresses autoloaded models
# installed_models = [x for x in self.mgr.list_models() if not self._is_autoloaded(x)]
for md in installed_models: for md in installed_models:
base = md['base_model'] base = md["base_model"]
model_type = md['model_type'] model_type = md["model_type"]
name = md['model_name'] name = md["model_name"]
key = ModelManager.create_key(name, base, model_type) key = ModelManager.create_key(name, base, model_type)
if key in model_dict: if key in model_dict:
model_dict[key].installed = True model_dict[key].installed = True
@ -130,32 +144,44 @@ class ModelInstall(object):
name=name, name=name,
base_type=base, base_type=base,
model_type=model_type, model_type=model_type,
path = value.get('path'), path=value.get("path"),
installed=True, installed=True,
) )
return {x: model_dict[x] for x in sorted(model_dict.keys(), key=lambda y: model_dict[y].name.lower())} return {x: model_dict[x] for x in sorted(model_dict.keys(), key=lambda y: model_dict[y].name.lower())}
def _is_autoloaded(self, model_info: dict) -> bool:
path = model_info.get("path")
if not path:
return False
for autodir in ["autoimport_dir", "lora_dir", "embedding_dir", "controlnet_dir"]:
if autodir_path := getattr(self.config, autodir):
autodir_path = self.config.root_path / autodir_path
if Path(path).is_relative_to(autodir_path):
return True
return False
def list_models(self, model_type): def list_models(self, model_type):
installed = self.mgr.list_models(model_type=model_type) installed = self.mgr.list_models(model_type=model_type)
print(f'Installed models of type `{model_type}`:') print(f"Installed models of type `{model_type}`:")
for i in installed: for i in installed:
print(f"{i['model_name']}\t{i['base_model']}\t{i['path']}") print(f"{i['model_name']}\t{i['base_model']}\t{i['path']}")
def starter_models(self)->Set[str]: # logic here a little reversed to maintain backward compatibility
def starter_models(self, all_models: bool = False) -> Set[str]:
models = set() models = set()
for key, value in self.datasets.items(): for key, value in self.datasets.items():
name, base, model_type = ModelManager.parse_key(key) name, base, model_type = ModelManager.parse_key(key)
if model_type==ModelType.Main: if all_models or model_type in [ModelType.Main, ModelType.Vae]:
models.add(key) models.add(key)
return models return models
def recommended_models(self) -> Set[str]: def recommended_models(self) -> Set[str]:
starters = self.starter_models() starters = self.starter_models(all_models=True)
return set([x for x in starters if self.datasets[x].get('recommended',False)]) return set([x for x in starters if self.datasets[x].get("recommended", False)])
def default_model(self) -> str: def default_model(self) -> str:
starters = self.starter_models() starters = self.starter_models()
defaults = [x for x in starters if self.datasets[x].get('default',False)] defaults = [x for x in starters if self.datasets[x].get("default", False)]
return defaults[0] return defaults[0]
def install(self, selections: InstallSelections): def install(self, selections: InstallSelections):
@ -168,7 +194,7 @@ class ModelInstall(object):
# remove requested models # remove requested models
for key in selections.remove_models: for key in selections.remove_models:
name, base, mtype = self.mgr.parse_key(key) name, base, mtype = self.mgr.parse_key(key)
logger.info(f'Deleting {mtype} model {name} [{job}/{jobs}]') logger.info(f"Deleting {mtype} model {name} [{job}/{jobs}]")
try: try:
self.mgr.del_model(name, base, mtype) self.mgr.del_model(name, base, mtype)
except FileNotFoundError as e: except FileNotFoundError as e:
@ -177,7 +203,7 @@ class ModelInstall(object):
# add requested models # add requested models
for path in selections.install_models: for path in selections.install_models:
logger.info(f'Installing {path} [{job}/{jobs}]') logger.info(f"Installing {path} [{job}/{jobs}]")
try: try:
self.heuristic_import(path) self.heuristic_import(path)
except (ValueError, KeyError) as e: except (ValueError, KeyError) as e:
@ -187,15 +213,16 @@ class ModelInstall(object):
dlogging.set_verbosity(verbosity) dlogging.set_verbosity(verbosity)
self.mgr.commit() self.mgr.commit()
def heuristic_import(self, def heuristic_import(
self,
model_path_id_or_url: Union[str, Path], model_path_id_or_url: Union[str, Path],
models_installed: Set[Path] = None, models_installed: Set[Path] = None,
) -> Dict[str, AddModelResult]: ) -> Dict[str, AddModelResult]:
''' """
:param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL :param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL
:param models_installed: Set of installed models, used for recursive invocation :param models_installed: Set of installed models, used for recursive invocation
Returns a set of dict objects corresponding to newly-created stanzas in models.yaml. Returns a set of dict objects corresponding to newly-created stanzas in models.yaml.
''' """
if not models_installed: if not models_installed:
models_installed = dict() models_installed = dict()
@ -208,8 +235,10 @@ class ModelInstall(object):
models_installed.update({str(path): self._install_path(path)}) models_installed.update({str(path): self._install_path(path)})
# folders style or similar # folders style or similar
elif path.is_dir() and any([(path/x).exists() for x in \ elif path.is_dir() and any(
{'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'} [
(path / x).exists()
for x in {"config.json", "model_index.json", "learned_embeds.bin", "pytorch_lora_weights.bin"}
] ]
): ):
models_installed.update({str(model_path_id_or_url): self._install_path(path)}) models_installed.update({str(model_path_id_or_url): self._install_path(path)})
@ -220,7 +249,7 @@ class ModelInstall(object):
self.heuristic_import(child, models_installed=models_installed) self.heuristic_import(child, models_installed=models_installed)
# huggingface repo # huggingface repo
elif len(str(model_path_id_or_url).split('/')) == 2: elif len(str(model_path_id_or_url).split("/")) == 2:
models_installed.update({str(model_path_id_or_url): self._install_repo(str(model_path_id_or_url))}) models_installed.update({str(model_path_id_or_url): self._install_repo(str(model_path_id_or_url))})
# a URL # a URL
@ -228,7 +257,7 @@ class ModelInstall(object):
models_installed.update({str(model_path_id_or_url): self._install_url(model_path_id_or_url)}) models_installed.update({str(model_path_id_or_url): self._install_url(model_path_id_or_url)})
else: else:
raise KeyError(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping') raise KeyError(f"{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping")
return models_installed return models_installed
@ -237,13 +266,14 @@ class ModelInstall(object):
def _install_path(self, path: Path, info: ModelProbeInfo = None) -> AddModelResult: def _install_path(self, path: Path, info: ModelProbeInfo = None) -> AddModelResult:
info = info or ModelProbe().heuristic_probe(path, self.prediction_helper) info = info or ModelProbe().heuristic_probe(path, self.prediction_helper)
if not info: if not info:
logger.warning(f'Unable to parse format of {path}') logger.warning(f"Unable to parse format of {path}")
return None return None
model_name = path.stem if path.is_file() else path.name model_name = path.stem if path.is_file() else path.name
if self.mgr.model_exists(model_name, info.base_type, info.model_type): if self.mgr.model_exists(model_name, info.base_type, info.model_type):
raise ValueError(f'A model named "{model_name}" is already installed.') raise ValueError(f'A model named "{model_name}" is already installed.')
attributes = self._make_attributes(path, info) attributes = self._make_attributes(path, info)
return self.mgr.add_model(model_name = model_name, return self.mgr.add_model(
model_name=model_name,
base_model=info.base_type, base_model=info.base_type,
model_type=info.model_type, model_type=info.model_type,
model_attributes=attributes, model_attributes=attributes,
@ -253,9 +283,10 @@ class ModelInstall(object):
with TemporaryDirectory(dir=self.config.models_path) as staging: with TemporaryDirectory(dir=self.config.models_path) as staging:
location = download_with_resume(url, Path(staging)) location = download_with_resume(url, Path(staging))
if not location: if not location:
logger.error(f'Unable to download {url}. Skipping.') logger.error(f"Unable to download {url}. Skipping.")
info = ModelProbe().heuristic_probe(location) info = ModelProbe().heuristic_probe(location)
dest = self.config.models_path / info.base_type.value / info.model_type.value / location.name dest = self.config.models_path / info.base_type.value / info.model_type.value / location.name
dest.parent.mkdir(parents=True, exist_ok=True)
models_path = shutil.move(location, dest) models_path = shutil.move(location, dest)
# staged version will be garbage-collected at this time # staged version will be garbage-collected at this time
@ -271,42 +302,49 @@ class ModelInstall(object):
with TemporaryDirectory(dir=self.config.models_path) as staging: with TemporaryDirectory(dir=self.config.models_path) as staging:
staging = Path(staging) staging = Path(staging)
if 'model_index.json' in files: if "model_index.json" in files:
location = self._download_hf_pipeline(repo_id, staging) # pipeline location = self._download_hf_pipeline(repo_id, staging) # pipeline
else: else:
for suffix in ['safetensors','bin']: for suffix in ["safetensors", "bin"]:
if f'pytorch_lora_weights.{suffix}' in files: if f"pytorch_lora_weights.{suffix}" in files:
location = self._download_hf_model(repo_id, ['pytorch_lora_weights.bin'], staging) # LoRA location = self._download_hf_model(repo_id, ["pytorch_lora_weights.bin"], staging) # LoRA
break break
elif self.config.precision=='float16' and f'diffusion_pytorch_model.fp16.{suffix}' in files: # vae, controlnet or some other standalone elif (
files = ['config.json', f'diffusion_pytorch_model.fp16.{suffix}'] self.config.precision == "float16" and f"diffusion_pytorch_model.fp16.{suffix}" in files
): # vae, controlnet or some other standalone
files = ["config.json", f"diffusion_pytorch_model.fp16.{suffix}"]
location = self._download_hf_model(repo_id, files, staging) location = self._download_hf_model(repo_id, files, staging)
break break
elif f'diffusion_pytorch_model.{suffix}' in files: elif f"diffusion_pytorch_model.{suffix}" in files:
files = ['config.json', f'diffusion_pytorch_model.{suffix}'] files = ["config.json", f"diffusion_pytorch_model.{suffix}"]
location = self._download_hf_model(repo_id, files, staging) location = self._download_hf_model(repo_id, files, staging)
break break
elif f'learned_embeds.{suffix}' in files: elif f"learned_embeds.{suffix}" in files:
location = self._download_hf_model(repo_id, [f'learned_embeds.{suffix}'], staging) location = self._download_hf_model(repo_id, [f"learned_embeds.{suffix}"], staging)
break break
if not location: if not location:
logger.warning(f'Could not determine type of repo {repo_id}. Skipping install.') logger.warning(f"Could not determine type of repo {repo_id}. Skipping install.")
return {} return {}
info = ModelProbe().heuristic_probe(location, self.prediction_helper) info = ModelProbe().heuristic_probe(location, self.prediction_helper)
if not info: if not info:
logger.warning(f'Could not probe {location}. Skipping install.') logger.warning(f"Could not probe {location}. Skipping install.")
return {} return {}
dest = self.config.models_path / info.base_type.value / info.model_type.value / self._get_model_name(repo_id,location) dest = (
self.config.models_path
/ info.base_type.value
/ info.model_type.value
/ self._get_model_name(repo_id, location)
)
if dest.exists(): if dest.exists():
shutil.rmtree(dest) shutil.rmtree(dest)
shutil.copytree(location, dest) shutil.copytree(location, dest)
return self._install_path(dest, info) return self._install_path(dest, info)
def _get_model_name(self, path_name: str, location: Path) -> str: def _get_model_name(self, path_name: str, location: Path) -> str:
''' """
Calculate a name for the model - primitive implementation. Calculate a name for the model - primitive implementation.
''' """
if key := self.reverse_paths.get(path_name): if key := self.reverse_paths.get(path_name):
(name, base, mtype) = ModelManager.parse_key(key) (name, base, mtype) = ModelManager.parse_key(key)
return name return name
@ -317,53 +355,65 @@ class ModelInstall(object):
def _make_attributes(self, path: Path, info: ModelProbeInfo) -> dict: def _make_attributes(self, path: Path, info: ModelProbeInfo) -> dict:
model_name = path.name if path.is_dir() else path.stem model_name = path.name if path.is_dir() else path.stem
description = f'{info.base_type.value} {info.model_type.value} model {model_name}' description = f"{info.base_type.value} {info.model_type.value} model {model_name}"
if key := self.reverse_paths.get(self.current_id): if key := self.reverse_paths.get(self.current_id):
if key in self.datasets: if key in self.datasets:
description = self.datasets[key].get('description') or description description = self.datasets[key].get("description") or description
rel_path = self.relative_to_root(path) rel_path = self.relative_to_root(path, self.config.models_path)
attributes = dict( attributes = dict(
path=str(rel_path), path=str(rel_path),
description=str(description), description=str(description),
model_format=info.format, model_format=info.format,
) )
legacy_conf = None
if info.model_type == ModelType.Main: if info.model_type == ModelType.Main:
attributes.update(dict(variant = info.variant_type,)) attributes.update(
dict(
variant=info.variant_type,
)
)
if info.format == "checkpoint": if info.format == "checkpoint":
try: try:
possible_conf = path.with_suffix('.yaml') possible_conf = path.with_suffix(".yaml")
if possible_conf.exists(): if possible_conf.exists():
legacy_conf = str(self.relative_to_root(possible_conf)) legacy_conf = str(self.relative_to_root(possible_conf))
elif info.base_type == BaseModelType.StableDiffusion2: elif info.base_type == BaseModelType.StableDiffusion2:
legacy_conf = Path(self.config.legacy_conf_dir, LEGACY_CONFIGS[info.base_type][info.variant_type][info.prediction_type]) legacy_conf = Path(
self.config.legacy_conf_dir,
LEGACY_CONFIGS[info.base_type][info.variant_type][info.prediction_type],
)
else: else:
legacy_conf = Path(self.config.legacy_conf_dir, LEGACY_CONFIGS[info.base_type][info.variant_type]) legacy_conf = Path(
self.config.legacy_conf_dir, LEGACY_CONFIGS[info.base_type][info.variant_type]
)
except KeyError: except KeyError:
legacy_conf = Path(self.config.legacy_conf_dir, 'v1-inference.yaml') # best guess legacy_conf = Path(self.config.legacy_conf_dir, "v1-inference.yaml") # best guess
attributes.update( if info.model_type == ModelType.ControlNet and info.format == "checkpoint":
dict( possible_conf = path.with_suffix(".yaml")
config = str(legacy_conf) if possible_conf.exists():
) legacy_conf = str(self.relative_to_root(possible_conf))
)
if legacy_conf:
attributes.update(dict(config=str(legacy_conf)))
return attributes return attributes
def relative_to_root(self, path: Path)->Path: def relative_to_root(self, path: Path, root: Optional[Path] = None) -> Path:
root = self.config.root_path root = root or self.config.root_path
if path.is_relative_to(root): if path.is_relative_to(root):
return path.relative_to(root) return path.relative_to(root)
else: else:
return path return path
def _download_hf_pipeline(self, repo_id: str, staging: Path) -> Path: def _download_hf_pipeline(self, repo_id: str, staging: Path) -> Path:
''' """
This retrieves a StableDiffusion model from cache or remote and then This retrieves a StableDiffusion model from cache or remote and then
does a save_pretrained() to the indicated staging area. does a save_pretrained() to the indicated staging area.
''' """
_, name = repo_id.split("/") _, name = repo_id.split("/")
revisions = ['fp16','main'] if self.config.precision=='float16' else ['main'] revisions = ["fp16", "main"] if self.config.precision == "float16" else ["main"]
model = None model = None
for revision in revisions: for revision in revisions:
try: try:
@ -373,7 +423,7 @@ class ModelInstall(object):
if model: if model:
break break
if not model: if not model:
logger.error(f'Diffusers model {repo_id} could not be downloaded. Skipping.') logger.error(f"Diffusers model {repo_id} could not be downloaded. Skipping.")
return None return None
model.save_pretrained(staging / name, safe_serialization=True) model.save_pretrained(staging / name, safe_serialization=True)
return staging / name return staging / name
@ -383,24 +433,23 @@ class ModelInstall(object):
location = staging / name location = staging / name
paths = list() paths = list()
for filename in files: for filename in files:
p = hf_download_with_resume(repo_id, p = hf_download_with_resume(
model_dir=location, repo_id, model_dir=location, model_name=filename, access_token=self.access_token
model_name=filename,
access_token = self.access_token
) )
if p: if p:
paths.append(p) paths.append(p)
else: else:
logger.warning(f'Could not download {filename} from {repo_id}.') logger.warning(f"Could not download {filename} from {repo_id}.")
return location if len(paths) > 0 else None return location if len(paths) > 0 else None
@classmethod @classmethod
def _reverse_paths(cls, datasets) -> dict: def _reverse_paths(cls, datasets) -> dict:
''' """
Reverse mapping from repo_id/path to destination name. Reverse mapping from repo_id/path to destination name.
''' """
return {v.get('path') or v.get('repo_id') : k for k, v in datasets.items()} return {v.get("path") or v.get("repo_id"): k for k, v in datasets.items()}
# ------------------------------------- # -------------------------------------
def yes_or_no(prompt: str, default_yes=True): def yes_or_no(prompt: str, default_yes=True):
@ -411,12 +460,11 @@ def yes_or_no(prompt: str, default_yes=True):
else: else:
return response[0] in ("y", "Y") return response[0] in ("y", "Y")
# --------------------------------------------- # ---------------------------------------------
def hf_download_from_pretrained( def hf_download_from_pretrained(model_class: object, model_name: str, destination: Path, **kwargs):
model_class: object, model_name: str, destination: Path, **kwargs logger = InvokeAILogger.getLogger("InvokeAI")
): logger.addFilter(lambda x: "fp16 is not a valid" not in x.getMessage())
logger = InvokeAILogger.getLogger('InvokeAI')
logger.addFilter(lambda x: 'fp16 is not a valid' not in x.getMessage())
model = model_class.from_pretrained( model = model_class.from_pretrained(
model_name, model_name,
@ -426,6 +474,7 @@ def hf_download_from_pretrained(
model.save_pretrained(destination, safe_serialization=True) model.save_pretrained(destination, safe_serialization=True)
return destination return destination
# --------------------------------------------- # ---------------------------------------------
def hf_download_with_resume( def hf_download_with_resume(
repo_id: str, repo_id: str,
@ -451,9 +500,7 @@ def hf_download_with_resume(
resp = requests.get(url, headers=header, stream=True) resp = requests.get(url, headers=header, stream=True)
total = int(resp.headers.get("content-length", 0)) total = int(resp.headers.get("content-length", 0))
if ( if resp.status_code == 416: # "range not satisfiable", which means nothing to return
resp.status_code == 416
): # "range not satisfiable", which means nothing to return
logger.info(f"{model_name}: complete file found. Skipping.") logger.info(f"{model_name}: complete file found. Skipping.")
return model_dest return model_dest
elif resp.status_code == 404: elif resp.status_code == 404:
@ -482,5 +529,3 @@ def hf_download_with_resume(
logger.error(f"An error occurred while downloading {model_name}: {str(e)}") logger.error(f"An error occurred while downloading {model_name}: {str(e)}")
return None return None
return model_dest return model_dest

View File

@ -3,6 +3,12 @@ Initialization file for invokeai.backend.model_management
""" """
from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType
from .model_cache import ModelCache from .model_cache import ModelCache
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType, ModelNotFoundException, DuplicateModelException from .models import (
BaseModelType,
ModelType,
SubModelType,
ModelVariantType,
ModelNotFoundException,
DuplicateModelException,
)
from .model_merge import ModelMerger, MergeInterpolationMethod from .model_merge import ModelMerger, MergeInterpolationMethod

File diff suppressed because it is too large Load Diff

View File

@ -11,6 +11,7 @@ from diffusers.models import UNet2DConditionModel
from safetensors.torch import load_file from safetensors.torch import load_file
from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPTextModel, CLIPTokenizer
class LoRALayerBase: class LoRALayerBase:
# rank: Optional[int] # rank: Optional[int]
# alpha: Optional[float] # alpha: Optional[float]
@ -31,11 +32,7 @@ class LoRALayerBase:
else: else:
self.alpha = None self.alpha = None
if ( if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
"bias_indices" in values
and "bias_values" in values
and "bias_size" in values
):
self.bias = torch.sparse_coo_tensor( self.bias = torch.sparse_coo_tensor(
values["bias_indices"], values["bias_indices"],
values["bias_values"], values["bias_values"],
@ -71,12 +68,16 @@ class LoRALayerBase:
bias = self.bias if self.bias is not None else 0 bias = self.bias if self.bias is not None else 0
scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0 scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0
return op( return (
op(
*input_h, *input_h,
(weight + bias).view(module.weight.shape), (weight + bias).view(module.weight.shape),
None, None,
**extra_args, **extra_args,
) * multiplier * scale )
* multiplier
* scale
)
def get_weight(self): def get_weight(self):
raise NotImplementedError() raise NotImplementedError()
@ -187,12 +188,8 @@ class LoHALayer(LoRALayerBase):
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b) weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
else: else:
rebuild1 = torch.einsum( rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
"i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
)
rebuild2 = torch.einsum(
"i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a
)
weight = rebuild1 * rebuild2 weight = rebuild1 * rebuild2
return weight return weight
@ -278,7 +275,7 @@ class LoKRLayer(LoRALayerBase):
if self.t2 is None: if self.t2 is None:
w2 = self.w2_a @ self.w2_b w2 = self.w2_a @ self.w2_b
else: else:
w2 = torch.einsum('i j k l, i p, j r -> p r k l', self.t2, self.w2_a, self.w2_b) w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
if len(w2.shape) == 4: if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2) w1 = w1.unsqueeze(2).unsqueeze(2)
@ -392,7 +389,6 @@ class LoRAModel: #(torch.nn.Module):
state_dict = cls._group_state(state_dict) state_dict = cls._group_state(state_dict)
for layer_key, values in state_dict.items(): for layer_key, values in state_dict.items():
# lora and locon # lora and locon
if "lora_down.weight" in values: if "lora_down.weight" in values:
layer = LoRALayer(layer_key, values) layer = LoRALayer(layer_key, values)
@ -407,9 +403,7 @@ class LoRAModel: #(torch.nn.Module):
else: else:
# TODO: diff/ia3/... format # TODO: diff/ia3/... format
print( print(f">> Encountered unknown lora layer module in {model.name}: {layer_key}")
f">> Encountered unknown lora layer module in {model.name}: {layer_key}"
)
return return
# lower memory consumption by removing already parsed layer values # lower memory consumption by removing already parsed layer values
@ -443,9 +437,10 @@ with LoRAHelper.apply_lora_unet(unet, loras):
# unmodified unet # unmodified unet
""" """
# TODO: rename smth like ModelPatcher and add TI method? # TODO: rename smth like ModelPatcher and add TI method?
class ModelPatcher: class ModelPatcher:
@staticmethod @staticmethod
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]: def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
assert "." not in lora_key assert "." not in lora_key
@ -455,7 +450,7 @@ class ModelPatcher:
module = model module = model
module_key = "" module_key = ""
key_parts = lora_key[len(prefix):].split('_') key_parts = lora_key[len(prefix) :].split("_")
submodule_name = key_parts.pop(0) submodule_name = key_parts.pop(0)
@ -477,7 +472,6 @@ class ModelPatcher:
applied_loras: List[Tuple[LoRAModel, float]], applied_loras: List[Tuple[LoRAModel, float]],
layer_name: str, layer_name: str,
): ):
def lora_forward(module, input_h, output): def lora_forward(module, input_h, output):
if len(applied_loras) == 0: if len(applied_loras) == 0:
return output return output
@ -491,7 +485,6 @@ class ModelPatcher:
return lora_forward return lora_forward
@classmethod @classmethod
@contextmanager @contextmanager
def apply_lora_unet( def apply_lora_unet(
@ -502,7 +495,6 @@ class ModelPatcher:
with cls.apply_lora(unet, loras, "lora_unet_"): with cls.apply_lora(unet, loras, "lora_unet_"):
yield yield
@classmethod @classmethod
@contextmanager @contextmanager
def apply_lora_text_encoder( def apply_lora_text_encoder(
@ -513,7 +505,6 @@ class ModelPatcher:
with cls.apply_lora(text_encoder, loras, "lora_te_"): with cls.apply_lora(text_encoder, loras, "lora_te_"):
yield yield
@classmethod @classmethod
@contextmanager @contextmanager
def apply_lora( def apply_lora(
@ -554,7 +545,6 @@ class ModelPatcher:
for module_key, weight in original_weights.items(): for module_key, weight in original_weights.items():
model.get_submodule(module_key).weight.copy_(weight) model.get_submodule(module_key).weight.copy_(weight)
@classmethod @classmethod
@contextmanager @contextmanager
def apply_ti( def apply_ti(
@ -602,7 +592,9 @@ class ModelPatcher:
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {model_embeddings.weight.data[token_id].shape[0]}." f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {model_embeddings.weight.data[token_id].shape[0]}."
) )
model_embeddings.weight.data[token_id] = embedding.to(device=text_encoder.device, dtype=text_encoder.dtype) model_embeddings.weight.data[token_id] = embedding.to(
device=text_encoder.device, dtype=text_encoder.dtype
)
ti_tokens.append(token_id) ti_tokens.append(token_id)
if len(ti_tokens) > 1: if len(ti_tokens) > 1:
@ -614,7 +606,6 @@ class ModelPatcher:
if init_tokens_count and new_tokens_added: if init_tokens_count and new_tokens_added:
text_encoder.resize_token_embeddings(init_tokens_count) text_encoder.resize_token_embeddings(init_tokens_count)
@classmethod @classmethod
@contextmanager @contextmanager
def apply_clip_skip( def apply_clip_skip(
@ -633,6 +624,7 @@ class ModelPatcher:
while len(skipped_layers) > 0: while len(skipped_layers) > 0:
text_encoder.text_model.encoder.layers.append(skipped_layers.pop()) text_encoder.text_model.encoder.layers.append(skipped_layers.pop())
class TextualInversionModel: class TextualInversionModel:
name: str name: str
embedding: torch.Tensor # [n, 768]|[n, 1280] embedding: torch.Tensor # [n, 768]|[n, 1280]
@ -659,7 +651,9 @@ class TextualInversionModel:
# difference mostly in metadata # difference mostly in metadata
if "string_to_param" in state_dict: if "string_to_param" in state_dict:
if len(state_dict["string_to_param"]) > 1: if len(state_dict["string_to_param"]) > 1:
print(f"Warn: Embedding \"{file_path.name}\" contains multiple tokens, which is not supported. The first token will be used.") print(
f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first token will be used.'
)
result.embedding = next(iter(state_dict["string_to_param"].values())) result.embedding = next(iter(state_dict["string_to_param"].values()))
@ -688,10 +682,7 @@ class TextualInversionManager(BaseTextualInversionManager):
self.pad_tokens = dict() self.pad_tokens = dict()
self.tokenizer = tokenizer self.tokenizer = tokenizer
def expand_textual_inversion_token_ids_if_necessary( def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]:
self, token_ids: list[int]
) -> list[int]:
if len(self.pad_tokens) == 0: if len(self.pad_tokens) == 0:
return token_ids return token_ids
@ -707,4 +698,3 @@ class TextualInversionManager(BaseTextualInversionManager):
new_token_ids.extend(self.pad_tokens[token_id]) new_token_ids.extend(self.pad_tokens[token_id])
return new_token_ids return new_token_ids

Some files were not shown because too many files have changed in this diff Show More