From 0aa1106c96681bb12d37499faa625187d684fb16 Mon Sep 17 00:00:00 2001 From: mauwii Date: Mon, 13 Feb 2023 22:54:14 +0100 Subject: [PATCH 01/57] update .editorconfig --- .editorconfig | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.editorconfig b/.editorconfig index d4b0972eda..e03231eaca 100644 --- a/.editorconfig +++ b/.editorconfig @@ -1,3 +1,5 @@ +root = true + # All files [*] charset = utf-8 From 35518542f828152f365741b966e0a41d3e73ff2e Mon Sep 17 00:00:00 2001 From: mauwii Date: Mon, 13 Feb 2023 22:56:49 +0100 Subject: [PATCH 02/57] add .vscode files --- .gitignore | 1 - .vscode/extensions.json | 14 ++++++++++ .vscode/launch.json | 16 ++++++++++++ .vscode/settings.json | 57 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 87 insertions(+), 1 deletion(-) create mode 100644 .vscode/extensions.json create mode 100644 .vscode/launch.json create mode 100644 .vscode/settings.json diff --git a/.gitignore b/.gitignore index 9adb0be85a..8b9b495261 100644 --- a/.gitignore +++ b/.gitignore @@ -201,7 +201,6 @@ checkpoints # Scratch folder .scratch/ -.vscode/ gfpgan/ models/ldm/stable-diffusion-v1/*.sha256 diff --git a/.vscode/extensions.json b/.vscode/extensions.json new file mode 100644 index 0000000000..45964d6921 --- /dev/null +++ b/.vscode/extensions.json @@ -0,0 +1,14 @@ +{ + "recommendations": [ + "editorconfig.editorconfig", + "github.vscode-pull-request-github", + "ms-python.black-formatter", + "ms-python.flake8", + "ms-python.isort", + "ms-python.python", + "ms-python.vscode-pylance", + "redhat.vscode-yaml", + "tamasfe.even-better-toml", + "eamodio.gitlens" + ] +} diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000000..cde7f9b4c3 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,16 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "invokeai --web", + "type": "python", + "request": "launch", + "program": ".venv/bin/invokeai", + "args": ["--web"], + "justMyCode": true + } + ] +} diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000000..67ea3be92a --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,57 @@ +{ + "[json]": { + "editor.defaultFormatter": "vscode.json-language-features", + "editor.quickSuggestions": { + "strings": true + }, + "editor.suggest.insertMode": "replace", + "files.insertFinalNewline": true, + "gitlens.codeLens.scopes": ["document"] + }, + "[python]": { + "editor.defaultFormatter": "ms-python.black-formatter", + "editor.formatOnSave": true, + "editor.formatOnSaveMode": "file" + }, + "[toml]": { + "editor.defaultFormatter": "tamasfe.even-better-toml", + "editor.formatOnSave": true, + "editor.formatOnSaveMode": "modificationsIfAvailable" + }, + "[yaml]": { + "editor.defaultFormatter": "redhat.vscode-yaml", + "editor.formatOnSave": true, + "editor.formatOnSaveMode": "modificationsIfAvailable" + }, + "editor.rulers": [88], + "editor.defaultFormatter": "esbenp.prettier-vscode", + "evenBetterToml.formatter.alignEntries": false, + "evenBetterToml.formatter.allowedBlankLines": 1, + "evenBetterToml.formatter.arrayAutoExpand": true, + "evenBetterToml.formatter.arrayTrailingComma": true, + "evenBetterToml.formatter.arrayAutoCollapse": true, + "evenBetterToml.formatter.columnWidth": 88, + "evenBetterToml.formatter.compactArrays": true, + "evenBetterToml.formatter.compactInlineTables": true, + "evenBetterToml.formatter.indentEntries": false, + "evenBetterToml.formatter.inlineTableExpand": true, + "evenBetterToml.formatter.reorderArrays": true, + "evenBetterToml.formatter.reorderKeys": true, + "evenBetterToml.formatter.compactEntries": true, + "evenBetterToml.schema.enabled": true, + "isort.check": true, + "isort.args": ["--profile=black", "--filter-files", "--color"], + "python.analysis.typeCheckingMode": "basic", + "python.formatting.provider": "black", + "python.languageServer": "Pylance", + "python.linting.enabled": true, + "python.linting.flake8Enabled": true, + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true, + "python.testing.pytestArgs": [ + "tests", + "--cov=ldm", + "--cov-branch", + "--cov-report=term:skip-covered" + ] +} From 39715017f9127ebc0a4ced39dc0cbbaff29d7801 Mon Sep 17 00:00:00 2001 From: mauwii Date: Mon, 13 Feb 2023 22:57:00 +0100 Subject: [PATCH 03/57] update pyproject.toml --- pyproject.toml | 175 ++++++++++++++++++++++++++++++------------------- 1 file changed, 106 insertions(+), 69 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6357d25653..8e65f89efb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,44 +1,37 @@ [build-system] -requires = ["setuptools~=65.5", "pip~=22.3", "wheel"] -build-backend = "setuptools.build_meta" +build-backend="setuptools.build_meta" +requires=["setuptools ~= 67.1", "wheel"] [project] -name = "InvokeAI" -description = "An implementation of Stable Diffusion which provides various new features and options to aid the image generation process" -requires-python = ">=3.9, <3.11" -readme = { content-type = "text/markdown", file = "README.md" } -keywords = ["stable-diffusion", "AI"] -dynamic = ["version"] -license = { file = "LICENSE" } -authors = [{ name = "The InvokeAI Project", email = "lincoln.stein@gmail.com" }] -classifiers = [ - 'Development Status :: 4 - Beta', - 'Environment :: GPU', - 'Environment :: GPU :: NVIDIA CUDA', - 'Environment :: MacOS X', - 'Intended Audience :: End Users/Desktop', - 'Intended Audience :: Developers', - 'License :: OSI Approved :: MIT License', - 'Operating System :: POSIX :: Linux', - 'Operating System :: MacOS', - 'Operating System :: Microsoft :: Windows', - 'Programming Language :: Python :: 3 :: Only', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Topic :: Artistic Software', - 'Topic :: Internet :: WWW/HTTP :: WSGI :: Application', - 'Topic :: Internet :: WWW/HTTP :: WSGI :: Server', - 'Topic :: Multimedia :: Graphics', - 'Topic :: Scientific/Engineering :: Artificial Intelligence', - 'Topic :: Scientific/Engineering :: Image Processing', +authors=[{name="The InvokeAI Project", email="lincoln.stein@gmail.com"}] +classifiers=[ + "Development Status :: 4 - Beta", + "Environment :: GPU :: NVIDIA CUDA", + "Environment :: GPU", + "Environment :: MacOS X", + "Intended Audience :: Developers", + "Intended Audience :: End Users/Desktop", + "License :: OSI Approved :: MIT License", + "Operating System :: MacOS", + "Operating System :: Microsoft :: Windows", + "Operating System :: POSIX :: Linux", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python", + "Topic :: Artistic Software", + "Topic :: Internet :: WWW/HTTP :: WSGI :: Application", + "Topic :: Internet :: WWW/HTTP :: WSGI :: Server", + "Topic :: Multimedia :: Graphics", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Scientific/Engineering :: Image Processing", ] -dependencies = [ +dependencies=[ "accelerate", "albumentations", "click", - "clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip", - "compel==0.1.7", + "clip_anytorch", "datasets", "diffusers[torch]~=0.13", "dnspython==2.2.1", @@ -54,7 +47,7 @@ dependencies = [ "huggingface-hub>=0.11.1", "imageio", "imageio-ffmpeg", - "k-diffusion", # replacing "k-diffusion @ https://github.com/Birch-san/k-diffusion/archive/refs/heads/mps.zip", + "k-diffusion", "kornia", "npyscreen", "numpy<1.24", @@ -62,8 +55,8 @@ dependencies = [ "opencv-python", "picklescan", "pillow", - "pudb", "prompt-toolkit", + "pudb", "pypatchmatch", "pyreadline3", "pytorch-lightning==1.7.7", @@ -75,62 +68,106 @@ dependencies = [ "streamlit", "taming-transformers-rom1504", "test-tube>=0.7.5", - "torch>=1.13.1", "torch-fidelity", - "torchvision>=0.14.1", + "torch>=1.13.1", "torchmetrics", + "torchvision>=0.14.1", "transformers~=4.25", "windows-curses; sys_platform=='win32'", ] +description="An implementation of Stable Diffusion which provides various new features and options to aid the image generation process" +dynamic=["version"] +keywords=["AI", "stable-diffusion"] +license={text="MIT"} +name="InvokeAI" +readme={content-type="text/markdown", file="README.md"} +requires-python=">=3.9, <3.11" [project.optional-dependencies] -"dist" = ["pip-tools", "pipdeptree", "twine"] -"docs" = [ - "mkdocs-material<9.0", +"dev"=["black", "flake8", "flake8-bugbear", "isort", "pre-commit"] +"dist"=["pip-tools", "pipdeptree", "twine"] +"docs"=[ "mkdocs-git-revision-date-localized-plugin", + "mkdocs-material<9.0", "mkdocs-redirects==1.2.0", ] -"test" = ["pytest>6.0.0", "pytest-cov"] -"xformers" = [ - "xformers~=0.0.16; sys_platform!='darwin'", - "triton; sys_platform=='linux'", -] +"test"=["pytest-cov", "pytest>6.0.0"] +"xformers"=["triton; sys_platform=='linux'", "xformers~=0.0.16; sys_platform!='darwin'"] [project.scripts] # legacy entrypoints; provided for backwards compatibility -"invoke.py" = "ldm.invoke.CLI:main" -"configure_invokeai.py" = "ldm.invoke.config.invokeai_configure:main" -"textual_inversion.py" = "ldm.invoke.training.textual_inversion:main" -"merge_embeddings.py" = "ldm.invoke.merge_diffusers:main" +"configure_invokeai.py"="ldm.invoke.config.invokeai_configure:main" +"invoke.py"="ldm.invoke.CLI:main" +"merge_embeddings.py"="ldm.invoke.merge_diffusers:main" +"textual_inversion.py"="ldm.invoke.training.textual_inversion:main" # modern entrypoints -"invokeai" = "ldm.invoke.CLI:main" -"invokeai-configure" = "ldm.invoke.config.invokeai_configure:main" -"invokeai-merge" = "ldm.invoke.merge_diffusers:main" # note name munging -"invokeai-ti" = "ldm.invoke.training.textual_inversion:main" -"invokeai-model-install" = "ldm.invoke.config.model_install:main" -"invokeai-update" = "ldm.invoke.config.invokeai_update:main" +"invokeai"="ldm.invoke.CLI:main" +"invokeai-configure"="ldm.invoke.config.invokeai_configure:main" +"invokeai-merge"="ldm.invoke.merge_diffusers:main" +"invokeai-ti"="ldm.invoke.training.textual_inversion:main" [project.urls] -"Homepage" = "https://invoke-ai.github.io/InvokeAI/" -"Documentation" = "https://invoke-ai.github.io/InvokeAI/" -"Source" = "https://github.com/invoke-ai/InvokeAI/" -"Bug Reports" = "https://github.com/invoke-ai/InvokeAI/issues" -"Discord" = "https://discord.gg/ZmtBAhwWhy" +"Bug Reports"="https://github.com/invoke-ai/InvokeAI/issues" +"Discord"="https://discord.gg/ZmtBAhwWhy" +"Documentation"="https://invoke-ai.github.io/InvokeAI/" +"Homepage"="https://invoke-ai.github.io/InvokeAI/" +"Source"="https://github.com/invoke-ai/InvokeAI/" + +[tool.setuptools] +license-files=["LICENSE"] [tool.setuptools.dynamic] -version = { attr = "ldm.invoke.__version__" } +version={attr="ldm.invoke.__version__"} [tool.setuptools.packages.find] -"where" = ["."] -"include" = ["invokeai.assets.web*", "invokeai.backend*", "invokeai.frontend.dist*", "invokeai.configs*", "ldm*"] +"include"=[ + "invokeai.assets.web*", + "invokeai.backend*", + "invokeai.configs*", + "invokeai.frontend.dist*", + "ldm*", +] +"where"=["."] [tool.setuptools.package-data] -"invokeai.assets.web" = ["**.png"] -"invokeai.backend" = ["**.png"] -"invokeai.configs" = ["*.example", "**/*.yaml", "*.txt"] -"invokeai.frontend.dist" = ["**"] +"invokeai.assets.web"=["**.png"] +"invokeai.configs"=["**.example", "**.txt", "**.yaml"] +"invokeai.frontend.dist"=["**"] + +[tool.black] +exclude=''' +/( + .git + | .tox + | .venv + | _build + | build + | dist + | node_modules +)/ +''' +include='.pyi?$' +line-length=88 +source=['invokeai/backend', 'ldm/invoke'] +target-version=['py39'] + +[tool.isort] +extend_ignore=["scripts"] +profile="black" +py_version=39 + +[tool.coverage.run] +branch=true +parallel=true + +[tool.coverage.report] +skip_covered=true +skip_empty=true + +[tool.coverage.paths] +source=["invokeai/backend", "ldm/invoke"] [tool.pytest.ini_options] -addopts = "-p pytest_cov --junitxml=junit/test-results.xml --cov-report=term:skip-covered --cov=ldm/invoke --cov=backend --cov-branch" +addopts=["--cov=invokeai/backend", "--cov=ldm/invoke"] From 02e84c9565e264aa4daebdddcba922764a2248ec Mon Sep 17 00:00:00 2001 From: mauwii Date: Mon, 13 Feb 2023 22:57:11 +0100 Subject: [PATCH 04/57] add .flake8 --- .flake8 | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 .flake8 diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000000..35fb6f8ea2 --- /dev/null +++ b/.flake8 @@ -0,0 +1,5 @@ +[flake8] +max-line-length = 88 +extend-select = C,E,F,W,B,B9 +# B905 should be enabled when we drop support for 3.9 +extend-ignore = E203, E501, B905 From 2a739890a3caf19da70c81a2df5c3ae24b55cd8e Mon Sep 17 00:00:00 2001 From: mauwii Date: Mon, 13 Feb 2023 22:57:23 +0100 Subject: [PATCH 05/57] add .pre-commit-config.yaml --- .pre-commit-config.yaml | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000..c1be8103b6 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,28 @@ +# See https://pre-commit.com for more information +# See https://pre-commit.com/hooks.html for more hooks +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-ast + - id: check-yaml + args: [--allow-multiple-documents] + - id: check-added-large-files + - repo: https://github.com/psf/black + rev: 23.1.0 + hooks: + - id: black + - repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort + - repo: https://github.com/PyCQA/flake8 + rev: 6.0.0 + hooks: + - id: flake8 + - repo: https://github.com/pre-commit/mirrors-prettier + rev: 'v2.7.1' + hooks: + - id: prettier From 6c11e8ee068c481e10ccc36fa42e76e00fa2ccf6 Mon Sep 17 00:00:00 2001 From: mauwii Date: Thu, 16 Feb 2023 21:09:56 +0100 Subject: [PATCH 06/57] update mkdocs.yml - add feature `content.tabs.link` --- mkdocs.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/mkdocs.yml b/mkdocs.yml index ebd9ec0acf..f8d5f4631e 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -33,6 +33,7 @@ theme: icon: material/lightbulb-outline name: Switch to light mode features: + - content.tabs.link - navigation.instant - navigation.tabs - navigation.top From 87b466302667067ff908fd2282eb34c528dd9829 Mon Sep 17 00:00:00 2001 From: mauwii Date: Thu, 16 Feb 2023 21:10:46 +0100 Subject: [PATCH 07/57] add `/docs/.markdownlint.jsonc` - for now only disable `MD046` --- docs/.markdownlint.jsonc | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 docs/.markdownlint.jsonc diff --git a/docs/.markdownlint.jsonc b/docs/.markdownlint.jsonc new file mode 100644 index 0000000000..9a24fcfc80 --- /dev/null +++ b/docs/.markdownlint.jsonc @@ -0,0 +1,3 @@ +{ + "MD046": false +} From 519a9071a84b95181d6b06dc297383db8cb7f5ce Mon Sep 17 00:00:00 2001 From: mauwii Date: Thu, 16 Feb 2023 21:13:22 +0100 Subject: [PATCH 08/57] add "How to contribute" to docs - not yet finished --- .../010_HOW_TO_CONTRIBUTE.md | 94 +++++++++++++++++++ 1 file changed, 94 insertions(+) create mode 100644 docs/installation/Developers_documentation/010_HOW_TO_CONTRIBUTE.md diff --git a/docs/installation/Developers_documentation/010_HOW_TO_CONTRIBUTE.md b/docs/installation/Developers_documentation/010_HOW_TO_CONTRIBUTE.md new file mode 100644 index 0000000000..a613c0b3b6 --- /dev/null +++ b/docs/installation/Developers_documentation/010_HOW_TO_CONTRIBUTE.md @@ -0,0 +1,94 @@ +--- +title: How to Contribute +--- + +## pre-requirements + +To follow the steps in this tutorial you will need the following: + +- [git](https://git-scm.com/downloads) +- [GitHub](https://github.com) account +- A Code Editor (personally I use Visual Studio Code) + +## Fork Repository + +The first step to be done if you want to contribute to InvokeAI, is to fork the +rpeository. + +The easiest way to do so is by clicking +[here](https://github.com/invoke-ai/InvokeAI/fork). It is also possible by +opening [InvokeAI](https://github.com/invoke-ai/InvoekAI) and click on the +"Fork" Button in the top right. + +## Clone your fork + +After you forked the Repository, you should clone it to your dev machine: + +=== "Linux/MacOS" + + ```sh + git clone https://github.com//InvokeAI \ + && cd InvokeAI + ``` + +=== "Windows" + + ```powershell + git clone https://github.com//InvokeAI ` + && cd InvokeAI + ``` + +## Install in Editable Mode + +To install InvokeAI in editable mode, (as always) we recommend to create and +activate a venv first. Afterwards you can install the InvokeAI Package, +including dev and docs extras in editable mode, follwed by the installation of +the pre-commit hook: + +=== "Linux/MacOS" + + ```sh + python -m venv .venv \ + --prompt InvokeAI \ + --upgrade-deps \ + && source .venv/bin/activate \ + && pip install \ + --upgrade-deps \ + --use-pep517 \ + --editable=".[dev,docs]" \ + && pre-commit install + ``` + +=== "Windows" + + ```powershell + python -m venv .venv ` + --prompt InvokeAI ` + --upgrade-deps ` + && .venv/scripts/activate.ps1 ` + && pip install ` + --upgrade ` + --use-pep517 ` + --editable=".[dev,docs]" ` + && pre-commit install + ``` + +## Create a branch + +Make sure you are on main branch, from there create your feature branch: + +=== "Linux/MacOS" + + ```sh + git checkout main \ + && git pull \ + && git checkout -B + ``` + +=== "Windows" + + ```powershell + git checkout main ` + && git pull ` + && git checkout -B + ``` From c3f533f20f4a30c842f85f2b5789d7575bd468b9 Mon Sep 17 00:00:00 2001 From: mauwii Date: Fri, 17 Feb 2023 02:52:21 +0100 Subject: [PATCH 09/57] update .pre-commit-config.yaml --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c1be8103b6..bd38a6128b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,7 +8,7 @@ repos: - id: end-of-file-fixer - id: check-ast - id: check-yaml - args: [--allow-multiple-documents] + args: [--unsafe] - id: check-added-large-files - repo: https://github.com/psf/black rev: 23.1.0 From c134161a45f5c86584eae5dce76d73376ea1fcc3 Mon Sep 17 00:00:00 2001 From: mauwii Date: Fri, 17 Feb 2023 02:52:28 +0100 Subject: [PATCH 10/57] update .editorconfig --- .editorconfig | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/.editorconfig b/.editorconfig index e03231eaca..87deaa4477 100644 --- a/.editorconfig +++ b/.editorconfig @@ -2,6 +2,7 @@ root = true # All files [*] +max_line_length = 80 charset = utf-8 end_of_line = lf indent_size = 2 @@ -11,4 +12,13 @@ trim_trailing_whitespace = true # Python [*.py] +max_line_length = 88 +profile = black + +# css +[*.css] +indent_size = 4 + +# flake8 +[.flake8] indent_size = 4 From 32314999925396159162c8bbca442d1f6631fea2 Mon Sep 17 00:00:00 2001 From: mauwii Date: Fri, 17 Feb 2023 02:55:41 +0100 Subject: [PATCH 11/57] update .vscode settings and extensions --- .vscode/extensions.json | 10 +++++++++- .vscode/settings.json | 14 +++++++++++++- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/.vscode/extensions.json b/.vscode/extensions.json index 45964d6921..9c5bb5d049 100644 --- a/.vscode/extensions.json +++ b/.vscode/extensions.json @@ -9,6 +9,14 @@ "ms-python.vscode-pylance", "redhat.vscode-yaml", "tamasfe.even-better-toml", - "eamodio.gitlens" + "eamodio.gitlens", + "foxundermoon.shell-format", + "timonwong.shellcheck", + "esbenp.prettier-vscode", + "davidanson.vscode-markdownlint", + "yzhang.markdown-all-in-one", + "bierner.github-markdown-preview", + "ms-azuretools.vscode-docker", + "mads-hartmann.bash-ide-vscode" ] } diff --git a/.vscode/settings.json b/.vscode/settings.json index 67ea3be92a..6423669115 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,6 +1,6 @@ { "[json]": { - "editor.defaultFormatter": "vscode.json-language-features", + "editor.defaultFormatter": "esbenp.prettier-vscode", "editor.quickSuggestions": { "strings": true }, @@ -23,6 +23,18 @@ "editor.formatOnSave": true, "editor.formatOnSaveMode": "modificationsIfAvailable" }, + "[markdown]": { + "editor.rulers": [80], + "editor.unicodeHighlight.ambiguousCharacters": false, + "editor.unicodeHighlight.invisibleCharacters": false, + "diffEditor.ignoreTrimWhitespace": false, + "editor.wordWrap": "on", + "editor.quickSuggestions": { + "comments": "off", + "strings": "off", + "other": "off" + } + }, "editor.rulers": [88], "editor.defaultFormatter": "esbenp.prettier-vscode", "evenBetterToml.formatter.alignEntries": false, From 4e0fe4ad6e4005b338fd70ad97dda6b6d150b324 Mon Sep 17 00:00:00 2001 From: mauwii Date: Fri, 17 Feb 2023 06:37:17 +0100 Subject: [PATCH 12/57] update black / flake8 related settings - add flake8-black to dev extras - update `.flake8` - update flake8 pre-commit hook --- .flake8 | 50 +++++++++++++++++++++++++++++++++++++---- .pre-commit-config.yaml | 3 +++ pyproject.toml | 28 ++++++++++++++--------- 3 files changed, 67 insertions(+), 14 deletions(-) diff --git a/.flake8 b/.flake8 index 35fb6f8ea2..5137321b60 100644 --- a/.flake8 +++ b/.flake8 @@ -1,5 +1,47 @@ [flake8] -max-line-length = 88 -extend-select = C,E,F,W,B,B9 -# B905 should be enabled when we drop support for 3.9 -extend-ignore = E203, E501, B905 +# line length is intentionally set to 80 here because black uses Bugbear +max-line-length = 80 +extend-ignore = + # See https://github.com/PyCQA/pycodestyle/issues/373 + E203, + # use Bugbear's B950 instead + E501, + # from black repo https://github.com/psf/black/blob/main/.flake8 + E266, W503, B907 +extend-select = + # Bugbear line length + B950 +exclude = + .venv, + .git, + .tox, + dist, + doc, + *lib/python*, + *egg, + build + scripts/orig_scripts/* + ldm/models/* + ldm/modules/* + ldm/data/* + ldm/generate.py + ldm/util.py + ldm/simplet2i.py +per-file-ignores = + # B950 line too long + # W605 invalid escape sequence + # F841 assigned to but never used + # F401 imported but unused + tests/test_prompt_parser.py: B950, W605, F401 + tests/test_textual_inversion.py: F841 + # B023 Function definition does not bind loop variable + scripts/legacy_api.py: F401, B950, B023, F841 + ldm/invoke/__init__.py: F401 + # B010 Do not call setattr with a constant attribute value + ldm/invoke/server_legacy.py: B010 + +# ===================== +# flake-quote settings: +# ===================== +# Set this to match black style: +inline-quotes = double diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bd38a6128b..3dcfb7a2c0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,6 +22,9 @@ repos: rev: 6.0.0 hooks: - id: flake8 + additional_dependencies: + - flake8-black + - flake8-bugbear - repo: https://github.com/pre-commit/mirrors-prettier rev: 'v2.7.1' hooks: diff --git a/pyproject.toml b/pyproject.toml index 8e65f89efb..d8101b29e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,7 +84,14 @@ readme={content-type="text/markdown", file="README.md"} requires-python=">=3.9, <3.11" [project.optional-dependencies] -"dev"=["black", "flake8", "flake8-bugbear", "isort", "pre-commit"] +"dev"=[ + "black[jupyter]", + "flake8", + "flake8-black", + "flake8-bugbear", + "isort", + "pre-commit", +] "dist"=["pip-tools", "pipdeptree", "twine"] "docs"=[ "mkdocs-git-revision-date-localized-plugin", @@ -139,24 +146,25 @@ version={attr="ldm.invoke.__version__"} [tool.black] exclude=''' /( - .git - | .tox - | .venv - | _build - | build - | dist - | node_modules + .git + | .tox + | .venv + | _build + | build + | dist + | node_modules )/ ''' include='.pyi?$' line-length=88 -source=['invokeai/backend', 'ldm/invoke'] +source=["installer", "invokeai/backend", "ldm/invoke"] target-version=['py39'] [tool.isort] -extend_ignore=["scripts"] profile="black" py_version=39 +skip_gitignore=true +skip_glob=["scripts/orig_scripts/*"] [tool.coverage.run] branch=true From b4fd02b91045e5bab2c2e2b6d5e959542843a1ad Mon Sep 17 00:00:00 2001 From: mauwii Date: Sat, 18 Feb 2023 03:01:05 +0100 Subject: [PATCH 13/57] add more hooks, reorder hooks, update .flake8 --- .flake8 | 6 ++---- .pre-commit-config.yaml | 30 +++++++++++++++++++++--------- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/.flake8 b/.flake8 index 5137321b60..57ac5f4c01 100644 --- a/.flake8 +++ b/.flake8 @@ -1,6 +1,5 @@ [flake8] -# line length is intentionally set to 80 here because black uses Bugbear -max-line-length = 80 +max-line-length = 88 extend-ignore = # See https://github.com/PyCQA/pycodestyle/issues/373 E203, @@ -33,13 +32,12 @@ per-file-ignores = # F841 assigned to but never used # F401 imported but unused tests/test_prompt_parser.py: B950, W605, F401 - tests/test_textual_inversion.py: F841 + tests/test_textual_inversion.py: F841, B950 # B023 Function definition does not bind loop variable scripts/legacy_api.py: F401, B950, B023, F841 ldm/invoke/__init__.py: F401 # B010 Do not call setattr with a constant attribute value ldm/invoke/server_legacy.py: B010 - # ===================== # flake-quote settings: # ===================== diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3dcfb7a2c0..bf4fb65b39 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,20 +4,36 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.4.0 hooks: - - id: trailing-whitespace - - id: end-of-file-fixer - - id: check-ast - - id: check-yaml - args: [--unsafe] - id: check-added-large-files + - id: check-ast + - id: check-executables-have-shebangs + - id: check-json + - id: check-merge-conflict + - id: check-symlinks + - id: check-toml + - id: check-yaml + args: ['--unsafe'] + - id: end-of-file-fixer + # files: \.(py|sh|rst|md|yml|yaml)$ + exclude: \.(json|jsonc|js|map)$ + - id: trailing-whitespace + exclude: \.(json|jsonc|js|map)$ + + - repo: https://github.com/pre-commit/mirrors-prettier + rev: 'v3.0.0-alpha.4' + hooks: + - id: prettier + - repo: https://github.com/psf/black rev: 23.1.0 hooks: - id: black + - repo: https://github.com/pycqa/isort rev: 5.12.0 hooks: - id: isort + - repo: https://github.com/PyCQA/flake8 rev: 6.0.0 hooks: @@ -25,7 +41,3 @@ repos: additional_dependencies: - flake8-black - flake8-bugbear - - repo: https://github.com/pre-commit/mirrors-prettier - rev: 'v2.7.1' - hooks: - - id: prettier From 0443befd2fdc0902fc4b738309138ade2037a79e Mon Sep 17 00:00:00 2001 From: mauwii Date: Sat, 18 Feb 2023 03:01:55 +0100 Subject: [PATCH 14/57] update pyproject.toml and vscode settings --- .vscode/settings.json | 2 +- pyproject.toml | 111 +++++++++++++++++++----------------------- 2 files changed, 50 insertions(+), 63 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 6423669115..f888881623 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -49,7 +49,7 @@ "evenBetterToml.formatter.inlineTableExpand": true, "evenBetterToml.formatter.reorderArrays": true, "evenBetterToml.formatter.reorderKeys": true, - "evenBetterToml.formatter.compactEntries": true, + "evenBetterToml.formatter.compactEntries": false, "evenBetterToml.schema.enabled": true, "isort.check": true, "isort.args": ["--profile=black", "--filter-files", "--color"], diff --git a/pyproject.toml b/pyproject.toml index d8101b29e4..d1cd1d1938 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,10 +1,10 @@ [build-system] -build-backend="setuptools.build_meta" -requires=["setuptools ~= 67.1", "wheel"] +build-backend = "setuptools.build_meta" +requires = ["setuptools ~= 67.1", "wheel"] [project] -authors=[{name="The InvokeAI Project", email="lincoln.stein@gmail.com"}] -classifiers=[ +authors = [{name = "The InvokeAI Project", email = "lincoln.stein@gmail.com"}] +classifiers = [ "Development Status :: 4 - Beta", "Environment :: GPU :: NVIDIA CUDA", "Environment :: GPU", @@ -27,7 +27,7 @@ classifiers=[ "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Scientific/Engineering :: Image Processing", ] -dependencies=[ +dependencies = [ "accelerate", "albumentations", "click", @@ -75,16 +75,16 @@ dependencies=[ "transformers~=4.25", "windows-curses; sys_platform=='win32'", ] -description="An implementation of Stable Diffusion which provides various new features and options to aid the image generation process" -dynamic=["version"] -keywords=["AI", "stable-diffusion"] -license={text="MIT"} -name="InvokeAI" -readme={content-type="text/markdown", file="README.md"} -requires-python=">=3.9, <3.11" +description = "An implementation of Stable Diffusion which provides various new features and options to aid the image generation process" +dynamic = ["version"] +keywords = ["AI", "stable-diffusion"] +license = {text = "MIT"} +name = "InvokeAI" +readme = {content-type = "text/markdown", file = "README.md"} +requires-python = ">=3.9, <3.11" [project.optional-dependencies] -"dev"=[ +"dev" = [ "black[jupyter]", "flake8", "flake8-black", @@ -92,90 +92,77 @@ requires-python=">=3.9, <3.11" "isort", "pre-commit", ] -"dist"=["pip-tools", "pipdeptree", "twine"] -"docs"=[ +"dist" = ["pip-tools", "pipdeptree", "twine"] +"docs" = [ "mkdocs-git-revision-date-localized-plugin", "mkdocs-material<9.0", "mkdocs-redirects==1.2.0", ] -"test"=["pytest-cov", "pytest>6.0.0"] -"xformers"=["triton; sys_platform=='linux'", "xformers~=0.0.16; sys_platform!='darwin'"] +"test" = ["pytest-cov", "pytest>6.0.0"] +"xformers" = [ + "triton; sys_platform=='linux'", + "xformers~=0.0.16; sys_platform!='darwin'", +] [project.scripts] # legacy entrypoints; provided for backwards compatibility -"configure_invokeai.py"="ldm.invoke.config.invokeai_configure:main" -"invoke.py"="ldm.invoke.CLI:main" -"merge_embeddings.py"="ldm.invoke.merge_diffusers:main" -"textual_inversion.py"="ldm.invoke.training.textual_inversion:main" +"configure_invokeai.py" = "ldm.invoke.config.invokeai_configure:main" +"invoke.py" = "ldm.invoke.CLI:main" +"merge_embeddings.py" = "ldm.invoke.merge_diffusers:main" +"textual_inversion.py" = "ldm.invoke.training.textual_inversion:main" # modern entrypoints -"invokeai"="ldm.invoke.CLI:main" -"invokeai-configure"="ldm.invoke.config.invokeai_configure:main" -"invokeai-merge"="ldm.invoke.merge_diffusers:main" -"invokeai-ti"="ldm.invoke.training.textual_inversion:main" +"invokeai" = "ldm.invoke.CLI:main" +"invokeai-configure" = "ldm.invoke.config.invokeai_configure:main" +"invokeai-merge" = "ldm.invoke.merge_diffusers:main" +"invokeai-ti" = "ldm.invoke.training.textual_inversion:main" [project.urls] -"Bug Reports"="https://github.com/invoke-ai/InvokeAI/issues" -"Discord"="https://discord.gg/ZmtBAhwWhy" -"Documentation"="https://invoke-ai.github.io/InvokeAI/" -"Homepage"="https://invoke-ai.github.io/InvokeAI/" -"Source"="https://github.com/invoke-ai/InvokeAI/" +"Bug Reports" = "https://github.com/invoke-ai/InvokeAI/issues" +"Discord" = "https://discord.gg/ZmtBAhwWhy" +"Documentation" = "https://invoke-ai.github.io/InvokeAI/" +"Homepage" = "https://invoke-ai.github.io/InvokeAI/" +"Source" = "https://github.com/invoke-ai/InvokeAI/" [tool.setuptools] -license-files=["LICENSE"] +license-files = ["LICENSE"] [tool.setuptools.dynamic] -version={attr="ldm.invoke.__version__"} +version = {attr = "ldm.invoke.__version__"} [tool.setuptools.packages.find] -"include"=[ +"include" = [ "invokeai.assets.web*", "invokeai.backend*", "invokeai.configs*", "invokeai.frontend.dist*", "ldm*", ] -"where"=["."] +"where" = ["."] [tool.setuptools.package-data] -"invokeai.assets.web"=["**.png"] -"invokeai.configs"=["**.example", "**.txt", "**.yaml"] -"invokeai.frontend.dist"=["**"] +"invokeai.assets.web" = ["**.png"] +"invokeai.configs" = ["**.example", "**.txt", "**.yaml"] +"invokeai.frontend.dist" = ["**"] [tool.black] -exclude=''' -/( - .git - | .tox - | .venv - | _build - | build - | dist - | node_modules -)/ -''' -include='.pyi?$' -line-length=88 -source=["installer", "invokeai/backend", "ldm/invoke"] -target-version=['py39'] +line-length = 88 +target-version = ['py310'] [tool.isort] -profile="black" -py_version=39 -skip_gitignore=true -skip_glob=["scripts/orig_scripts/*"] +profile = "black" [tool.coverage.run] -branch=true -parallel=true +branch = true +parallel = true [tool.coverage.report] -skip_covered=true -skip_empty=true +skip_covered = true +skip_empty = true [tool.coverage.paths] -source=["invokeai/backend", "ldm/invoke"] +source = ["invokeai/backend", "ldm/invoke"] [tool.pytest.ini_options] -addopts=["--cov=invokeai/backend", "--cov=ldm/invoke"] +addopts = ["--cov=invokeai/backend", "--cov=ldm/invoke"] From ee3d695e2e97509badfe45bd517f57916039a063 Mon Sep 17 00:00:00 2001 From: mauwii Date: Sat, 18 Feb 2023 03:02:41 +0100 Subject: [PATCH 15/57] remove command from json to be compliant --- .vscode/launch.json | 3 --- 1 file changed, 3 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index cde7f9b4c3..2b8f22f75c 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -1,7 +1,4 @@ { - // Use IntelliSense to learn about possible attributes. - // Hover to view descriptions of existing attributes. - // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 "version": "0.2.0", "configurations": [ { From d77dc68119c281b3943b40a304ff4c7a162af60d Mon Sep 17 00:00:00 2001 From: mauwii Date: Sat, 18 Feb 2023 14:42:03 +0100 Subject: [PATCH 16/57] better config of pre-commit hooks: - better order of hooks - add flake8-comprehensions and flake8-simplify - remove unecesarry hooks which are covered by previous hooks - add hooks - check-executables-have-shebangs - check-shebang-scripts-are-executable --- .pre-commit-config.yaml | 42 +++++++++++++++++++---------------------- 1 file changed, 19 insertions(+), 23 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bf4fb65b39..636600ba4d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,29 +1,6 @@ # See https://pre-commit.com for more information # See https://pre-commit.com/hooks.html for more hooks repos: - - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 - hooks: - - id: check-added-large-files - - id: check-ast - - id: check-executables-have-shebangs - - id: check-json - - id: check-merge-conflict - - id: check-symlinks - - id: check-toml - - id: check-yaml - args: ['--unsafe'] - - id: end-of-file-fixer - # files: \.(py|sh|rst|md|yml|yaml)$ - exclude: \.(json|jsonc|js|map)$ - - id: trailing-whitespace - exclude: \.(json|jsonc|js|map)$ - - - repo: https://github.com/pre-commit/mirrors-prettier - rev: 'v3.0.0-alpha.4' - hooks: - - id: prettier - - repo: https://github.com/psf/black rev: 23.1.0 hooks: @@ -41,3 +18,22 @@ repos: additional_dependencies: - flake8-black - flake8-bugbear + - flake8-comprehensions + - flake8-simplify + + - repo: https://github.com/pre-commit/mirrors-prettier + rev: 'v3.0.0-alpha.4' + hooks: + - id: prettier + + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: check-added-large-files + - id: check-executables-have-shebangs + - id: check-shebang-scripts-are-executable + - id: check-merge-conflict + - id: check-symlinks + - id: check-toml + - id: end-of-file-fixer + - id: trailing-whitespace From e3f906e90d9c25bc47df42f4ff58419fce1f4705 Mon Sep 17 00:00:00 2001 From: mauwii Date: Sat, 18 Feb 2023 14:42:42 +0100 Subject: [PATCH 17/57] update .flake8 - use extend-exclude so that default excludes are not overwritten --- .flake8 | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/.flake8 b/.flake8 index 57ac5f4c01..81d8d82bfb 100644 --- a/.flake8 +++ b/.flake8 @@ -10,15 +10,7 @@ extend-ignore = extend-select = # Bugbear line length B950 -exclude = - .venv, - .git, - .tox, - dist, - doc, - *lib/python*, - *egg, - build +extend-exclude = scripts/orig_scripts/* ldm/models/* ldm/modules/* From bb1769ababcc107e792a34dea1ab89a88d0fd7b7 Mon Sep 17 00:00:00 2001 From: mauwii Date: Sat, 18 Feb 2023 14:43:01 +0100 Subject: [PATCH 18/57] remove non working .editorconfig entrys --- .editorconfig | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.editorconfig b/.editorconfig index 87deaa4477..e447617b01 100644 --- a/.editorconfig +++ b/.editorconfig @@ -12,8 +12,7 @@ trim_trailing_whitespace = true # Python [*.py] -max_line_length = 88 -profile = black +indent_size = 4 # css [*.css] From bfe64b1510dfeec58bf6814940d315e3c217c8cf Mon Sep 17 00:00:00 2001 From: mauwii Date: Sat, 18 Feb 2023 14:43:41 +0100 Subject: [PATCH 19/57] allign prettierrc with config in frontend --- .prettierignore | 14 ++++++++++++++ .prettierrc.yaml | 14 +++++++++----- 2 files changed, 23 insertions(+), 5 deletions(-) create mode 100644 .prettierignore diff --git a/.prettierignore b/.prettierignore new file mode 100644 index 0000000000..2ef13d5aae --- /dev/null +++ b/.prettierignore @@ -0,0 +1,14 @@ +invokeai/frontend/.husky +invokeai/frontend/patches + +# Ignore artifacts: +build +coverage +static +invokeai/frontend/dist + +# Ignore all HTML files: +*.html + +# Ignore deprecated docs +docs/installation/deprecated_documentation diff --git a/.prettierrc.yaml b/.prettierrc.yaml index ce4b99a07b..8050e4cbb7 100644 --- a/.prettierrc.yaml +++ b/.prettierrc.yaml @@ -1,9 +1,10 @@ -endOfLine: lf -tabWidth: 2 -useTabs: false -singleQuote: true -quoteProps: as-needed embeddedLanguageFormatting: auto +endOfLine: lf +singleQuote: true +semi: true +tabWidth: 2 +trailingComma: es5 +useTabs: false overrides: - files: '*.md' options: @@ -11,3 +12,6 @@ overrides: printWidth: 80 parser: markdown cursorOffset: -1 + - files: 'invokeai/frontend/public/locales/*.json' + options: + tabWidth: 4 From e80160f8dd84b8919871db655ae0e7d4589a2900 Mon Sep 17 00:00:00 2001 From: mauwii Date: Sat, 18 Feb 2023 14:47:36 +0100 Subject: [PATCH 20/57] update config of black and isort black: - extend-exclude legacy scripts - config for python 3.9 as long as we support it isort: - set atomic to true to only apply if no syntax errors are introduced - config for python 3.9 as long as we support it - extend_skib_glob legacy scripts - filter_files - match line_length with black - remove_redundant_aliases - skip_gitignore - set src paths - include virtual_env to detect third party modules --- pyproject.toml | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index d1cd1d1938..cfc4ad364b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -147,11 +147,26 @@ version = {attr = "ldm.invoke.__version__"} "invokeai.frontend.dist" = ["**"] [tool.black] +extend-exclude = ''' +/( + # skip legacy scripts + | scripts/orig_scripts +)/ +''' line-length = 88 -target-version = ['py310'] +target-version = ['py39'] [tool.isort] +atomic = true +extend_skip_glob = ["scripts/orig_scripts/*"] +filter_files = true +line_length = 88 profile = "black" +py_version = 39 +remove_redundant_aliases = true +skip_gitignore = true +src_paths = ["installer", "invokeai", "ldm", "tests"] +virtual_env = ".venv" [tool.coverage.run] branch = true From ed06a70eca3f944f3f99a1c51a5cfdf43abf220a Mon Sep 17 00:00:00 2001 From: mauwii Date: Sat, 18 Feb 2023 15:04:25 +0100 Subject: [PATCH 21/57] add pre-commit hook `no-commit-to-branch` additional layer to prevent accidential commits directly to main branch --- .pre-commit-config.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 636600ba4d..e5d024eaee 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -36,4 +36,6 @@ repos: - id: check-symlinks - id: check-toml - id: end-of-file-fixer + - id: no-commit-to-branch + args: ['--branch', 'main'] - id: trailing-whitespace From 2aa5688d9021ebeab1ae4317119e6114ebadff07 Mon Sep 17 00:00:00 2001 From: mauwii Date: Sun, 19 Feb 2023 01:37:02 +0100 Subject: [PATCH 22/57] update `docs/.markdownlint.jsonc` - disable ul-indent - disable list-marker-space --- docs/.markdownlint.jsonc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/.markdownlint.jsonc b/docs/.markdownlint.jsonc index 9a24fcfc80..c6b91b533f 100644 --- a/docs/.markdownlint.jsonc +++ b/docs/.markdownlint.jsonc @@ -1,3 +1,5 @@ { - "MD046": false + "MD046": false, + "MD007": false, + "MD030": false } From 2f25363d767bfad5d60afd87b13f0491b5553093 Mon Sep 17 00:00:00 2001 From: mauwii Date: Sun, 19 Feb 2023 01:42:52 +0100 Subject: [PATCH 23/57] update "how to contribute" doc and md indentation --- .editorconfig | 6 ++ .prettierrc.yaml | 4 +- .../010_HOW_TO_CONTRIBUTE.md | 95 ++++++++++++++----- 3 files changed, 79 insertions(+), 26 deletions(-) diff --git a/.editorconfig b/.editorconfig index e447617b01..fe9b4a61d1 100644 --- a/.editorconfig +++ b/.editorconfig @@ -21,3 +21,9 @@ indent_size = 4 # flake8 [.flake8] indent_size = 4 + +# Markdown MkDocs +[docs/**/*.md] +max_line_length = 80 +indent_size = 4 +indent_style = unset diff --git a/.prettierrc.yaml b/.prettierrc.yaml index 8050e4cbb7..457c8267f6 100644 --- a/.prettierrc.yaml +++ b/.prettierrc.yaml @@ -2,7 +2,6 @@ embeddedLanguageFormatting: auto endOfLine: lf singleQuote: true semi: true -tabWidth: 2 trailingComma: es5 useTabs: false overrides: @@ -12,6 +11,9 @@ overrides: printWidth: 80 parser: markdown cursorOffset: -1 + - files: docs/**/*.md + options: + tabWidth: 4 - files: 'invokeai/frontend/public/locales/*.json' options: tabWidth: 4 diff --git a/docs/installation/Developers_documentation/010_HOW_TO_CONTRIBUTE.md b/docs/installation/Developers_documentation/010_HOW_TO_CONTRIBUTE.md index a613c0b3b6..d83833d063 100644 --- a/docs/installation/Developers_documentation/010_HOW_TO_CONTRIBUTE.md +++ b/docs/installation/Developers_documentation/010_HOW_TO_CONTRIBUTE.md @@ -2,52 +2,62 @@ title: How to Contribute --- -## pre-requirements +There are different ways how you can contribute to +[InvokeAI](https://github.com/invoke-ai/InvokeAI), like Translations, opening +Issues for Bugs or ideas how to improve. -To follow the steps in this tutorial you will need the following: +## Pull Requests -- [git](https://git-scm.com/downloads) -- [GitHub](https://github.com) account -- A Code Editor (personally I use Visual Studio Code) +### pre-requirements -## Fork Repository +To follow the steps in this tutorial you will need: + +- [GitHub](https://github.com) account +- [git](https://git-scm.com/downloads) source controll +- Text / Code Editor (personally I preffer + [Visual Studio Code](https://code.visualstudio.com/Download)) +- Terminal: + - If you are on Linux/MacOS you can use bash or zsh + - for Windows Users the commands are written for PowerShell + +### Fork Repository The first step to be done if you want to contribute to InvokeAI, is to fork the rpeository. -The easiest way to do so is by clicking -[here](https://github.com/invoke-ai/InvokeAI/fork). It is also possible by -opening [InvokeAI](https://github.com/invoke-ai/InvoekAI) and click on the -"Fork" Button in the top right. +Since you are already reading this doc, the easiest way to do so is by clicking +[here](https://github.com/invoke-ai/InvokeAI/fork). You could also open +[InvokeAI](https://github.com/invoke-ai/InvoekAI) and click on the "Fork" Button +in the top right. -## Clone your fork +### Clone your fork After you forked the Repository, you should clone it to your dev machine: -=== "Linux/MacOS" +=== "Linux:fontawesome-brands-linux: / MacOS:simple-apple:" - ```sh + ``` sh git clone https://github.com//InvokeAI \ && cd InvokeAI ``` -=== "Windows" +=== "Windows:fontawesome-brands-windows:" - ```powershell + ``` powershell git clone https://github.com//InvokeAI ` && cd InvokeAI ``` -## Install in Editable Mode +### Install in Editable Mode To install InvokeAI in editable mode, (as always) we recommend to create and activate a venv first. Afterwards you can install the InvokeAI Package, including dev and docs extras in editable mode, follwed by the installation of the pre-commit hook: -=== "Linux/MacOS" +=== "Linux:fontawesome-brands-linux: / MacOS:simple-apple:" - ```sh + ``` sh python -m venv .venv \ --prompt InvokeAI \ --upgrade-deps \ @@ -59,9 +69,9 @@ the pre-commit hook: && pre-commit install ``` -=== "Windows" +=== "Windows:fontawesome-brands-windows:" - ```powershell + ``` powershell python -m venv .venv ` --prompt InvokeAI ` --upgrade-deps ` @@ -73,22 +83,57 @@ the pre-commit hook: && pre-commit install ``` -## Create a branch +### Create a branch Make sure you are on main branch, from there create your feature branch: -=== "Linux/MacOS" +=== "Linux:fontawesome-brands-linux: / MacOS:simple-apple:" - ```sh + ``` sh git checkout main \ && git pull \ && git checkout -B ``` -=== "Windows" +=== "Windows:fontawesome-brands-windows:" - ```powershell + ``` powershell git checkout main ` && git pull ` && git checkout -B ``` + +### Commit your changes + +When you are done with adding / updating content, you need to commit those +changes to your repository before you can actually open an PR: + +```{ .sh .annotate } +git add # (1)! +git commit -m "A commit message which describes your change" +git push +``` + +1. Replace this with a space seperated list of the files you changed, like: + `README.md foo.sh bar.json baz` + +### Create a Pull Request + +After pushing your changes, you are ready to create a Pull Request. just head +over to your fork on [GitHub](https://github.com), which should already show you +a message that there have been recent changes on your feature branch and a green +button which you could use to create the PR. + +The default target for your PRs would be the main branch of +[invoke-ai/InvokeAI](https://github.com/invoke-ai/InvokeAI) + +Another way would be to create it in VS-Code or via the GitHub CLI (or even via +the GitHub CLI in a VS-Code Terminal Window 🤭): + +```sh +gh pr create +``` + +The CLI will inform you if there are still unpushed commits on your branch. It +will also prompt you for things like the the Title and the Body (Description) if +you did not already pass them as arguments. From bec81170b5aecafd899ae49d48cea60bc3a93de8 Mon Sep 17 00:00:00 2001 From: mauwii Date: Sun, 19 Feb 2023 02:04:27 +0100 Subject: [PATCH 24/57] move contribution docs to help section, add index --- docs/help/contribute/010_PULL_REQUEST.md | 133 +++++++++++++++++++++++ docs/help/contribute/index.md | 10 ++ 2 files changed, 143 insertions(+) create mode 100644 docs/help/contribute/010_PULL_REQUEST.md create mode 100644 docs/help/contribute/index.md diff --git a/docs/help/contribute/010_PULL_REQUEST.md b/docs/help/contribute/010_PULL_REQUEST.md new file mode 100644 index 0000000000..129df49349 --- /dev/null +++ b/docs/help/contribute/010_PULL_REQUEST.md @@ -0,0 +1,133 @@ +--- +title: Pull Requests +--- + +## pre-requirements + +To follow the steps in this tutorial you will need: + +- [GitHub](https://github.com) account +- [git](https://git-scm.com/downloads) source controll +- Text / Code Editor (personally I preffer + [Visual Studio Code](https://code.visualstudio.com/Download)) +- Terminal: + - If you are on Linux/MacOS you can use bash or zsh + - for Windows Users the commands are written for PowerShell + +## Fork Repository + +The first step to be done if you want to contribute to InvokeAI, is to fork the +rpeository. + +Since you are already reading this doc, the easiest way to do so is by clicking +[here](https://github.com/invoke-ai/InvokeAI/fork). You could also open +[InvokeAI](https://github.com/invoke-ai/InvoekAI) and click on the "Fork" Button +in the top right. + +## Clone your fork + +After you forked the Repository, you should clone it to your dev machine: + +=== "Linux:fontawesome-brands-linux: / MacOS:simple-apple:" + + ``` sh + git clone https://github.com//InvokeAI \ + && cd InvokeAI + ``` + +=== "Windows:fontawesome-brands-windows:" + + ``` powershell + git clone https://github.com//InvokeAI ` + && cd InvokeAI + ``` + +## Install in Editable Mode + +To install InvokeAI in editable mode, (as always) we recommend to create and +activate a venv first. Afterwards you can install the InvokeAI Package, +including dev and docs extras in editable mode, follwed by the installation of +the pre-commit hook: + +=== "Linux:fontawesome-brands-linux: / MacOS:simple-apple:" + + ``` sh + python -m venv .venv \ + --prompt InvokeAI \ + --upgrade-deps \ + && source .venv/bin/activate \ + && pip install \ + --upgrade-deps \ + --use-pep517 \ + --editable=".[dev,docs]" \ + && pre-commit install + ``` + +=== "Windows:fontawesome-brands-windows:" + + ``` powershell + python -m venv .venv ` + --prompt InvokeAI ` + --upgrade-deps ` + && .venv/scripts/activate.ps1 ` + && pip install ` + --upgrade ` + --use-pep517 ` + --editable=".[dev,docs]" ` + && pre-commit install + ``` + +## Create a branch + +Make sure you are on main branch, from there create your feature branch: + +=== "Linux:fontawesome-brands-linux: / MacOS:simple-apple:" + + ``` sh + git checkout main \ + && git pull \ + && git checkout -B + ``` + +=== "Windows:fontawesome-brands-windows:" + + ``` powershell + git checkout main ` + && git pull ` + && git checkout -B + ``` + +## Commit your changes + +When you are done with adding / updating content, you need to commit those +changes to your repository before you can actually open an PR: + +```{ .sh .annotate } +git add # (1)! +git commit -m "A commit message which describes your change" +git push +``` + +1. Replace this with a space seperated list of the files you changed, like: + `README.md foo.sh bar.json baz` + +## Create a Pull Request + +After pushing your changes, you are ready to create a Pull Request. just head +over to your fork on [GitHub](https://github.com), which should already show you +a message that there have been recent changes on your feature branch and a green +button which you could use to create the PR. + +The default target for your PRs would be the main branch of +[invoke-ai/InvokeAI](https://github.com/invoke-ai/InvokeAI) + +Another way would be to create it in VS-Code or via the GitHub CLI (or even via +the GitHub CLI in a VS-Code Terminal Window 🤭): + +```sh +gh pr create +``` + +The CLI will inform you if there are still unpushed commits on your branch. It +will also prompt you for things like the the Title and the Body (Description) if +you did not already pass them as arguments. diff --git a/docs/help/contribute/index.md b/docs/help/contribute/index.md new file mode 100644 index 0000000000..19189b3104 --- /dev/null +++ b/docs/help/contribute/index.md @@ -0,0 +1,10 @@ +--- +title: Contribute +--- + +There are different ways how you can contribute to +[InvokeAI](https://github.com/invoke-ai/InvokeAI), like Translations, opening +Issues for Bugs or ideas how to improve. + +This Section of the docs will explain some of the different ways of how you can +contribute to make it easier for newcommers as well as advanced users :nerd: From 8a233174de91a6beaf6b8731d08781f350331848 Mon Sep 17 00:00:00 2001 From: mauwii Date: Sun, 19 Feb 2023 02:10:15 +0100 Subject: [PATCH 25/57] update MkDocs-Material to v9 --- docs/requirements-mkdocs.txt | 4 +--- mkdocs.yml | 18 +++++++++++------- pyproject.toml | 6 +++--- 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/docs/requirements-mkdocs.txt b/docs/requirements-mkdocs.txt index a637622954..6fec332cad 100644 --- a/docs/requirements-mkdocs.txt +++ b/docs/requirements-mkdocs.txt @@ -1,5 +1,3 @@ -mkdocs -mkdocs-material>=8, <9 +mkdocs-material=="9.*" mkdocs-git-revision-date-localized-plugin mkdocs-redirects==1.2.0 - diff --git a/mkdocs.yml b/mkdocs.yml index f8d5f4631e..fdc694108f 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -19,7 +19,8 @@ theme: name: material icon: repo: fontawesome/brands/github - edit: material/file-document-edit-outline + edit: material/pencil + view: material/eye palette: - media: '(prefers-color-scheme: light)' scheme: default @@ -34,7 +35,10 @@ theme: name: Switch to light mode features: - content.tabs.link + - content.action.edit + - content.action.view - navigation.instant + - navigation.indexes - navigation.tabs - navigation.top - navigation.tracking @@ -90,9 +94,9 @@ plugins: enable_creation_date: true - redirects: redirect_maps: - 'installation/INSTALL_AUTOMATED.md': 'installation/010_INSTALL_AUTOMATED.md' - 'installation/INSTALL_MANUAL.md': 'installation/020_INSTALL_MANUAL.md' - 'installation/INSTALL_SOURCE.md': 'installation/020_INSTALL_MANUAL.md' - 'installation/INSTALL_DOCKER.md': 'installation/040_INSTALL_DOCKER.md' - 'installation/INSTALLING_MODELS.md': 'installation/050_INSTALLING_MODELS.md' - 'installation/INSTALL_PATCHMATCH.md': 'installation/060_INSTALL_PATCHMATCH.md' + 'installation/INSTALL_AUTOMATED.md': 'installation/010_INSTALL_AUTOMATED.md' + 'installation/INSTALL_MANUAL.md': 'installation/020_INSTALL_MANUAL.md' + 'installation/INSTALL_SOURCE.md': 'installation/020_INSTALL_MANUAL.md' + 'installation/INSTALL_DOCKER.md': 'installation/040_INSTALL_DOCKER.md' + 'installation/INSTALLING_MODELS.md': 'installation/050_INSTALLING_MODELS.md' + 'installation/INSTALL_PATCHMATCH.md': 'installation/060_INSTALL_PATCHMATCH.md' diff --git a/pyproject.toml b/pyproject.toml index cfc4ad364b..793ed62d4f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,7 +95,7 @@ requires-python = ">=3.9, <3.11" "dist" = ["pip-tools", "pipdeptree", "twine"] "docs" = [ "mkdocs-git-revision-date-localized-plugin", - "mkdocs-material<9.0", + "mkdocs-material==9.*", "mkdocs-redirects==1.2.0", ] "test" = ["pytest-cov", "pytest>6.0.0"] @@ -133,8 +133,8 @@ version = {attr = "ldm.invoke.__version__"} [tool.setuptools.packages.find] "include" = [ - "invokeai.assets.web*", - "invokeai.backend*", + "invokeai.assets.web", + "invokeai.backend", "invokeai.configs*", "invokeai.frontend.dist*", "ldm*", From 5b5898827c0fea559f59a5f384975919bea9b53a Mon Sep 17 00:00:00 2001 From: mauwii Date: Sun, 19 Feb 2023 02:24:49 +0100 Subject: [PATCH 26/57] update vscode settings --- .vscode/settings.json | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index f888881623..81234cfc8c 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -24,6 +24,7 @@ "editor.formatOnSaveMode": "modificationsIfAvailable" }, "[markdown]": { + "editor.defaultFormatter": "esbenp.prettier-vscode", "editor.rulers": [80], "editor.unicodeHighlight.ambiguousCharacters": false, "editor.unicodeHighlight.invisibleCharacters": false, @@ -51,8 +52,6 @@ "evenBetterToml.formatter.reorderKeys": true, "evenBetterToml.formatter.compactEntries": false, "evenBetterToml.schema.enabled": true, - "isort.check": true, - "isort.args": ["--profile=black", "--filter-files", "--color"], "python.analysis.typeCheckingMode": "basic", "python.formatting.provider": "black", "python.languageServer": "Pylance", @@ -65,5 +64,8 @@ "--cov=ldm", "--cov-branch", "--cov-report=term:skip-covered" - ] + ], + "yaml.schemas": { + "https://json.schemastore.org/prettierrc.json": "${workspaceFolder}/.prettierrc" + } } From d32819875aef85d41281b96f6289f4b3ea272ec6 Mon Sep 17 00:00:00 2001 From: mauwii Date: Sun, 19 Feb 2023 02:32:21 +0100 Subject: [PATCH 27/57] fix `docs/requirements-mkdocs.txt` --- docs/requirements-mkdocs.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/requirements-mkdocs.txt b/docs/requirements-mkdocs.txt index 6fec332cad..cc7db0dd9f 100644 --- a/docs/requirements-mkdocs.txt +++ b/docs/requirements-mkdocs.txt @@ -1,3 +1,3 @@ -mkdocs-material=="9.*" +mkdocs-material==9.* mkdocs-git-revision-date-localized-plugin mkdocs-redirects==1.2.0 From ace7032067246bade8d6c960386e41bc569f11fb Mon Sep 17 00:00:00 2001 From: mauwii Date: Sun, 19 Feb 2023 03:15:30 +0100 Subject: [PATCH 28/57] add `docs/help/contribute/issues`, update index --- docs/help/contribute/020_ISSUES.md | 24 ++++++++++++++++++++++++ docs/help/contribute/index.md | 4 ++++ 2 files changed, 28 insertions(+) create mode 100644 docs/help/contribute/020_ISSUES.md diff --git a/docs/help/contribute/020_ISSUES.md b/docs/help/contribute/020_ISSUES.md new file mode 100644 index 0000000000..f40d65c8d6 --- /dev/null +++ b/docs/help/contribute/020_ISSUES.md @@ -0,0 +1,24 @@ +--- +title: Issues +--- + +## :fontawesome-solid-bug: Report a bug + +If you stumbled over a bug while using InvokeAI, we would apreciate it a lot if +you +[open a issue](https://github.com/invoke-ai/InvokeAI/issues/new?assignees=&labels=bug&template=BUG_REPORT.yml&title=%5Bbug%5D%3A+) +to inform us about the details so that our developers can look into it. + +If you also know how to fix the bug, take a look [here](010_PULL_REQUEST.md) to +find out how to create a Pull Request. + +## Request a feature + +If you have a idea for a new feature on your mind which you would like to see in +InvokeAI, there is a +[feature request](https://github.com/invoke-ai/InvokeAI/issues/new?assignees=&labels=bug&template=BUG_REPORT.yml&title=%5Bbug%5D%3A+) +available in the issues section of the repository. + +If you are just curious which features already got requested you can find the +overview of open requests +[here](https://github.com/invoke-ai/InvokeAI/labels/enhancement) diff --git a/docs/help/contribute/index.md b/docs/help/contribute/index.md index 19189b3104..9a1e3691fb 100644 --- a/docs/help/contribute/index.md +++ b/docs/help/contribute/index.md @@ -8,3 +8,7 @@ Issues for Bugs or ideas how to improve. This Section of the docs will explain some of the different ways of how you can contribute to make it easier for newcommers as well as advanced users :nerd: + +If you want to contribute code, but you do not have an exact idea yet, take a +look at the currently open +[:fontawesome-solid-bug: Bug Reports](https://github.com/invoke-ai/InvokeAI/issues?q=is%3Aissue+is%3Aopen+label%3Abug) From f3d669319e1e9d08f4ec2e6170f8ba3fdc409881 Mon Sep 17 00:00:00 2001 From: mauwii Date: Sun, 19 Feb 2023 03:29:45 +0100 Subject: [PATCH 29/57] get rid of `requirements-mkdocs.txt` --- .github/workflows/mkdocs-material.yml | 4 +++- docs/requirements-mkdocs.txt | 3 --- 2 files changed, 3 insertions(+), 4 deletions(-) delete mode 100644 docs/requirements-mkdocs.txt diff --git a/.github/workflows/mkdocs-material.yml b/.github/workflows/mkdocs-material.yml index 26a46c1328..e553dde8c4 100644 --- a/.github/workflows/mkdocs-material.yml +++ b/.github/workflows/mkdocs-material.yml @@ -19,11 +19,13 @@ jobs: uses: actions/setup-python@v4 with: python-version: '3.10' + cache: pip + cache-dependency-path: pyproject.toml - name: install requirements run: | python -m \ - pip install -r docs/requirements-mkdocs.txt + pip install ".[docs]" - name: confirm buildability run: | diff --git a/docs/requirements-mkdocs.txt b/docs/requirements-mkdocs.txt deleted file mode 100644 index cc7db0dd9f..0000000000 --- a/docs/requirements-mkdocs.txt +++ /dev/null @@ -1,3 +0,0 @@ -mkdocs-material==9.* -mkdocs-git-revision-date-localized-plugin -mkdocs-redirects==1.2.0 From 8744dd0c4683209a6cd9108c866f6a0e7ca21d8d Mon Sep 17 00:00:00 2001 From: mauwii Date: Sun, 19 Feb 2023 04:06:38 +0100 Subject: [PATCH 30/57] fix edit_uri in `mkdocs.yml` --- mkdocs.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mkdocs.yml b/mkdocs.yml index fdc694108f..4eb9920e6c 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -9,7 +9,7 @@ dev_addr: '127.0.0.1:8080' # Repository repo_name: 'invoke-ai/InvokeAI' repo_url: 'https://github.com/invoke-ai/InvokeAI' -edit_uri: edit/main/docs/ +edit_uri: blob/main/docs/ # Copyright copyright: Copyright © 2022 InvokeAI Team From f514f17e924d72eb6aa4cdfa98adbdf22aeedb73 Mon Sep 17 00:00:00 2001 From: mauwii Date: Sun, 19 Feb 2023 04:28:07 +0100 Subject: [PATCH 31/57] add variables to define: - repo_url - repo_name - site_url --- .github/workflows/mkdocs-material.yml | 4 ++++ mkdocs.yml | 6 +++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/.github/workflows/mkdocs-material.yml b/.github/workflows/mkdocs-material.yml index e553dde8c4..df4fa0a01e 100644 --- a/.github/workflows/mkdocs-material.yml +++ b/.github/workflows/mkdocs-material.yml @@ -9,6 +9,10 @@ jobs: mkdocs-material: if: github.event.pull_request.draft == false runs-on: ubuntu-latest + env: + REPO_URL: '${{ github.server_url }}/${{ github.repository }}' + REPO_NAME: '${{ github.repository }}' + SITE_URL: 'https://${{ github.repository_owner }}.github.io/InvokeAI' steps: - name: checkout sources uses: actions/checkout@v3 diff --git a/mkdocs.yml b/mkdocs.yml index 4eb9920e6c..f9acbbc41a 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -2,13 +2,13 @@ # General site_name: InvokeAI Stable Diffusion Toolkit Docs -site_url: https://invoke-ai.github.io/InvokeAI +site_url: !ENV [SITE_URL, 'https://invoke-ai.github.io/InvokeAI'] site_author: mauwii dev_addr: '127.0.0.1:8080' # Repository -repo_name: 'invoke-ai/InvokeAI' -repo_url: 'https://github.com/invoke-ai/InvokeAI' +repo_name: !ENV [REPO_NAME, 'invoke-ai/InvokeAI'] +repo_url: !ENV [REPO_URL, 'https://github.com/invoke-ai/InvokeAI'] edit_uri: blob/main/docs/ # Copyright From f901645c129b5d2c1667a68ad39eb9177a5e1957 Mon Sep 17 00:00:00 2001 From: mauwii Date: Sun, 19 Feb 2023 04:37:28 +0100 Subject: [PATCH 32/57] use pip517 --- .github/workflows/mkdocs-material.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/mkdocs-material.yml b/.github/workflows/mkdocs-material.yml index df4fa0a01e..bddccd4f1b 100644 --- a/.github/workflows/mkdocs-material.yml +++ b/.github/workflows/mkdocs-material.yml @@ -27,6 +27,8 @@ jobs: cache-dependency-path: pyproject.toml - name: install requirements + env: + PIP_USE_PEP517: 1 run: | python -m \ pip install ".[docs]" From ce98fdc5c48bd1de77aab41f8d6d32febe91f1a4 Mon Sep 17 00:00:00 2001 From: mauwii Date: Sun, 19 Feb 2023 09:02:22 +0100 Subject: [PATCH 33/57] after some complaints reomove .vscode I still think they would be beneficial, but to lazy to re-discuss this --- .vscode/extensions.json | 22 ------------- .vscode/launch.json | 13 -------- .vscode/settings.json | 71 ----------------------------------------- 3 files changed, 106 deletions(-) delete mode 100644 .vscode/extensions.json delete mode 100644 .vscode/launch.json delete mode 100644 .vscode/settings.json diff --git a/.vscode/extensions.json b/.vscode/extensions.json deleted file mode 100644 index 9c5bb5d049..0000000000 --- a/.vscode/extensions.json +++ /dev/null @@ -1,22 +0,0 @@ -{ - "recommendations": [ - "editorconfig.editorconfig", - "github.vscode-pull-request-github", - "ms-python.black-formatter", - "ms-python.flake8", - "ms-python.isort", - "ms-python.python", - "ms-python.vscode-pylance", - "redhat.vscode-yaml", - "tamasfe.even-better-toml", - "eamodio.gitlens", - "foxundermoon.shell-format", - "timonwong.shellcheck", - "esbenp.prettier-vscode", - "davidanson.vscode-markdownlint", - "yzhang.markdown-all-in-one", - "bierner.github-markdown-preview", - "ms-azuretools.vscode-docker", - "mads-hartmann.bash-ide-vscode" - ] -} diff --git a/.vscode/launch.json b/.vscode/launch.json deleted file mode 100644 index 2b8f22f75c..0000000000 --- a/.vscode/launch.json +++ /dev/null @@ -1,13 +0,0 @@ -{ - "version": "0.2.0", - "configurations": [ - { - "name": "invokeai --web", - "type": "python", - "request": "launch", - "program": ".venv/bin/invokeai", - "args": ["--web"], - "justMyCode": true - } - ] -} diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index 81234cfc8c..0000000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,71 +0,0 @@ -{ - "[json]": { - "editor.defaultFormatter": "esbenp.prettier-vscode", - "editor.quickSuggestions": { - "strings": true - }, - "editor.suggest.insertMode": "replace", - "files.insertFinalNewline": true, - "gitlens.codeLens.scopes": ["document"] - }, - "[python]": { - "editor.defaultFormatter": "ms-python.black-formatter", - "editor.formatOnSave": true, - "editor.formatOnSaveMode": "file" - }, - "[toml]": { - "editor.defaultFormatter": "tamasfe.even-better-toml", - "editor.formatOnSave": true, - "editor.formatOnSaveMode": "modificationsIfAvailable" - }, - "[yaml]": { - "editor.defaultFormatter": "redhat.vscode-yaml", - "editor.formatOnSave": true, - "editor.formatOnSaveMode": "modificationsIfAvailable" - }, - "[markdown]": { - "editor.defaultFormatter": "esbenp.prettier-vscode", - "editor.rulers": [80], - "editor.unicodeHighlight.ambiguousCharacters": false, - "editor.unicodeHighlight.invisibleCharacters": false, - "diffEditor.ignoreTrimWhitespace": false, - "editor.wordWrap": "on", - "editor.quickSuggestions": { - "comments": "off", - "strings": "off", - "other": "off" - } - }, - "editor.rulers": [88], - "editor.defaultFormatter": "esbenp.prettier-vscode", - "evenBetterToml.formatter.alignEntries": false, - "evenBetterToml.formatter.allowedBlankLines": 1, - "evenBetterToml.formatter.arrayAutoExpand": true, - "evenBetterToml.formatter.arrayTrailingComma": true, - "evenBetterToml.formatter.arrayAutoCollapse": true, - "evenBetterToml.formatter.columnWidth": 88, - "evenBetterToml.formatter.compactArrays": true, - "evenBetterToml.formatter.compactInlineTables": true, - "evenBetterToml.formatter.indentEntries": false, - "evenBetterToml.formatter.inlineTableExpand": true, - "evenBetterToml.formatter.reorderArrays": true, - "evenBetterToml.formatter.reorderKeys": true, - "evenBetterToml.formatter.compactEntries": false, - "evenBetterToml.schema.enabled": true, - "python.analysis.typeCheckingMode": "basic", - "python.formatting.provider": "black", - "python.languageServer": "Pylance", - "python.linting.enabled": true, - "python.linting.flake8Enabled": true, - "python.testing.unittestEnabled": false, - "python.testing.pytestEnabled": true, - "python.testing.pytestArgs": [ - "tests", - "--cov=ldm", - "--cov-branch", - "--cov-report=term:skip-covered" - ], - "yaml.schemas": { - "https://json.schemastore.org/prettierrc.json": "${workspaceFolder}/.prettierrc" - } -} From 57daa3e1c26ee50b0d3bf3869c3654db13623fe6 Mon Sep 17 00:00:00 2001 From: mauwii Date: Sun, 19 Feb 2023 09:02:38 +0100 Subject: [PATCH 34/57] re-ignore .vscode --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 8b9b495261..9adb0be85a 100644 --- a/.gitignore +++ b/.gitignore @@ -201,6 +201,7 @@ checkpoints # Scratch folder .scratch/ +.vscode/ gfpgan/ models/ldm/stable-diffusion-v1/*.sha256 From 9c6af7455699fa1c2cde72cdd0ab00bfc41da05c Mon Sep 17 00:00:00 2001 From: mauwii Date: Sun, 19 Feb 2023 09:09:26 +0100 Subject: [PATCH 35/57] add `docs/help/IDE-Settings` --- docs/help/IDE-Settings/index.md | 4 + docs/help/IDE-Settings/vs-code.md | 149 ++++++++++++++++++++++++++++++ 2 files changed, 153 insertions(+) create mode 100644 docs/help/IDE-Settings/index.md create mode 100644 docs/help/IDE-Settings/vs-code.md diff --git a/docs/help/IDE-Settings/index.md b/docs/help/IDE-Settings/index.md new file mode 100644 index 0000000000..bdb045f800 --- /dev/null +++ b/docs/help/IDE-Settings/index.md @@ -0,0 +1,4 @@ +# :octicons-file-code-16: IDE-Settings + +Here we will share settings for IDEs used by our developers, maybe you can find +something interestening which will help to boost your development efficency 🔥 diff --git a/docs/help/IDE-Settings/vs-code.md b/docs/help/IDE-Settings/vs-code.md new file mode 100644 index 0000000000..a2ae9b0dd9 --- /dev/null +++ b/docs/help/IDE-Settings/vs-code.md @@ -0,0 +1,149 @@ +--- +title: Visual Studio Code +--- + +# :material-microsoft-visual-studio-code:Visual Studio Code + +The Workspace Settings are stored in the project (repository) root and get +higher priorized than your user settings. + +This helps to have different settings for different projects, while the user +settings get used as a default value if no workspace settings are provided. + +## launch.json + +It is asumed that you have created a virtual environment as `.venv`: + +```sh +python -m venv .venv --prompt="InvokeAI" --upgrade-deps +``` + +This is the most simplified version of launching `invokeai --web` with the +debugger attached: + +```json title=".vscode/launch.json" +{ + "version": "0.2.0", + "configurations": [ + { + "name": "invokeai --web", + "type": "python", + "request": "launch", + "program": ".venv/bin/invokeai", + "args": ["--web"], + "justMyCode": true + } + ] +} +``` + +Then you only need to hit ++F5++ and the fun begins :nerd: + +## extensions.json + +A list of recommended vscode-extensions to make your life easier: + +```json title=".vscode/extensions.json" +{ + "recommendations": [ + "editorconfig.editorconfig", + "github.vscode-pull-request-github", + "ms-python.black-formatter", + "ms-python.flake8", + "ms-python.isort", + "ms-python.python", + "ms-python.vscode-pylance", + "redhat.vscode-yaml", + "tamasfe.even-better-toml", + "eamodio.gitlens", + "foxundermoon.shell-format", + "timonwong.shellcheck", + "esbenp.prettier-vscode", + "davidanson.vscode-markdownlint", + "yzhang.markdown-all-in-one", + "bierner.github-markdown-preview", + "ms-azuretools.vscode-docker", + "mads-hartmann.bash-ide-vscode" + ] +} +``` + +## settings.json + +With those settings your files already get formated when you save them, which +will help you to not run into trouble with the pre-commit hooks, which will +prevent you from commiting if the formaters are failing + +```json title=".vscode/settings.json" +{ + "[json]": { + "editor.defaultFormatter": "esbenp.prettier-vscode", + "editor.quickSuggestions": { + "strings": true + }, + "editor.suggest.insertMode": "replace", + "files.insertFinalNewline": true, + "gitlens.codeLens.scopes": ["document"] + }, + "[python]": { + "editor.defaultFormatter": "ms-python.black-formatter", + "editor.formatOnSave": true, + "editor.formatOnSaveMode": "file" + }, + "[toml]": { + "editor.defaultFormatter": "tamasfe.even-better-toml", + "editor.formatOnSave": true, + "editor.formatOnSaveMode": "modificationsIfAvailable" + }, + "[yaml]": { + "editor.defaultFormatter": "redhat.vscode-yaml", + "editor.formatOnSave": true, + "editor.formatOnSaveMode": "modificationsIfAvailable" + }, + "[markdown]": { + "editor.defaultFormatter": "esbenp.prettier-vscode", + "editor.rulers": [80], + "editor.unicodeHighlight.ambiguousCharacters": false, + "editor.unicodeHighlight.invisibleCharacters": false, + "diffEditor.ignoreTrimWhitespace": false, + "editor.wordWrap": "on", + "editor.quickSuggestions": { + "comments": "off", + "strings": "off", + "other": "off" + } + }, + "editor.rulers": [88], + "editor.defaultFormatter": "esbenp.prettier-vscode", + "evenBetterToml.formatter.alignEntries": false, + "evenBetterToml.formatter.allowedBlankLines": 1, + "evenBetterToml.formatter.arrayAutoExpand": true, + "evenBetterToml.formatter.arrayTrailingComma": true, + "evenBetterToml.formatter.arrayAutoCollapse": true, + "evenBetterToml.formatter.columnWidth": 88, + "evenBetterToml.formatter.compactArrays": true, + "evenBetterToml.formatter.compactInlineTables": true, + "evenBetterToml.formatter.indentEntries": false, + "evenBetterToml.formatter.inlineTableExpand": true, + "evenBetterToml.formatter.reorderArrays": true, + "evenBetterToml.formatter.reorderKeys": true, + "evenBetterToml.formatter.compactEntries": false, + "evenBetterToml.schema.enabled": true, + "python.analysis.typeCheckingMode": "basic", + "python.formatting.provider": "black", + "python.languageServer": "Pylance", + "python.linting.enabled": true, + "python.linting.flake8Enabled": true, + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true, + "python.testing.pytestArgs": [ + "tests", + "--cov=ldm", + "--cov-branch", + "--cov-report=term:skip-covered" + ], + "yaml.schemas": { + "https://json.schemastore.org/prettierrc.json": "${workspaceFolder}/.prettierrc" + } +} +``` From 7c7c1ba02d720363964bf8ed7a1d2416f3ba1af9 Mon Sep 17 00:00:00 2001 From: mauwii Date: Sun, 19 Feb 2023 11:49:14 +0100 Subject: [PATCH 36/57] add `docs/help/index.md` --- docs/help/index.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 docs/help/index.md diff --git a/docs/help/index.md b/docs/help/index.md new file mode 100644 index 0000000000..b4236e8244 --- /dev/null +++ b/docs/help/index.md @@ -0,0 +1,12 @@ +# :material-help:Help + +If you are looking for help with the installation of InvokeAI, please take a +look into the [Installation](../installation/index.md) section of the docs. + +Here you will find help to topics like + +- how to contribute +- configuration recommendation for IDEs + +If you have an Idea about what's missing and aren't scared from contributing, +just take a look at [DOCS](./contribute/030_DOCS.md) to find out how to do so. From 72e25d99c754b5be0bf4f9a197422f90c96a66e0 Mon Sep 17 00:00:00 2001 From: mauwii Date: Sun, 19 Feb 2023 11:50:14 +0100 Subject: [PATCH 37/57] add `docs/help/contribute/030_DOCS.md` --- docs/help/contribute/030_DOCS.md | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 docs/help/contribute/030_DOCS.md diff --git a/docs/help/contribute/030_DOCS.md b/docs/help/contribute/030_DOCS.md new file mode 100644 index 0000000000..12b9c428e7 --- /dev/null +++ b/docs/help/contribute/030_DOCS.md @@ -0,0 +1,24 @@ +--- +title: docs +--- + +# :simple-readthedocs: MkDocs-Material + +If you want to contribute to the docs, there is a easy way to verify the results +of your changes before commiting them. + +Just follow the steps in the [Pull-Requests](010_PULL_REQUEST.md) docs, there we +already +[create a venv and install the docs extras](010_PULL_REQUEST.md#install-in-editable-mode). +When installed it's as simple as: + +```sh +mkdocs serve +``` + +This will build the docs locally and serve them on your local host, even +auto-refresh is included, so you can just update a doc, save it and tab to the +browser, without the needs of restarting the `mkdocs serve`. + +More information about the "mkdocs flavored markdown syntax" can be found +[here](https://squidfunk.github.io/mkdocs-material/reference/). From 774230f7b9a44759cf76d02e23cdacf1c7030cb9 Mon Sep 17 00:00:00 2001 From: mauwii Date: Sun, 19 Feb 2023 12:02:26 +0100 Subject: [PATCH 38/57] re-format `docs/features/index.md` --- docs/features/index.md | 104 ++++++++++++++++++++++++----------------- 1 file changed, 62 insertions(+), 42 deletions(-) diff --git a/docs/features/index.md b/docs/features/index.md index d9b0e1fd7c..50fc8bc77c 100644 --- a/docs/features/index.md +++ b/docs/features/index.md @@ -2,62 +2,82 @@ title: Overview --- -Here you can find the documentation for InvokeAI's various features. +- The Basics -## The Basics -### * The [Web User Interface](WEB.md) -Guide to the Web interface. Also see the [WebUI Hotkeys Reference Guide](WEBUIHOTKEYS.md) + - The [Web User Interface](WEB.md) -### * The [Unified Canvas](UNIFIED_CANVAS.md) -Build complex scenes by combine and modifying multiple images in a stepwise -fashion. This feature combines img2img, inpainting and outpainting in -a single convenient digital artist-optimized user interface. + Guide to the Web interface. Also see the + [WebUI Hotkeys Reference Guide](WEBUIHOTKEYS.md) -### * The [Command Line Interface (CLI)](CLI.md) -Scriptable access to InvokeAI's features. + - The [Unified Canvas](UNIFIED_CANVAS.md) -## Image Generation -### * [Prompt Engineering](PROMPTS.md) -Get the images you want with the InvokeAI prompt engineering language. + Build complex scenes by combine and modifying multiple images in a + stepwise fashion. This feature combines img2img, inpainting and + outpainting in a single convenient digital artist-optimized user + interface. -## * [Post-Processing](POSTPROCESS.md) -Restore mangled faces and make images larger with upscaling. Also see the [Embiggen Upscaling Guide](EMBIGGEN.md). + - The [Command Line Interface (CLI)](CLI.md) -## * The [Concepts Library](CONCEPTS.md) -Add custom subjects and styles using HuggingFace's repository of embeddings. + Scriptable access to InvokeAI's features. -### * [Image-to-Image Guide for the CLI](IMG2IMG.md) -Use a seed image to build new creations in the CLI. +- Image Generation -### * [Inpainting Guide for the CLI](INPAINTING.md) -Selectively erase and replace portions of an existing image in the CLI. + - [Prompt Engineering](PROMPTS.md) -### * [Outpainting Guide for the CLI](OUTPAINTING.md) -Extend the borders of the image with an "outcrop" function within the CLI. + Get the images you want with the InvokeAI prompt engineering language. -### * [Generating Variations](VARIATIONS.md) -Have an image you like and want to generate many more like it? Variations -are the ticket. + - [Post-Processing](POSTPROCESS.md) -## Model Management + Restore mangled faces and make images larger with upscaling. Also see + the [Embiggen Upscaling Guide](EMBIGGEN.md). -## * [Model Installation](../installation/050_INSTALLING_MODELS.md) -Learn how to import third-party models and switch among them. This -guide also covers optimizing models to load quickly. + - The [Concepts Library](CONCEPTS.md) -## * [Merging Models](MODEL_MERGING.md) -Teach an old model new tricks. Merge 2-3 models together to create a -new model that combines characteristics of the originals. + Add custom subjects and styles using HuggingFace's repository of + embeddings. -## * [Textual Inversion](TEXTUAL_INVERSION.md) -Personalize models by adding your own style or subjects. + - [Image-to-Image Guide for the CLI](IMG2IMG.md) -# Other Features + Use a seed image to build new creations in the CLI. -## * [The NSFW Checker](NSFW.md) -Prevent InvokeAI from displaying unwanted racy images. + - [Inpainting Guide for the CLI](INPAINTING.md) -## * [Miscellaneous](OTHER.md) -Run InvokeAI on Google Colab, generate images with repeating patterns, -batch process a file of prompts, increase the "creativity" of image -generation by adding initial noise, and more! + Selectively erase and replace portions of an existing image in the CLI. + + - [Outpainting Guide for the CLI](OUTPAINTING.md) + + Extend the borders of the image with an "outcrop" function within the + CLI. + + - [Generating Variations](VARIATIONS.md) + + Have an image you like and want to generate many more like it? + Variations are the ticket. + +- Model Management + + - [Model Installation](../installation/050_INSTALLING_MODELS.md) + + Learn how to import third-party models and switch among them. This guide + also covers optimizing models to load quickly. + + - [Merging Models](MODEL_MERGING.md) + + Teach an old model new tricks. Merge 2-3 models together to create a new + model that combines characteristics of the originals. + + - [Textual Inversion](TEXTUAL_INVERSION.md) + + Personalize models by adding your own style or subjects. + +- Other Features + + - [The NSFW Checker](NSFW.md) + + Prevent InvokeAI from displaying unwanted racy images. + + - [Miscellaneous](OTHER.md) + + Run InvokeAI on Google Colab, generate images with repeating patterns, + batch process a file of prompts, increase the "creativity" of image + generation by adding initial noise, and more! From f5aadbc200d8814c94486ee9ed15924c3888ab62 Mon Sep 17 00:00:00 2001 From: mauwii Date: Sun, 19 Feb 2023 16:14:24 +0100 Subject: [PATCH 39/57] rename `docs/help/contributing`` - update vs-code.md - update 30_DOCS.md --- docs/help/IDE-Settings/vs-code.md | 128 +++++++++++++++--- .../010_PULL_REQUEST.md | 0 .../020_ISSUES.md | 0 .../{contribute => contributing}/030_DOCS.md | 8 ++ .../{contribute => contributing}/index.md | 0 5 files changed, 118 insertions(+), 18 deletions(-) rename docs/help/{contribute => contributing}/010_PULL_REQUEST.md (100%) rename docs/help/{contribute => contributing}/020_ISSUES.md (100%) rename docs/help/{contribute => contributing}/030_DOCS.md (70%) rename docs/help/{contribute => contributing}/index.md (100%) diff --git a/docs/help/IDE-Settings/vs-code.md b/docs/help/IDE-Settings/vs-code.md index a2ae9b0dd9..ea952b192e 100644 --- a/docs/help/IDE-Settings/vs-code.md +++ b/docs/help/IDE-Settings/vs-code.md @@ -10,34 +10,111 @@ higher priorized than your user settings. This helps to have different settings for different projects, while the user settings get used as a default value if no workspace settings are provided. -## launch.json +## tasks.json -It is asumed that you have created a virtual environment as `.venv`: +First we will create a task configuration which will create a virtual +environment and update the deps (pip, setuptools and wheel). -```sh -python -m venv .venv --prompt="InvokeAI" --upgrade-deps +Into this venv we will then install the pyproject.toml in editable mode with +dev, docs and test dependencies. + +```json +{ + // See https://go.microsoft.com/fwlink/?LinkId=733558 + // for the documentation about the tasks.json format + "version": "2.0.0", + "tasks": [ + { + "label": "Create virtual environment", + "detail": "Create .venv and upgrade pip, setuptools and wheel", + "command": "python3", + "args": [ + "-m", + "venv", + ".venv", + "--prompt", + "InvokeAI", + "--upgrade-deps" + ], + "runOptions": { + "instanceLimit": 1, + "reevaluateOnRerun": true + }, + "group": { + "kind": "build" + }, + "presentation": { + "echo": true, + "reveal": "always", + "focus": false, + "panel": "shared", + "showReuseMessage": true, + "clear": false + } + }, + { + "label": "build InvokeAI", + "detail": "Build pyproject.toml with extras dev, docs and test", + "command": "${workspaceFolder}/.venv/bin/python", + "args": ["-m", "pip", "install", "-e", ".[dev,docs,test]"], + "dependsOn": "Create venv", + "dependsOrder": "sequence", + "group": { + "kind": "build", + "isDefault": true + }, + "presentation": { + "echo": true, + "reveal": "always", + "focus": false, + "panel": "shared", + "showReuseMessage": true, + "clear": false + } + } + ] +} ``` -This is the most simplified version of launching `invokeai --web` with the -debugger attached: +## launch.json + +This file is used to define debugger configurations, so that you can one-click +launch and monitor the application, set halt points to inspect specific states, +... ```json title=".vscode/launch.json" { "version": "0.2.0", "configurations": [ { - "name": "invokeai --web", + "name": "invokeai web", "type": "python", "request": "launch", "program": ".venv/bin/invokeai", - "args": ["--web"], + "justMyCode": true + }, + { + "name": "invokeai cli", + "type": "python", + "request": "launch", + "program": ".venv/bin/invokeai", + "justMyCode": true + }, + { + "name": "mkdocs serve", + "type": "python", + "request": "launch", + "program": ".venv/bin/mkdocs", + "args": ["serve"], "justMyCode": true } ] } ``` -Then you only need to hit ++F5++ and the fun begins :nerd: +Then you only need to hit ++f5++ and the fun begins :nerd: (It is asumed that +you have created a virtual environment via the [tasks](#tasksjson) from the +previous step.) ## extensions.json @@ -70,21 +147,29 @@ A list of recommended vscode-extensions to make your life easier: ## settings.json -With those settings your files already get formated when you save them, which -will help you to not run into trouble with the pre-commit hooks, which will -prevent you from commiting if the formaters are failing +With bellow settings your files already get formated when you save them (only +your modifications if available), which will help you to not run into trouble +with the pre-commit hooks. If the hooks fail, they will prevent you from +commiting, but most hooks directly add a fixed version, so that you just need to +stage and commit them: ```json title=".vscode/settings.json" { "[json]": { "editor.defaultFormatter": "esbenp.prettier-vscode", "editor.quickSuggestions": { - "strings": true + "comments": false, + "strings": true, + "other": true }, "editor.suggest.insertMode": "replace", - "files.insertFinalNewline": true, "gitlens.codeLens.scopes": ["document"] }, + "[jsonc]": { + "editor.defaultFormatter": "esbenp.prettier-vscode", + "editor.formatOnSave": true, + "editor.formatOnSaveMode": "modificationsIfAvailable" + }, "[python]": { "editor.defaultFormatter": "ms-python.black-formatter", "editor.formatOnSave": true, @@ -96,7 +181,7 @@ prevent you from commiting if the formaters are failing "editor.formatOnSaveMode": "modificationsIfAvailable" }, "[yaml]": { - "editor.defaultFormatter": "redhat.vscode-yaml", + "editor.defaultFormatter": "esbenp.prettier-vscode", "editor.formatOnSave": true, "editor.formatOnSaveMode": "modificationsIfAvailable" }, @@ -111,10 +196,17 @@ prevent you from commiting if the formaters are failing "comments": "off", "strings": "off", "other": "off" - } + }, + "editor.formatOnSave": true, + "editor.formatOnSaveMode": "modificationsIfAvailable" + }, + "[shellscript]": { + "editor.defaultFormatter": "foxundermoon.shell-format" + }, + "[ignore]": { + "editor.defaultFormatter": "foxundermoon.shell-format" }, "editor.rulers": [88], - "editor.defaultFormatter": "esbenp.prettier-vscode", "evenBetterToml.formatter.alignEntries": false, "evenBetterToml.formatter.allowedBlankLines": 1, "evenBetterToml.formatter.arrayAutoExpand": true, @@ -143,7 +235,7 @@ prevent you from commiting if the formaters are failing "--cov-report=term:skip-covered" ], "yaml.schemas": { - "https://json.schemastore.org/prettierrc.json": "${workspaceFolder}/.prettierrc" + "https://json.schemastore.org/prettierrc.json": "${workspaceFolder}/.prettierrc.yaml" } } ``` diff --git a/docs/help/contribute/010_PULL_REQUEST.md b/docs/help/contributing/010_PULL_REQUEST.md similarity index 100% rename from docs/help/contribute/010_PULL_REQUEST.md rename to docs/help/contributing/010_PULL_REQUEST.md diff --git a/docs/help/contribute/020_ISSUES.md b/docs/help/contributing/020_ISSUES.md similarity index 100% rename from docs/help/contribute/020_ISSUES.md rename to docs/help/contributing/020_ISSUES.md diff --git a/docs/help/contribute/030_DOCS.md b/docs/help/contributing/030_DOCS.md similarity index 70% rename from docs/help/contribute/030_DOCS.md rename to docs/help/contributing/030_DOCS.md index 12b9c428e7..f4ebfb9df4 100644 --- a/docs/help/contribute/030_DOCS.md +++ b/docs/help/contributing/030_DOCS.md @@ -22,3 +22,11 @@ browser, without the needs of restarting the `mkdocs serve`. More information about the "mkdocs flavored markdown syntax" can be found [here](https://squidfunk.github.io/mkdocs-material/reference/). + +## :material-microsoft-visual-studio-code:VS-Code + +We also provide a +[launch configuration for VS-Code](../IDE-Settings/vs-code.md#launchjson) which +includes a `mkdocs serve` entrypoint as well. You also don't have to worry about +the formatting since this is automated via prettier, but this is of course not +limited to VS-Code. diff --git a/docs/help/contribute/index.md b/docs/help/contributing/index.md similarity index 100% rename from docs/help/contribute/index.md rename to docs/help/contributing/index.md From 317165c410a19a92928888a853e6999027413b07 Mon Sep 17 00:00:00 2001 From: mauwii Date: Sun, 19 Feb 2023 17:10:32 +0100 Subject: [PATCH 40/57] remove previous attempt for contributing docs --- .../010_HOW_TO_CONTRIBUTE.md | 139 ------------------ 1 file changed, 139 deletions(-) delete mode 100644 docs/installation/Developers_documentation/010_HOW_TO_CONTRIBUTE.md diff --git a/docs/installation/Developers_documentation/010_HOW_TO_CONTRIBUTE.md b/docs/installation/Developers_documentation/010_HOW_TO_CONTRIBUTE.md deleted file mode 100644 index d83833d063..0000000000 --- a/docs/installation/Developers_documentation/010_HOW_TO_CONTRIBUTE.md +++ /dev/null @@ -1,139 +0,0 @@ ---- -title: How to Contribute ---- - -There are different ways how you can contribute to -[InvokeAI](https://github.com/invoke-ai/InvokeAI), like Translations, opening -Issues for Bugs or ideas how to improve. - -## Pull Requests - -### pre-requirements - -To follow the steps in this tutorial you will need: - -- [GitHub](https://github.com) account -- [git](https://git-scm.com/downloads) source controll -- Text / Code Editor (personally I preffer - [Visual Studio Code](https://code.visualstudio.com/Download)) -- Terminal: - - If you are on Linux/MacOS you can use bash or zsh - - for Windows Users the commands are written for PowerShell - -### Fork Repository - -The first step to be done if you want to contribute to InvokeAI, is to fork the -rpeository. - -Since you are already reading this doc, the easiest way to do so is by clicking -[here](https://github.com/invoke-ai/InvokeAI/fork). You could also open -[InvokeAI](https://github.com/invoke-ai/InvoekAI) and click on the "Fork" Button -in the top right. - -### Clone your fork - -After you forked the Repository, you should clone it to your dev machine: - -=== "Linux:fontawesome-brands-linux: / MacOS:simple-apple:" - - ``` sh - git clone https://github.com//InvokeAI \ - && cd InvokeAI - ``` - -=== "Windows:fontawesome-brands-windows:" - - ``` powershell - git clone https://github.com//InvokeAI ` - && cd InvokeAI - ``` - -### Install in Editable Mode - -To install InvokeAI in editable mode, (as always) we recommend to create and -activate a venv first. Afterwards you can install the InvokeAI Package, -including dev and docs extras in editable mode, follwed by the installation of -the pre-commit hook: - -=== "Linux:fontawesome-brands-linux: / MacOS:simple-apple:" - - ``` sh - python -m venv .venv \ - --prompt InvokeAI \ - --upgrade-deps \ - && source .venv/bin/activate \ - && pip install \ - --upgrade-deps \ - --use-pep517 \ - --editable=".[dev,docs]" \ - && pre-commit install - ``` - -=== "Windows:fontawesome-brands-windows:" - - ``` powershell - python -m venv .venv ` - --prompt InvokeAI ` - --upgrade-deps ` - && .venv/scripts/activate.ps1 ` - && pip install ` - --upgrade ` - --use-pep517 ` - --editable=".[dev,docs]" ` - && pre-commit install - ``` - -### Create a branch - -Make sure you are on main branch, from there create your feature branch: - -=== "Linux:fontawesome-brands-linux: / MacOS:simple-apple:" - - ``` sh - git checkout main \ - && git pull \ - && git checkout -B - ``` - -=== "Windows:fontawesome-brands-windows:" - - ``` powershell - git checkout main ` - && git pull ` - && git checkout -B - ``` - -### Commit your changes - -When you are done with adding / updating content, you need to commit those -changes to your repository before you can actually open an PR: - -```{ .sh .annotate } -git add # (1)! -git commit -m "A commit message which describes your change" -git push -``` - -1. Replace this with a space seperated list of the files you changed, like: - `README.md foo.sh bar.json baz` - -### Create a Pull Request - -After pushing your changes, you are ready to create a Pull Request. just head -over to your fork on [GitHub](https://github.com), which should already show you -a message that there have been recent changes on your feature branch and a green -button which you could use to create the PR. - -The default target for your PRs would be the main branch of -[invoke-ai/InvokeAI](https://github.com/invoke-ai/InvokeAI) - -Another way would be to create it in VS-Code or via the GitHub CLI (or even via -the GitHub CLI in a VS-Code Terminal Window 🤭): - -```sh -gh pr create -``` - -The CLI will inform you if there are still unpushed commits on your branch. It -will also prompt you for things like the the Title and the Body (Description) if -you did not already pass them as arguments. From f4940770035e5ff86bf24487563d9748db1f3cd2 Mon Sep 17 00:00:00 2001 From: mauwii Date: Sun, 19 Feb 2023 17:45:05 +0100 Subject: [PATCH 41/57] enable `content.code.copy` - to get a handy copy button in code blocks - also sort the features alphabetically --- mkdocs.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mkdocs.yml b/mkdocs.yml index f9acbbc41a..0e9bf5687a 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -34,11 +34,12 @@ theme: icon: material/lightbulb-outline name: Switch to light mode features: - - content.tabs.link - content.action.edit - content.action.view - - navigation.instant + - content.code.copy + - content.tabs.link - navigation.indexes + - navigation.instant - navigation.tabs - navigation.top - navigation.tracking From 51956ba356caf75a79b07ed5b23e072926a5dc33 Mon Sep 17 00:00:00 2001 From: mauwii Date: Sun, 19 Feb 2023 17:50:22 +0100 Subject: [PATCH 42/57] update `vs-code.md`, fix `docs/help/index.md` --- docs/help/IDE-Settings/vs-code.md | 17 +++++++++++++---- docs/help/index.md | 2 +- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/docs/help/IDE-Settings/vs-code.md b/docs/help/IDE-Settings/vs-code.md index ea952b192e..25be691d0e 100644 --- a/docs/help/IDE-Settings/vs-code.md +++ b/docs/help/IDE-Settings/vs-code.md @@ -18,7 +18,7 @@ environment and update the deps (pip, setuptools and wheel). Into this venv we will then install the pyproject.toml in editable mode with dev, docs and test dependencies. -```json +```json title=".vscode/tasks.json" { // See https://go.microsoft.com/fwlink/?LinkId=733558 // for the documentation about the tasks.json format @@ -55,9 +55,16 @@ dev, docs and test dependencies. { "label": "build InvokeAI", "detail": "Build pyproject.toml with extras dev, docs and test", - "command": "${workspaceFolder}/.venv/bin/python", - "args": ["-m", "pip", "install", "-e", ".[dev,docs,test]"], - "dependsOn": "Create venv", + "command": "${workspaceFolder}/.venv/bin/python3", + "args": [ + "-m", + "pip", + "install", + "--use-pep517", + "--editable", + ".[dev,docs,test]" + ], + "dependsOn": "Create virtual environment", "dependsOrder": "sequence", "group": { "kind": "build", @@ -76,6 +83,8 @@ dev, docs and test dependencies. } ``` +The fastest way to build InvokeAI now is ++cmd+shift+b++ + ## launch.json This file is used to define debugger configurations, so that you can one-click diff --git a/docs/help/index.md b/docs/help/index.md index b4236e8244..fa56264486 100644 --- a/docs/help/index.md +++ b/docs/help/index.md @@ -9,4 +9,4 @@ Here you will find help to topics like - configuration recommendation for IDEs If you have an Idea about what's missing and aren't scared from contributing, -just take a look at [DOCS](./contribute/030_DOCS.md) to find out how to do so. +just take a look at [DOCS](./contributing/030_DOCS.md) to find out how to do so. From b731b55de4591e02211337c4b146242dfa160609 Mon Sep 17 00:00:00 2001 From: mauwii Date: Sun, 19 Feb 2023 21:17:09 +0100 Subject: [PATCH 43/57] update title in `docs/help/contributing/index.md` --- docs/help/contributing/index.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/help/contributing/index.md b/docs/help/contributing/index.md index 9a1e3691fb..9e33003ef2 100644 --- a/docs/help/contributing/index.md +++ b/docs/help/contributing/index.md @@ -1,7 +1,9 @@ --- -title: Contribute +title: Contributing --- +# :fontawesome-solid-code-commit: Contributing + There are different ways how you can contribute to [InvokeAI](https://github.com/invoke-ai/InvokeAI), like Translations, opening Issues for Bugs or ideas how to improve. From 7ef63161ba21b25c01e3a998b344127e143a2bbf Mon Sep 17 00:00:00 2001 From: mauwii Date: Sun, 19 Feb 2023 21:45:44 +0100 Subject: [PATCH 44/57] add icons to some docs - this also reformated `docs/index.md` --- docs/help/contributing/010_PULL_REQUEST.md | 4 +- docs/help/contributing/020_ISSUES.md | 2 + docs/index.md | 295 +++++++++++++++------ 3 files changed, 214 insertions(+), 87 deletions(-) diff --git a/docs/help/contributing/010_PULL_REQUEST.md b/docs/help/contributing/010_PULL_REQUEST.md index 129df49349..511a1fcc63 100644 --- a/docs/help/contributing/010_PULL_REQUEST.md +++ b/docs/help/contributing/010_PULL_REQUEST.md @@ -1,7 +1,9 @@ --- -title: Pull Requests +title: Pull-Request --- +# :octicons-git-pull-request-16: Pull-Request + ## pre-requirements To follow the steps in this tutorial you will need: diff --git a/docs/help/contributing/020_ISSUES.md b/docs/help/contributing/020_ISSUES.md index f40d65c8d6..576af82188 100644 --- a/docs/help/contributing/020_ISSUES.md +++ b/docs/help/contributing/020_ISSUES.md @@ -2,6 +2,8 @@ title: Issues --- +# :octicons-issue-opened-16: Issues + ## :fontawesome-solid-bug: Report a bug If you stumbled over a bug while using InvokeAI, we would apreciate it a lot if diff --git a/docs/index.md b/docs/index.md index 4587b08f18..ab89434c55 100644 --- a/docs/index.md +++ b/docs/index.md @@ -2,6 +2,8 @@ title: Home --- +# :octicons-home-16: Home + + ### The InvokeAI Command Line Interface -- [Command Line Interace Reference Guide](features/CLI.md) + +- [Command Line Interace Reference Guide](features/CLI.md) + ### Image Management -- [Image2Image](features/IMG2IMG.md) -- [Inpainting](features/INPAINTING.md) -- [Outpainting](features/OUTPAINTING.md) -- [Adding custom styles and subjects](features/CONCEPTS.md) -- [Upscaling and Face Reconstruction](features/POSTPROCESS.md) -- [Embiggen upscaling](features/EMBIGGEN.md) -- [Other Features](features/OTHER.md) + +- [Image2Image](features/IMG2IMG.md) +- [Inpainting](features/INPAINTING.md) +- [Outpainting](features/OUTPAINTING.md) +- [Adding custom styles and subjects](features/CONCEPTS.md) +- [Upscaling and Face Reconstruction](features/POSTPROCESS.md) +- [Embiggen upscaling](features/EMBIGGEN.md) +- [Other Features](features/OTHER.md) + ### Model Management -- [Installing](installation/050_INSTALLING_MODELS.md) -- [Model Merging](features/MODEL_MERGING.md) -- [Style/Subject Concepts and Embeddings](features/CONCEPTS.md) -- [Textual Inversion](features/TEXTUAL_INVERSION.md) -- [Not Safe for Work (NSFW) Checker](features/NSFW.md) + +- [Installing](installation/050_INSTALLING_MODELS.md) +- [Model Merging](features/MODEL_MERGING.md) +- [Style/Subject Concepts and Embeddings](features/CONCEPTS.md) +- [Textual Inversion](features/TEXTUAL_INVERSION.md) +- [Not Safe for Work (NSFW) Checker](features/NSFW.md) + ### Prompt Engineering -- [Prompt Syntax](features/PROMPTS.md) -- [Generating Variations](features/VARIATIONS.md) + +- [Prompt Syntax](features/PROMPTS.md) +- [Generating Variations](features/VARIATIONS.md) ## :octicons-log-16: Latest Changes @@ -162,84 +181,188 @@ This method is recommended for those familiar with running Docker containers #### Migration to Stable Diffusion `diffusers` models -Previous versions of InvokeAI supported the original model file format introduced with Stable Diffusion 1.4. In the original format, known variously as "checkpoint", or "legacy" format, there is a single large weights file ending with `.ckpt` or `.safetensors`. Though this format has served the community well, it has a number of disadvantages, including file size, slow loading times, and a variety of non-standard variants that require special-case code to handle. In addition, because checkpoint files are actually a bundle of multiple machine learning sub-models, it is hard to swap different sub-models in and out, or to share common sub-models. A new format, introduced by the StabilityAI company in collaboration with HuggingFace, is called `diffusers` and consists of a directory of individual models. The most immediate benefit of `diffusers` is that they load from disk very quickly. A longer term benefit is that in the near future `diffusers` models will be able to share common sub-models, dramatically reducing disk space when you have multiple fine-tune models derived from the same base. +Previous versions of InvokeAI supported the original model file format +introduced with Stable Diffusion 1.4. In the original format, known variously as +"checkpoint", or "legacy" format, there is a single large weights file ending +with `.ckpt` or `.safetensors`. Though this format has served the community +well, it has a number of disadvantages, including file size, slow loading times, +and a variety of non-standard variants that require special-case code to handle. +In addition, because checkpoint files are actually a bundle of multiple machine +learning sub-models, it is hard to swap different sub-models in and out, or to +share common sub-models. A new format, introduced by the StabilityAI company in +collaboration with HuggingFace, is called `diffusers` and consists of a +directory of individual models. The most immediate benefit of `diffusers` is +that they load from disk very quickly. A longer term benefit is that in the near +future `diffusers` models will be able to share common sub-models, dramatically +reducing disk space when you have multiple fine-tune models derived from the +same base. -When you perform a new install of version 2.3.0, you will be offered the option to install the `diffusers` versions of a number of popular SD models, including Stable Diffusion versions 1.5 and 2.1 (including the 768x768 pixel version of 2.1). These will act and work just like the checkpoint versions. Do not be concerned if you already have a lot of ".ckpt" or ".safetensors" models on disk! InvokeAI 2.3.0 can still load these and generate images from them without any extra intervention on your part. +When you perform a new install of version 2.3.0, you will be offered the option +to install the `diffusers` versions of a number of popular SD models, including +Stable Diffusion versions 1.5 and 2.1 (including the 768x768 pixel version of +2.1). These will act and work just like the checkpoint versions. Do not be +concerned if you already have a lot of ".ckpt" or ".safetensors" models on disk! +InvokeAI 2.3.0 can still load these and generate images from them without any +extra intervention on your part. -To take advantage of the optimized loading times of `diffusers` models, InvokeAI offers options to convert legacy checkpoint models into optimized `diffusers` models. If you use the `invokeai` command line interface, the relevant commands are: +To take advantage of the optimized loading times of `diffusers` models, InvokeAI +offers options to convert legacy checkpoint models into optimized `diffusers` +models. If you use the `invokeai` command line interface, the relevant commands +are: -* `!convert_model` -- Take the path to a local checkpoint file or a URL that is pointing to one, convert it into a `diffusers` model, and import it into InvokeAI's models registry file. -* `!optimize_model` -- If you already have a checkpoint model in your InvokeAI models file, this command will accept its short name and convert it into a like-named `diffusers` model, optionally deleting the original checkpoint file. -* `!import_model` -- Take the local path of either a checkpoint file or a `diffusers` model directory and import it into InvokeAI's registry file. You may also provide the ID of any diffusers model that has been published on the [HuggingFace models repository](https://huggingface.co/models?pipeline_tag=text-to-image&sort=downloads) and it will be downloaded and installed automatically. +- `!convert_model` -- Take the path to a local checkpoint file or a URL that + is pointing to one, convert it into a `diffusers` model, and import it into + InvokeAI's models registry file. +- `!optimize_model` -- If you already have a checkpoint model in your InvokeAI + models file, this command will accept its short name and convert it into a + like-named `diffusers` model, optionally deleting the original checkpoint + file. +- `!import_model` -- Take the local path of either a checkpoint file or a + `diffusers` model directory and import it into InvokeAI's registry file. You + may also provide the ID of any diffusers model that has been published on + the + [HuggingFace models repository](https://huggingface.co/models?pipeline_tag=text-to-image&sort=downloads) + and it will be downloaded and installed automatically. The WebGUI offers similar functionality for model management. -For advanced users, new command-line options provide additional functionality. Launching `invokeai` with the argument `--autoconvert ` takes the path to a directory of checkpoint files, automatically converts them into `diffusers` models and imports them. Each time the script is launched, the directory will be scanned for new checkpoint files to be loaded. Alternatively, the `--ckpt_convert` argument will cause any checkpoint or safetensors model that is already registered with InvokeAI to be converted into a `diffusers` model on the fly, allowing you to take advantage of future diffusers-only features without explicitly converting the model and saving it to disk. +For advanced users, new command-line options provide additional functionality. +Launching `invokeai` with the argument `--autoconvert ` takes +the path to a directory of checkpoint files, automatically converts them into +`diffusers` models and imports them. Each time the script is launched, the +directory will be scanned for new checkpoint files to be loaded. Alternatively, +the `--ckpt_convert` argument will cause any checkpoint or safetensors model +that is already registered with InvokeAI to be converted into a `diffusers` +model on the fly, allowing you to take advantage of future diffusers-only +features without explicitly converting the model and saving it to disk. -Please see [INSTALLING MODELS](https://invoke-ai.github.io/InvokeAI/installation/050_INSTALLING_MODELS/) for more information on model management in both the command-line and Web interfaces. +Please see +[INSTALLING MODELS](https://invoke-ai.github.io/InvokeAI/installation/050_INSTALLING_MODELS/) +for more information on model management in both the command-line and Web +interfaces. #### Support for the `XFormers` Memory-Efficient Crossattention Package -On CUDA (Nvidia) systems, version 2.3.0 supports the `XFormers` library. Once installed, the`xformers` package dramatically reduces the memory footprint of loaded Stable Diffusion models files and modestly increases image generation speed. `xformers` will be installed and activated automatically if you specify a CUDA system at install time. +On CUDA (Nvidia) systems, version 2.3.0 supports the `XFormers` library. Once +installed, the`xformers` package dramatically reduces the memory footprint of +loaded Stable Diffusion models files and modestly increases image generation +speed. `xformers` will be installed and activated automatically if you specify a +CUDA system at install time. -The caveat with using `xformers` is that it introduces slightly non-deterministic behavior, and images generated using the same seed and other settings will be subtly different between invocations. Generally the changes are unnoticeable unless you rapidly shift back and forth between images, but to disable `xformers` and restore fully deterministic behavior, you may launch InvokeAI using the `--no-xformers` option. This is most conveniently done by opening the file `invokeai/invokeai.init` with a text editor, and adding the line `--no-xformers` at the bottom. +The caveat with using `xformers` is that it introduces slightly +non-deterministic behavior, and images generated using the same seed and other +settings will be subtly different between invocations. Generally the changes are +unnoticeable unless you rapidly shift back and forth between images, but to +disable `xformers` and restore fully deterministic behavior, you may launch +InvokeAI using the `--no-xformers` option. This is most conveniently done by +opening the file `invokeai/invokeai.init` with a text editor, and adding the +line `--no-xformers` at the bottom. #### A Negative Prompt Box in the WebUI -There is now a separate text input box for negative prompts in the WebUI. This is convenient for stashing frequently-used negative prompts ("mangled limbs, bad anatomy"). The `[negative prompt]` syntax continues to work in the main prompt box as well. +There is now a separate text input box for negative prompts in the WebUI. This +is convenient for stashing frequently-used negative prompts ("mangled limbs, bad +anatomy"). The `[negative prompt]` syntax continues to work in the main prompt +box as well. -To see exactly how your prompts are being parsed, launch `invokeai` with the `--log_tokenization` option. The console window will then display the tokenization process for both positive and negative prompts. +To see exactly how your prompts are being parsed, launch `invokeai` with the +`--log_tokenization` option. The console window will then display the +tokenization process for both positive and negative prompts. #### Model Merging -Version 2.3.0 offers an intuitive user interface for merging up to three Stable Diffusion models using an intuitive user interface. Model merging allows you to mix the behavior of models to achieve very interesting effects. To use this, each of the models must already be imported into InvokeAI and saved in `diffusers` format, then launch the merger using a new menu item in the InvokeAI launcher script (`invoke.sh`, `invoke.bat`) or directly from the command line with `invokeai-merge --gui`. You will be prompted to select the models to merge, the proportions in which to mix them, and the mixing algorithm. The script will create a new merged `diffusers` model and import it into InvokeAI for your use. +Version 2.3.0 offers an intuitive user interface for merging up to three Stable +Diffusion models using an intuitive user interface. Model merging allows you to +mix the behavior of models to achieve very interesting effects. To use this, +each of the models must already be imported into InvokeAI and saved in +`diffusers` format, then launch the merger using a new menu item in the InvokeAI +launcher script (`invoke.sh`, `invoke.bat`) or directly from the command line +with `invokeai-merge --gui`. You will be prompted to select the models to merge, +the proportions in which to mix them, and the mixing algorithm. The script will +create a new merged `diffusers` model and import it into InvokeAI for your use. -See [MODEL MERGING](https://invoke-ai.github.io/InvokeAI/features/MODEL_MERGING/) for more details. +See +[MODEL MERGING](https://invoke-ai.github.io/InvokeAI/features/MODEL_MERGING/) +for more details. #### Textual Inversion Training -Textual Inversion (TI) is a technique for training a Stable Diffusion model to emit a particular subject or style when triggered by a keyword phrase. You can perform TI training by placing a small number of images of the subject or style in a directory, and choosing a distinctive trigger phrase, such as "pointillist-style". After successful training, The subject or style will be activated by including `` in your prompt. +Textual Inversion (TI) is a technique for training a Stable Diffusion model to +emit a particular subject or style when triggered by a keyword phrase. You can +perform TI training by placing a small number of images of the subject or style +in a directory, and choosing a distinctive trigger phrase, such as +"pointillist-style". After successful training, The subject or style will be +activated by including `` in your prompt. -Previous versions of InvokeAI were able to perform TI, but it required using a command-line script with dozens of obscure command-line arguments. Version 2.3.0 features an intuitive TI frontend that will build a TI model on top of any `diffusers` model. To access training you can launch from a new item in the launcher script or from the command line using `invokeai-ti --gui`. +Previous versions of InvokeAI were able to perform TI, but it required using a +command-line script with dozens of obscure command-line arguments. Version 2.3.0 +features an intuitive TI frontend that will build a TI model on top of any +`diffusers` model. To access training you can launch from a new item in the +launcher script or from the command line using `invokeai-ti --gui`. -See [TEXTUAL INVERSION](https://invoke-ai.github.io/InvokeAI/features/TEXTUAL_INVERSION/) for further details. +See +[TEXTUAL INVERSION](https://invoke-ai.github.io/InvokeAI/features/TEXTUAL_INVERSION/) +for further details. #### A New Installer Experience -The InvokeAI installer has been upgraded in order to provide a smoother and hopefully more glitch-free experience. In addition, InvokeAI is now packaged as a PyPi project, allowing developers and power-users to install InvokeAI with the command `pip install InvokeAI --use-pep517`. Please see [Installation](#installation) for details. +The InvokeAI installer has been upgraded in order to provide a smoother and +hopefully more glitch-free experience. In addition, InvokeAI is now packaged as +a PyPi project, allowing developers and power-users to install InvokeAI with the +command `pip install InvokeAI --use-pep517`. Please see +[Installation](#installation) for details. -Developers should be aware that the `pip` installation procedure has been simplified and that the `conda` method is no longer supported at all. Accordingly, the `environments_and_requirements` directory has been deleted from the repository. +Developers should be aware that the `pip` installation procedure has been +simplified and that the `conda` method is no longer supported at all. +Accordingly, the `environments_and_requirements` directory has been deleted from +the repository. #### Command-line name changes -All of InvokeAI's functionality, including the WebUI, command-line interface, textual inversion training and model merging, can all be accessed from the `invoke.sh` and `invoke.bat` launcher scripts. The menu of options has been expanded to add the new functionality. For the convenience of developers and power users, we have normalized the names of the InvokeAI command-line scripts: +All of InvokeAI's functionality, including the WebUI, command-line interface, +textual inversion training and model merging, can all be accessed from the +`invoke.sh` and `invoke.bat` launcher scripts. The menu of options has been +expanded to add the new functionality. For the convenience of developers and +power users, we have normalized the names of the InvokeAI command-line scripts: -* `invokeai` -- Command-line client -* `invokeai --web` -- Web GUI -* `invokeai-merge --gui` -- Model merging script with graphical front end -* `invokeai-ti --gui` -- Textual inversion script with graphical front end -* `invokeai-configure` -- Configuration tool for initializing the `invokeai` directory and selecting popular starter models. +- `invokeai` -- Command-line client +- `invokeai --web` -- Web GUI +- `invokeai-merge --gui` -- Model merging script with graphical front end +- `invokeai-ti --gui` -- Textual inversion script with graphical front end +- `invokeai-configure` -- Configuration tool for initializing the `invokeai` + directory and selecting popular starter models. -For backward compatibility, the old command names are also recognized, including `invoke.py` and `configure-invokeai.py`. However, these are deprecated and will eventually be removed. +For backward compatibility, the old command names are also recognized, including +`invoke.py` and `configure-invokeai.py`. However, these are deprecated and will +eventually be removed. -Developers should be aware that the locations of the script's source code has been moved. The new locations are: -* `invokeai` => `ldm/invoke/CLI.py` -* `invokeai-configure` => `ldm/invoke/config/configure_invokeai.py` -* `invokeai-ti`=> `ldm/invoke/training/textual_inversion.py` -* `invokeai-merge` => `ldm/invoke/merge_diffusers` +Developers should be aware that the locations of the script's source code has +been moved. The new locations are: -Developers are strongly encouraged to perform an "editable" install of InvokeAI using `pip install -e . --use-pep517` in the Git repository, and then to call the scripts using their 2.3.0 names, rather than executing the scripts directly. Developers should also be aware that the several important data files have been relocated into a new directory named `invokeai`. This includes the WebGUI's `frontend` and `backend` directories, and the `INITIAL_MODELS.yaml` files used by the installer to select starter models. Eventually all InvokeAI modules will be in subdirectories of `invokeai`. +- `invokeai` => `ldm/invoke/CLI.py` +- `invokeai-configure` => `ldm/invoke/config/configure_invokeai.py` +- `invokeai-ti`=> `ldm/invoke/training/textual_inversion.py` +- `invokeai-merge` => `ldm/invoke/merge_diffusers` -Please see [2.3.0 Release Notes](https://github.com/invoke-ai/InvokeAI/releases/tag/v2.3.0) for further details. -For older changelogs, please visit the +Developers are strongly encouraged to perform an "editable" install of InvokeAI +using `pip install -e . --use-pep517` in the Git repository, and then to call +the scripts using their 2.3.0 names, rather than executing the scripts directly. +Developers should also be aware that the several important data files have been +relocated into a new directory named `invokeai`. This includes the WebGUI's +`frontend` and `backend` directories, and the `INITIAL_MODELS.yaml` files used +by the installer to select starter models. Eventually all InvokeAI modules will +be in subdirectories of `invokeai`. + +Please see +[2.3.0 Release Notes](https://github.com/invoke-ai/InvokeAI/releases/tag/v2.3.0) +for further details. For older changelogs, please visit the **[CHANGELOG](CHANGELOG/#v223-2-december-2022)**. ## :material-target: Troubleshooting -Please check out our **[:material-frequently-asked-questions: -Troubleshooting -Guide](installation/010_INSTALL_AUTOMATED.md#troubleshooting)** to -get solutions for common installation problems and other issues. +Please check out our +**[:material-frequently-asked-questions: Troubleshooting Guide](installation/010_INSTALL_AUTOMATED.md#troubleshooting)** +to get solutions for common installation problems and other issues. ## :octicons-repo-push-24: Contributing @@ -265,8 +388,8 @@ thank them for their time, hard work and effort. For support, please use this repository's GitHub Issues tracking service. Feel free to send me an email if you use and like the script. -Original portions of the software are Copyright (c) 2022-23 -by [The InvokeAI Team](https://github.com/invoke-ai). +Original portions of the software are Copyright (c) 2022-23 by +[The InvokeAI Team](https://github.com/invoke-ai). ## :octicons-book-24: Further Reading From 6082aace6ddc683c730b6a1f24f70d5fca9d7c8a Mon Sep 17 00:00:00 2001 From: mauwii Date: Sun, 19 Feb 2023 22:33:49 +0100 Subject: [PATCH 45/57] update `docs/help/contributing/010_PULL_REQUEST` - prepend brand icons on tabs --- docs/help/contributing/010_PULL_REQUEST.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/help/contributing/010_PULL_REQUEST.md b/docs/help/contributing/010_PULL_REQUEST.md index 511a1fcc63..b6ac8dcab2 100644 --- a/docs/help/contributing/010_PULL_REQUEST.md +++ b/docs/help/contributing/010_PULL_REQUEST.md @@ -30,14 +30,14 @@ in the top right. After you forked the Repository, you should clone it to your dev machine: -=== "Linux:fontawesome-brands-linux: / MacOS:simple-apple:" +=== ":fontawesome-brands-linux:Linux / :simple-apple:macOS" ``` sh git clone https://github.com//InvokeAI \ && cd InvokeAI ``` -=== "Windows:fontawesome-brands-windows:" +=== ":fontawesome-brands-windows:Windows" ``` powershell git clone https://github.com//InvokeAI ` @@ -51,7 +51,7 @@ activate a venv first. Afterwards you can install the InvokeAI Package, including dev and docs extras in editable mode, follwed by the installation of the pre-commit hook: -=== "Linux:fontawesome-brands-linux: / MacOS:simple-apple:" +=== ":fontawesome-brands-linux:Linux / :simple-apple:macOS" ``` sh python -m venv .venv \ @@ -65,7 +65,7 @@ the pre-commit hook: && pre-commit install ``` -=== "Windows:fontawesome-brands-windows:" +=== ":fontawesome-brands-windows:Windows" ``` powershell python -m venv .venv ` @@ -83,7 +83,7 @@ the pre-commit hook: Make sure you are on main branch, from there create your feature branch: -=== "Linux:fontawesome-brands-linux: / MacOS:simple-apple:" +=== ":fontawesome-brands-linux:Linux / :simple-apple:macOS" ``` sh git checkout main \ @@ -91,7 +91,7 @@ Make sure you are on main branch, from there create your feature branch: && git checkout -B ``` -=== "Windows:fontawesome-brands-windows:" +=== ":fontawesome-brands-windows:Windows" ``` powershell git checkout main ` From fa391c0b7835dded461b23b21ce510f231e7f613 Mon Sep 17 00:00:00 2001 From: mauwii Date: Fri, 24 Feb 2023 00:31:38 +0100 Subject: [PATCH 46/57] fix pyproject.toml - add missing asterisk for backend package - remove old comment --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 793ed62d4f..b544b9eb9c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ dependencies = [ "albumentations", "click", "clip_anytorch", + "compel>=0.1.6", "datasets", "diffusers[torch]~=0.13", "dnspython==2.2.1", @@ -134,7 +135,7 @@ version = {attr = "ldm.invoke.__version__"} [tool.setuptools.packages.find] "include" = [ "invokeai.assets.web", - "invokeai.backend", + "invokeai.backend*", "invokeai.configs*", "invokeai.frontend.dist*", "ldm*", From b0657d5fde36579e7c137b1e76c693549d06480d Mon Sep 17 00:00:00 2001 From: mauwii Date: Fri, 24 Feb 2023 16:13:04 +0100 Subject: [PATCH 47/57] just4fun --- .../contributing/090_NODE_TRANSFORMATION.md | 76 +++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 docs/help/contributing/090_NODE_TRANSFORMATION.md diff --git a/docs/help/contributing/090_NODE_TRANSFORMATION.md b/docs/help/contributing/090_NODE_TRANSFORMATION.md new file mode 100644 index 0000000000..0a27022d51 --- /dev/null +++ b/docs/help/contributing/090_NODE_TRANSFORMATION.md @@ -0,0 +1,76 @@ +# Tranformation to nodes + +## Current state + +```mermaid +flowchart TD + web[WebUI]; + cli[CLI]; + web -- img2img --> generate; + web -- txt2img --> generate; + cli -- txt2img --> generate; + cli -- img2img --> generate; + generate --> model_manager; + generate --> generators; + generate --> ti_manager[TI Manager]; + generate --> etc; +``` + +## Transitional Architecture + +### first step + +```mermaid +flowchart TD + web[WebUI]; + cli[CLI]; + web -- img2img --> img2img_node[Img2img node]; + web -- txt2img --> generate; + img2img_node --> model_manager; + img2img_node --> generators; + cli -- txt2img --> generate; + cli -- img2img --> generate; + generate --> model_manager; + generate --> generators; + generate --> ti_manager[TI Manager]; + generate --> etc; +``` + +### second step + +```mermaid +flowchart TD + web[WebUI]; + cli[CLI]; + web -- img2img --> img2img_node[img2img node]; + img2img_node --> model_manager; + img2img_node --> generators; + web -- txt2img --> txt2img_node; + cli -- txt2img --> txt2img_node; + cli -- img2img --> generate; + generate --> model_manager; + generate --> generators; + generate --> ti_manager[TI Manager]; + generate --> etc; + txt2img_node --> model_manager; + txt2img_node --> generators; + txt2img_node --> ti_manager[TI Manager]; +``` + +## Final Architecture + +```mermaid +flowchart TD + web[WebUI]; + cli[CLI]; + web -- img2img --> img2img_node[img2img node]; + cli -- img2img --> img2img_node; + web -- txt2img --> txt2img_node; + cli -- txt2img --> txt2img_node; + img2img_node --> model_manager; + txt2img_node --> model_manager; + img2img_node --> generators; + txt2img_node --> generators; + img2img_node --> ti_manager[TI Manager]; + txt2img_node --> ti_manager[TI Manager]; +``` From 71ff75969261714363d090310cd0ffa7dfb045b4 Mon Sep 17 00:00:00 2001 From: mauwii Date: Fri, 24 Feb 2023 16:37:29 +0100 Subject: [PATCH 48/57] minor improvement to mermaid diagrams --- .../contributing/090_NODE_TRANSFORMATION.md | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/docs/help/contributing/090_NODE_TRANSFORMATION.md b/docs/help/contributing/090_NODE_TRANSFORMATION.md index 0a27022d51..10fa89bd04 100644 --- a/docs/help/contributing/090_NODE_TRANSFORMATION.md +++ b/docs/help/contributing/090_NODE_TRANSFORMATION.md @@ -6,10 +6,10 @@ flowchart TD web[WebUI]; cli[CLI]; - web -- img2img --> generate; - web -- txt2img --> generate; - cli -- txt2img --> generate; - cli -- img2img --> generate; + web --> |img2img| generate(generate); + web --> |txt2img| generate(generate); + cli --> |txt2img| generate(generate); + cli --> |img2img| generate(generate); generate --> model_manager; generate --> generators; generate --> ti_manager[TI Manager]; @@ -24,12 +24,12 @@ flowchart TD flowchart TD web[WebUI]; cli[CLI]; - web -- img2img --> img2img_node[Img2img node]; - web -- txt2img --> generate; + web --> |img2img| img2img_node(Img2img node); + web --> |txt2img| generate(generate); img2img_node --> model_manager; img2img_node --> generators; - cli -- txt2img --> generate; - cli -- img2img --> generate; + cli --> |txt2img| generate; + cli --> |img2img| generate; generate --> model_manager; generate --> generators; generate --> ti_manager[TI Manager]; @@ -42,12 +42,12 @@ flowchart TD flowchart TD web[WebUI]; cli[CLI]; - web -- img2img --> img2img_node[img2img node]; + web --> |img2img| img2img_node(img2img node); img2img_node --> model_manager; img2img_node --> generators; - web -- txt2img --> txt2img_node; - cli -- txt2img --> txt2img_node; - cli -- img2img --> generate; + web --> |txt2img| txt2img_node(txt2img node); + cli --> |txt2img| txt2img_node; + cli --> |img2img| generate(generate); generate --> model_manager; generate --> generators; generate --> ti_manager[TI Manager]; @@ -63,10 +63,10 @@ flowchart TD flowchart TD web[WebUI]; cli[CLI]; - web -- img2img --> img2img_node[img2img node]; - cli -- img2img --> img2img_node; - web -- txt2img --> txt2img_node; - cli -- txt2img --> txt2img_node; + web --> |img2img|img2img_node(img2img node); + cli --> |img2img|img2img_node; + web --> |txt2img|txt2img_node(txt2img node); + cli --> |txt2img|txt2img_node; img2img_node --> model_manager; txt2img_node --> model_manager; img2img_node --> generators; From 357601e2d673786d42e920bb24d4a1cf55c66540 Mon Sep 17 00:00:00 2001 From: Kyle Schouviller Date: Wed, 30 Nov 2022 21:33:20 -0800 Subject: [PATCH 49/57] parent 9eed1919c2071f9199996df747c8638c4a75e8fb author Kyle Schouviller 1669872800 -0800 committer Kyle Schouviller 1676240900 -0800 Adding base node architecture Fix type annotation errors Runs and generates, but breaks in saving session Fix default model value setting. Fix deprecation warning. Fixed node api Adding markdown docs Simplifying Generate construction in apps [nodes] A few minor changes (#2510) * Pin api-related requirements * Remove confusing extra CORS origins list * Adds response models for HTTP 200 [nodes] Adding graph_execution_state to soon replace session. Adding tests with pytest. Minor typing fixes [nodes] Fix some small output query hookups [node] Fixing some additional typing issues [nodes] Move and expand graph code. Add base item storage and sqlite implementation. Update startup to match new code [nodes] Add callbacks to item storage [nodes] Adding an InvocationContext object to use for invocations to provide easier extensibility [nodes] New execution model that handles iteration [nodes] Fixing the CLI [nodes] Adding a note to the CLI [nodes] Split processing thread into separate service [node] Add error message on node processing failure Removing old files and duplicated packages Adding python-multipart --- .coveragerc | 6 + .gitignore | 1 + .pytest.ini | 5 + docs/contributing/ARCHITECTURE.md | 93 ++ docs/contributing/INVOCATIONS.md | 105 +++ ldm/generate.py | 6 + ldm/invoke/app/api/dependencies.py | 83 ++ ldm/invoke/app/api/events.py | 54 ++ ldm/invoke/app/api/routers/images.py | 57 ++ ldm/invoke/app/api/routers/sessions.py | 232 +++++ ldm/invoke/app/api/sockets.py | 36 + ldm/invoke/app/api_app.py | 164 ++++ ldm/invoke/app/cli_app.py | 306 +++++++ ldm/invoke/app/invocations/__init__.py | 8 + ldm/invoke/app/invocations/baseinvocation.py | 74 ++ ldm/invoke/app/invocations/cv.py | 42 + ldm/invoke/app/invocations/generate.py | 160 ++++ ldm/invoke/app/invocations/image.py | 219 +++++ ldm/invoke/app/invocations/prompt.py | 9 + ldm/invoke/app/invocations/reconstruct.py | 36 + ldm/invoke/app/invocations/upscale.py | 38 + ldm/invoke/app/services/__init__.py | 0 ldm/invoke/app/services/events.py | 78 ++ .../app/services/generate_initializer.py | 233 +++++ ldm/invoke/app/services/graph.py | 797 ++++++++++++++++++ ldm/invoke/app/services/image_storage.py | 104 +++ ldm/invoke/app/services/invocation_queue.py | 46 + .../app/services/invocation_services.py | 20 + ldm/invoke/app/services/invoker.py | 109 +++ ldm/invoke/app/services/item_storage.py | 57 ++ ldm/invoke/app/services/processor.py | 78 ++ ldm/invoke/app/services/sqlite.py | 119 +++ pyproject.toml | 5 + scripts/invoke-new.py | 20 + static/dream_web/test.html | 206 +++++ tests/__init__.py | 0 tests/nodes/__init__.py | 0 tests/nodes/test_graph_execution_state.py | 114 +++ tests/nodes/test_invoker.py | 85 ++ tests/nodes/test_node_graph.py | 501 +++++++++++ tests/nodes/test_nodes.py | 92 ++ tests/nodes/test_sqlite.py | 112 +++ 42 files changed, 4510 insertions(+) create mode 100644 .coveragerc create mode 100644 .pytest.ini create mode 100644 docs/contributing/ARCHITECTURE.md create mode 100644 docs/contributing/INVOCATIONS.md create mode 100644 ldm/invoke/app/api/dependencies.py create mode 100644 ldm/invoke/app/api/events.py create mode 100644 ldm/invoke/app/api/routers/images.py create mode 100644 ldm/invoke/app/api/routers/sessions.py create mode 100644 ldm/invoke/app/api/sockets.py create mode 100644 ldm/invoke/app/api_app.py create mode 100644 ldm/invoke/app/cli_app.py create mode 100644 ldm/invoke/app/invocations/__init__.py create mode 100644 ldm/invoke/app/invocations/baseinvocation.py create mode 100644 ldm/invoke/app/invocations/cv.py create mode 100644 ldm/invoke/app/invocations/generate.py create mode 100644 ldm/invoke/app/invocations/image.py create mode 100644 ldm/invoke/app/invocations/prompt.py create mode 100644 ldm/invoke/app/invocations/reconstruct.py create mode 100644 ldm/invoke/app/invocations/upscale.py create mode 100644 ldm/invoke/app/services/__init__.py create mode 100644 ldm/invoke/app/services/events.py create mode 100644 ldm/invoke/app/services/generate_initializer.py create mode 100644 ldm/invoke/app/services/graph.py create mode 100644 ldm/invoke/app/services/image_storage.py create mode 100644 ldm/invoke/app/services/invocation_queue.py create mode 100644 ldm/invoke/app/services/invocation_services.py create mode 100644 ldm/invoke/app/services/invoker.py create mode 100644 ldm/invoke/app/services/item_storage.py create mode 100644 ldm/invoke/app/services/processor.py create mode 100644 ldm/invoke/app/services/sqlite.py create mode 100644 scripts/invoke-new.py create mode 100644 static/dream_web/test.html create mode 100644 tests/__init__.py create mode 100644 tests/nodes/__init__.py create mode 100644 tests/nodes/test_graph_execution_state.py create mode 100644 tests/nodes/test_invoker.py create mode 100644 tests/nodes/test_node_graph.py create mode 100644 tests/nodes/test_nodes.py create mode 100644 tests/nodes/test_sqlite.py diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000000..8232fc4b93 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,6 @@ +[run] +omit='.env/*' +source='.' + +[report] +show_missing = true diff --git a/.gitignore b/.gitignore index 9adb0be85a..9b33e07164 100644 --- a/.gitignore +++ b/.gitignore @@ -68,6 +68,7 @@ htmlcov/ .cache nosetests.xml coverage.xml +cov.xml *.cover *.py,cover .hypothesis/ diff --git a/.pytest.ini b/.pytest.ini new file mode 100644 index 0000000000..16ccfafe80 --- /dev/null +++ b/.pytest.ini @@ -0,0 +1,5 @@ +[pytest] +DJANGO_SETTINGS_MODULE = webtas.settings +; python_files = tests.py test_*.py *_tests.py + +addopts = --cov=. --cov-config=.coveragerc --cov-report xml:cov.xml diff --git a/docs/contributing/ARCHITECTURE.md b/docs/contributing/ARCHITECTURE.md new file mode 100644 index 0000000000..d74df94492 --- /dev/null +++ b/docs/contributing/ARCHITECTURE.md @@ -0,0 +1,93 @@ +# Invoke.AI Architecture + +```mermaid +flowchart TB + + subgraph apps[Applications] + webui[WebUI] + cli[CLI] + + subgraph webapi[Web API] + api[HTTP API] + sio[Socket.IO] + end + + end + + subgraph invoke[Invoke] + direction LR + invoker + services + sessions + invocations + end + + subgraph core[AI Core] + Generate + end + + webui --> webapi + webapi --> invoke + cli --> invoke + + invoker --> services & sessions + invocations --> services + sessions --> invocations + + services --> core + + %% Styles + classDef sg fill:#5028C8,font-weight:bold,stroke-width:2,color:#fff,stroke:#14141A + classDef default stroke-width:2px,stroke:#F6B314,color:#fff,fill:#14141A + + class apps,webapi,invoke,core sg + +``` + +## Applications + +Applications are built on top of the invoke framework. They should construct `invoker` and then interact through it. They should avoid interacting directly with core code in order to support a variety of configurations. + +### Web UI + +The Web UI is built on top of an HTTP API built with [FastAPI](https://fastapi.tiangolo.com/) and [Socket.IO](https://socket.io/). The frontend code is found in `/frontend` and the backend code is found in `/ldm/invoke/app/api_app.py` and `/ldm/invoke/app/api/`. The code is further organized as such: + +| Component | Description | +| --- | --- | +| api_app.py | Sets up the API app, annotates the OpenAPI spec with additional data, and runs the API | +| dependencies | Creates all invoker services and the invoker, and provides them to the API | +| events | An eventing system that could in the future be adapted to support horizontal scale-out | +| sockets | The Socket.IO interface - handles listening to and emitting session events (events are defined in the events service module) | +| routers | API definitions for different areas of API functionality | + +### CLI + +The CLI is built automatically from invocation metadata, and also supports invocation piping and auto-linking. Code is available in `/ldm/invoke/app/cli_app.py`. + +## Invoke + +The Invoke framework provides the interface to the underlying AI systems and is built with flexibility and extensibility in mind. There are four major concepts: invoker, sessions, invocations, and services. + +### Invoker + +The invoker (`/ldm/invoke/app/services/invoker.py`) is the primary interface through which applications interact with the framework. Its primary purpose is to create, manage, and invoke sessions. It also maintains two sets of services: +- **invocation services**, which are used by invocations to interact with core functionality. +- **invoker services**, which are used by the invoker to manage sessions and manage the invocation queue. + +### Sessions + +Invocations and links between them form a graph, which is maintained in a session. Sessions can be queued for invocation, which will execute their graph (either the next ready invocation, or all invocations). Sessions also maintain execution history for the graph (including storage of any outputs). An invocation may be added to a session at any time, and there is capability to add and entire graph at once, as well as to automatically link new invocations to previous invocations. Invocations can not be deleted or modified once added. + +The session graph does not support looping. This is left as an application problem to prevent additional complexity in the graph. + +### Invocations + +Invocations represent individual units of execution, with inputs and outputs. All invocations are located in `/ldm/invoke/app/invocations`, and are all automatically discovered and made available in the applications. These are the primary way to expose new functionality in Invoke.AI, and the [implementation guide](INVOCATIONS.md) explains how to add new invocations. + +### Services + +Services provide invocations access AI Core functionality and other necessary functionality (e.g. image storage). These are available in `/ldm/invoke/app/services`. As a general rule, new services should provide an interface as an abstract base class, and may provide a lightweight local implementation by default in their module. The goal for all services should be to enable the usage of different implementations (e.g. using cloud storage for image storage), but should not load any module dependencies unless that implementation has been used (i.e. don't import anything that won't be used, especially if it's expensive to import). + +## AI Core + +The AI Core is represented by the rest of the code base (i.e. the code outside of `/ldm/invoke/app/`). diff --git a/docs/contributing/INVOCATIONS.md b/docs/contributing/INVOCATIONS.md new file mode 100644 index 0000000000..c8a97c19e4 --- /dev/null +++ b/docs/contributing/INVOCATIONS.md @@ -0,0 +1,105 @@ +# Invocations + +Invocations represent a single operation, its inputs, and its outputs. These operations and their outputs can be chained together to generate and modify images. + +## Creating a new invocation + +To create a new invocation, either find the appropriate module file in `/ldm/invoke/app/invocations` to add your invocation to, or create a new one in that folder. All invocations in that folder will be discovered and made available to the CLI and API automatically. Invocations make use of [typing](https://docs.python.org/3/library/typing.html) and [pydantic](https://pydantic-docs.helpmanual.io/) for validation and integration into the CLI and API. + +An invocation looks like this: + +```py +class UpscaleInvocation(BaseInvocation): + """Upscales an image.""" + type: Literal['upscale'] = 'upscale' + + # Inputs + image: Union[ImageField,None] = Field(description="The input image") + strength: float = Field(default=0.75, gt=0, le=1, description="The strength") + level: Literal[2,4] = Field(default=2, description = "The upscale level") + + def invoke(self, context: InvocationContext) -> ImageOutput: + image = context.services.images.get(self.image.image_type, self.image.image_name) + results = context.services.generate.upscale_and_reconstruct( + image_list = [[image, 0]], + upscale = (self.level, self.strength), + strength = 0.0, # GFPGAN strength + save_original = False, + image_callback = None, + ) + + # Results are image and seed, unwrap for now + # TODO: can this return multiple results? + image_type = ImageType.RESULT + image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) + context.services.images.save(image_type, image_name, results[0][0]) + return ImageOutput( + image = ImageField(image_type = image_type, image_name = image_name) + ) +``` + +Each portion is important to implement correctly. + +### Class definition and type +```py +class UpscaleInvocation(BaseInvocation): + """Upscales an image.""" + type: Literal['upscale'] = 'upscale' +``` +All invocations must derive from `BaseInvocation`. They should have a docstring that declares what they do in a single, short line. They should also have a `type` with a type hint that's `Literal["command_name"]`, where `command_name` is what the user will type on the CLI or use in the API to create this invocation. The `command_name` must be unique. The `type` must be assigned to the value of the literal in the type hint. + +### Inputs +```py + # Inputs + image: Union[ImageField,None] = Field(description="The input image") + strength: float = Field(default=0.75, gt=0, le=1, description="The strength") + level: Literal[2,4] = Field(default=2, description="The upscale level") +``` +Inputs consist of three parts: a name, a type hint, and a `Field` with default, description, and validation information. For example: +| Part | Value | Description | +| ---- | ----- | ----------- | +| Name | `strength` | This field is referred to as `strength` | +| Type Hint | `float` | This field must be of type `float` | +| Field | `Field(default=0.75, gt=0, le=1, description="The strength")` | The default value is `0.75`, the value must be in the range (0,1], and help text will show "The strength" for this field. | + +Notice that `image` has type `Union[ImageField,None]`. The `Union` allows this field to be parsed with `None` as a value, which enables linking to previous invocations. All fields should either provide a default value or allow `None` as a value, so that they can be overwritten with a linked output from another invocation. + +The special type `ImageField` is also used here. All images are passed as `ImageField`, which protects them from pydantic validation errors (since images only ever come from links). + +Finally, note that for all linking, the `type` of the linked fields must match. If the `name` also matches, then the field can be **automatically linked** to a previous invocation by name and matching. + +### Invoke Function +```py + def invoke(self, context: InvocationContext) -> ImageOutput: + image = context.services.images.get(self.image.image_type, self.image.image_name) + results = context.services.generate.upscale_and_reconstruct( + image_list = [[image, 0]], + upscale = (self.level, self.strength), + strength = 0.0, # GFPGAN strength + save_original = False, + image_callback = None, + ) + + # Results are image and seed, unwrap for now + image_type = ImageType.RESULT + image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) + context.services.images.save(image_type, image_name, results[0][0]) + return ImageOutput( + image = ImageField(image_type = image_type, image_name = image_name) + ) +``` +The `invoke` function is the last portion of an invocation. It is provided an `InvocationContext` which contains services to perform work as well as a `session_id` for use as needed. It should return a class with output values that derives from `BaseInvocationOutput`. + +Before being called, the invocation will have all of its fields set from defaults, inputs, and finally links (overriding in that order). + +Assume that this invocation may be running simultaneously with other invocations, may be running on another machine, or in other interesting scenarios. If you need functionality, please provide it as a service in the `InvocationServices` class, and make sure it can be overridden. + +### Outputs +```py +class ImageOutput(BaseInvocationOutput): + """Base class for invocations that output an image""" + type: Literal['image'] = 'image' + + image: ImageField = Field(default=None, description="The output image") +``` +Output classes look like an invocation class without the invoke method. Prefer to use an existing output class if available, and prefer to name inputs the same as outputs when possible, to promote automatic invocation linking. diff --git a/ldm/generate.py b/ldm/generate.py index 413a1e25cb..256f214b25 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -1030,6 +1030,8 @@ class Generate: image_callback=None, prefix=None, ): + + results = [] for r in image_list: image, seed = r try: @@ -1083,6 +1085,10 @@ class Generate: else: r[0] = image + results.append([image, seed]) + + return results + def apply_textmask( self, image_path: str, prompt: str, callback, threshold: float = 0.5 ): diff --git a/ldm/invoke/app/api/dependencies.py b/ldm/invoke/app/api/dependencies.py new file mode 100644 index 0000000000..60dd522803 --- /dev/null +++ b/ldm/invoke/app/api/dependencies.py @@ -0,0 +1,83 @@ +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) + +from argparse import Namespace +import os + +from ..services.processor import DefaultInvocationProcessor + +from ..services.graph import GraphExecutionState +from ..services.sqlite import SqliteItemStorage + +from ...globals import Globals + +from ..services.image_storage import DiskImageStorage +from ..services.invocation_queue import MemoryInvocationQueue +from ..services.invocation_services import InvocationServices +from ..services.invoker import Invoker, InvokerServices +from ..services.generate_initializer import get_generate +from .events import FastAPIEventService + + +# TODO: is there a better way to achieve this? +def check_internet()->bool: + ''' + Return true if the internet is reachable. + It does this by pinging huggingface.co. + ''' + import urllib.request + host = 'http://huggingface.co' + try: + urllib.request.urlopen(host,timeout=1) + return True + except: + return False + + +class ApiDependencies: + """Contains and initializes all dependencies for the API""" + invoker: Invoker = None + + @staticmethod + def initialize( + args, + config, + event_handler_id: int + ): + Globals.try_patchmatch = args.patchmatch + Globals.always_use_cpu = args.always_use_cpu + Globals.internet_available = args.internet_available and check_internet() + Globals.disable_xformers = not args.xformers + Globals.ckpt_convert = args.ckpt_convert + + # TODO: Use a logger + print(f'>> Internet connectivity is {Globals.internet_available}') + + generate = get_generate(args, config) + + events = FastAPIEventService(event_handler_id) + + output_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../../outputs')) + + images = DiskImageStorage(output_folder) + + services = InvocationServices( + generate = generate, + events = events, + images = images + ) + + # TODO: build a file/path manager? + db_location = os.path.join(output_folder, 'invokeai.db') + + invoker_services = InvokerServices( + queue = MemoryInvocationQueue(), + graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = db_location, table_name = 'graph_executions'), + processor = DefaultInvocationProcessor() + ) + + ApiDependencies.invoker = Invoker(services, invoker_services) + + @staticmethod + def shutdown(): + if ApiDependencies.invoker: + ApiDependencies.invoker.stop() diff --git a/ldm/invoke/app/api/events.py b/ldm/invoke/app/api/events.py new file mode 100644 index 0000000000..701b48a316 --- /dev/null +++ b/ldm/invoke/app/api/events.py @@ -0,0 +1,54 @@ +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) + +import asyncio +from queue import Empty, Queue +from typing import Any +from fastapi_events.dispatcher import dispatch +from ..services.events import EventServiceBase +import threading + +class FastAPIEventService(EventServiceBase): + event_handler_id: int + __queue: Queue + __stop_event: threading.Event + + def __init__(self, event_handler_id: int) -> None: + self.event_handler_id = event_handler_id + self.__queue = Queue() + self.__stop_event = threading.Event() + asyncio.create_task(self.__dispatch_from_queue(stop_event = self.__stop_event)) + + super().__init__() + + + def stop(self, *args, **kwargs): + self.__stop_event.set() + self.__queue.put(None) + + + def dispatch(self, event_name: str, payload: Any) -> None: + self.__queue.put(dict( + event_name = event_name, + payload = payload + )) + + + async def __dispatch_from_queue(self, stop_event: threading.Event): + """Get events on from the queue and dispatch them, from the correct thread""" + while not stop_event.is_set(): + try: + event = self.__queue.get(block = False) + if not event: # Probably stopping + continue + + dispatch( + event.get('event_name'), + payload = event.get('payload'), + middleware_id = self.event_handler_id) + + except Empty: + await asyncio.sleep(0.001) + pass + + except asyncio.CancelledError as e: + raise e # Raise a proper error diff --git a/ldm/invoke/app/api/routers/images.py b/ldm/invoke/app/api/routers/images.py new file mode 100644 index 0000000000..1ae116e49d --- /dev/null +++ b/ldm/invoke/app/api/routers/images.py @@ -0,0 +1,57 @@ +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) + +from datetime import datetime, timezone +from fastapi import Path, UploadFile, Request +from fastapi.routing import APIRouter +from fastapi.responses import FileResponse, Response +from PIL import Image +from ...services.image_storage import ImageType +from ..dependencies import ApiDependencies + +images_router = APIRouter( + prefix = '/v1/images', + tags = ['images'] +) + + +@images_router.get('/{image_type}/{image_name}', + operation_id = 'get_image' + ) +async def get_image( + image_type: ImageType = Path(description = "The type of image to get"), + image_name: str = Path(description = "The name of the image to get") +): + """Gets a result""" + # TODO: This is not really secure at all. At least make sure only output results are served + filename = ApiDependencies.invoker.services.images.get_path(image_type, image_name) + return FileResponse(filename) + +@images_router.post('/uploads/', + operation_id = 'upload_image', + responses = { + 201: {'description': 'The image was uploaded successfully'}, + 404: {'description': 'Session not found'} + }) +async def upload_image( + file: UploadFile, + request: Request +): + if not file.content_type.startswith('image'): + return Response(status_code = 415) + + contents = await file.read() + try: + im = Image.open(contents) + except: + # Error opening the image + return Response(status_code = 415) + + filename = f'{str(int(datetime.now(timezone.utc).timestamp()))}.png' + ApiDependencies.invoker.services.images.save(ImageType.UPLOAD, filename, im) + + return Response( + status_code=201, + headers = { + 'Location': request.url_for('get_image', image_type=ImageType.UPLOAD, image_name=filename) + } + ) diff --git a/ldm/invoke/app/api/routers/sessions.py b/ldm/invoke/app/api/routers/sessions.py new file mode 100644 index 0000000000..77008ad6e4 --- /dev/null +++ b/ldm/invoke/app/api/routers/sessions.py @@ -0,0 +1,232 @@ +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) + +from typing import List, Optional, Union, Annotated +from fastapi import Query, Path, Body +from fastapi.routing import APIRouter +from fastapi.responses import Response +from pydantic.fields import Field + +from ...services.item_storage import PaginatedResults +from ..dependencies import ApiDependencies +from ...invocations.baseinvocation import BaseInvocation +from ...services.graph import EdgeConnection, Graph, GraphExecutionState, NodeAlreadyExecutedError +from ...invocations import * + +session_router = APIRouter( + prefix = '/v1/sessions', + tags = ['sessions'] +) + + +@session_router.post('/', + operation_id = 'create_session', + responses = { + 200: {"model": GraphExecutionState}, + 400: {'description': 'Invalid json'} + }) +async def create_session( + graph: Optional[Graph] = Body(default = None, description = "The graph to initialize the session with") +) -> GraphExecutionState: + """Creates a new session, optionally initializing it with an invocation graph""" + session = ApiDependencies.invoker.create_execution_state(graph) + return session + + +@session_router.get('/', + operation_id = 'list_sessions', + responses = { + 200: {"model": PaginatedResults[GraphExecutionState]} + }) +async def list_sessions( + page: int = Query(default = 0, description = "The page of results to get"), + per_page: int = Query(default = 10, description = "The number of results per page"), + query: str = Query(default = '', description = "The query string to search for") +) -> PaginatedResults[GraphExecutionState]: + """Gets a list of sessions, optionally searching""" + if filter == '': + result = ApiDependencies.invoker.invoker_services.graph_execution_manager.list(page, per_page) + else: + result = ApiDependencies.invoker.invoker_services.graph_execution_manager.search(query, page, per_page) + return result + + +@session_router.get('/{session_id}', + operation_id = 'get_session', + responses = { + 200: {"model": GraphExecutionState}, + 404: {'description': 'Session not found'} + }) +async def get_session( + session_id: str = Path(description = "The id of the session to get") +) -> GraphExecutionState: + """Gets a session""" + session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) + if session is None: + return Response(status_code = 404) + else: + return session + + +@session_router.post('/{session_id}/nodes', + operation_id = 'add_node', + responses = { + 200: {"model": str}, + 400: {'description': 'Invalid node or link'}, + 404: {'description': 'Session not found'} + } +) +async def add_node( + session_id: str = Path(description = "The id of the session"), + node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body(description = "The node to add") +) -> str: + """Adds a node to the graph""" + session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) + if session is None: + return Response(status_code = 404) + + try: + session.add_node(node) + ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? + return session.id + except NodeAlreadyExecutedError: + return Response(status_code = 400) + except IndexError: + return Response(status_code = 400) + + +@session_router.put('/{session_id}/nodes/{node_path}', + operation_id = 'update_node', + responses = { + 200: {"model": GraphExecutionState}, + 400: {'description': 'Invalid node or link'}, + 404: {'description': 'Session not found'} + } +) +async def update_node( + session_id: str = Path(description = "The id of the session"), + node_path: str = Path(description = "The path to the node in the graph"), + node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body(description = "The new node") +) -> GraphExecutionState: + """Updates a node in the graph and removes all linked edges""" + session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) + if session is None: + return Response(status_code = 404) + + try: + session.update_node(node_path, node) + ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? + return session + except NodeAlreadyExecutedError: + return Response(status_code = 400) + except IndexError: + return Response(status_code = 400) + + +@session_router.delete('/{session_id}/nodes/{node_path}', + operation_id = 'delete_node', + responses = { + 200: {"model": GraphExecutionState}, + 400: {'description': 'Invalid node or link'}, + 404: {'description': 'Session not found'} + } +) +async def delete_node( + session_id: str = Path(description = "The id of the session"), + node_path: str = Path(description = "The path to the node to delete") +) -> GraphExecutionState: + """Deletes a node in the graph and removes all linked edges""" + session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) + if session is None: + return Response(status_code = 404) + + try: + session.delete_node(node_path) + ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? + return session + except NodeAlreadyExecutedError: + return Response(status_code = 400) + except IndexError: + return Response(status_code = 400) + + +@session_router.post('/{session_id}/edges', + operation_id = 'add_edge', + responses = { + 200: {"model": GraphExecutionState}, + 400: {'description': 'Invalid node or link'}, + 404: {'description': 'Session not found'} + } +) +async def add_edge( + session_id: str = Path(description = "The id of the session"), + edge: tuple[EdgeConnection, EdgeConnection] = Body(description = "The edge to add") +) -> GraphExecutionState: + """Adds an edge to the graph""" + session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) + if session is None: + return Response(status_code = 404) + + try: + session.add_edge(edge) + ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? + return session + except NodeAlreadyExecutedError: + return Response(status_code = 400) + except IndexError: + return Response(status_code = 400) + + +# TODO: the edge being in the path here is really ugly, find a better solution +@session_router.delete('/{session_id}/edges/{from_node_id}/{from_field}/{to_node_id}/{to_field}', + operation_id = 'delete_edge', + responses = { + 200: {"model": GraphExecutionState}, + 400: {'description': 'Invalid node or link'}, + 404: {'description': 'Session not found'} + } +) +async def delete_edge( + session_id: str = Path(description = "The id of the session"), + from_node_id: str = Path(description = "The id of the node the edge is coming from"), + from_field: str = Path(description = "The field of the node the edge is coming from"), + to_node_id: str = Path(description = "The id of the node the edge is going to"), + to_field: str = Path(description = "The field of the node the edge is going to") +) -> GraphExecutionState: + """Deletes an edge from the graph""" + session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) + if session is None: + return Response(status_code = 404) + + try: + edge = (EdgeConnection(node_id = from_node_id, field = from_field), EdgeConnection(node_id = to_node_id, field = to_field)) + session.delete_edge(edge) + ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? + return session + except NodeAlreadyExecutedError: + return Response(status_code = 400) + except IndexError: + return Response(status_code = 400) + + +@session_router.put('/{session_id}/invoke', + operation_id = 'invoke_session', + responses = { + 200: {"model": None}, + 202: {'description': 'The invocation is queued'}, + 400: {'description': 'The session has no invocations ready to invoke'}, + 404: {'description': 'Session not found'} + }) +async def invoke_session( + session_id: str = Path(description = "The id of the session to invoke"), + all: bool = Query(default = False, description = "Whether or not to invoke all remaining invocations") +) -> None: + """Invokes a session""" + session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) + if session is None: + return Response(status_code = 404) + + if session.is_complete(): + return Response(status_code = 400) + + ApiDependencies.invoker.invoke(session, invoke_all = all) + return Response(status_code=202) diff --git a/ldm/invoke/app/api/sockets.py b/ldm/invoke/app/api/sockets.py new file mode 100644 index 0000000000..eb4d5403c0 --- /dev/null +++ b/ldm/invoke/app/api/sockets.py @@ -0,0 +1,36 @@ +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) + +from fastapi import FastAPI +from fastapi_socketio import SocketManager +from fastapi_events.handlers.local import local_handler +from fastapi_events.typing import Event +from ..services.events import EventServiceBase + +class SocketIO: + __sio: SocketManager + + def __init__(self, app: FastAPI): + self.__sio = SocketManager(app = app) + self.__sio.on('subscribe', handler=self._handle_sub) + self.__sio.on('unsubscribe', handler=self._handle_unsub) + + local_handler.register( + event_name = EventServiceBase.session_event, + _func=self._handle_session_event + ) + + async def _handle_session_event(self, event: Event): + await self.__sio.emit( + event = event[1]['event'], + data = event[1]['data'], + room = event[1]['data']['graph_execution_state_id'] + ) + + async def _handle_sub(self, sid, data, *args, **kwargs): + if 'session' in data: + self.__sio.enter_room(sid, data['session']) + + # @app.sio.on('unsubscribe') + async def _handle_unsub(self, sid, data, *args, **kwargs): + if 'session' in data: + self.__sio.leave_room(sid, data['session']) diff --git a/ldm/invoke/app/api_app.py b/ldm/invoke/app/api_app.py new file mode 100644 index 0000000000..db79b0d7e8 --- /dev/null +++ b/ldm/invoke/app/api_app.py @@ -0,0 +1,164 @@ +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) + +import asyncio +from inspect import signature +from fastapi import FastAPI +from fastapi.openapi.utils import get_openapi +from fastapi.openapi.docs import get_swagger_ui_html, get_redoc_html +from fastapi.staticfiles import StaticFiles +from fastapi_events.middleware import EventHandlerASGIMiddleware +from fastapi_events.handlers.local import local_handler +from fastapi.middleware.cors import CORSMiddleware +from pydantic.schema import schema +import uvicorn +from .api.sockets import SocketIO +from .invocations import * +from .invocations.baseinvocation import BaseInvocation +from .api.routers import images, sessions +from .api.dependencies import ApiDependencies +from ..args import Args + +# Create the app +# TODO: create this all in a method so configuration/etc. can be passed in? +app = FastAPI( + title = "Invoke AI", + docs_url = None, + redoc_url = None +) + +# Add event handler +event_handler_id: int = id(app) +app.add_middleware( + EventHandlerASGIMiddleware, + handlers = [local_handler], # TODO: consider doing this in services to support different configurations + middleware_id = event_handler_id) + +# Add CORS +# TODO: use configuration for this +origins = [] +app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +socket_io = SocketIO(app) + +config = {} + +# Add startup event to load dependencies +@app.on_event('startup') +async def startup_event(): + args = Args() + config = args.parse_args() + + ApiDependencies.initialize( + args = args, + config = config, + event_handler_id = event_handler_id + ) + +# Shut down threads +@app.on_event('shutdown') +async def shutdown_event(): + ApiDependencies.shutdown() + +# Include all routers +# TODO: REMOVE +# app.include_router( +# invocation.invocation_router, +# prefix = '/api') + +app.include_router( + sessions.session_router, + prefix = '/api' +) + +app.include_router( + images.images_router, + prefix = '/api' +) + +# Build a custom OpenAPI to include all outputs +# TODO: can outputs be included on metadata of invocation schemas somehow? +def custom_openapi(): + if app.openapi_schema: + return app.openapi_schema + openapi_schema = get_openapi( + title = app.title, + description = "An API for invoking AI image operations", + version = "1.0.0", + routes = app.routes + ) + + # Add all outputs + all_invocations = BaseInvocation.get_invocations() + output_types = set() + output_type_titles = dict() + for invoker in all_invocations: + output_type = signature(invoker.invoke).return_annotation + output_types.add(output_type) + + output_schemas = schema(output_types, ref_prefix="#/components/schemas/") + for schema_key, output_schema in output_schemas['definitions'].items(): + openapi_schema["components"]["schemas"][schema_key] = output_schema + + # TODO: note that we assume the schema_key here is the TYPE.__name__ + # This could break in some cases, figure out a better way to do it + output_type_titles[schema_key] = output_schema['title'] + + # Add a reference to the output type to additionalProperties of the invoker schema + for invoker in all_invocations: + invoker_name = invoker.__name__ + output_type = signature(invoker.invoke).return_annotation + output_type_title = output_type_titles[output_type.__name__] + invoker_schema = openapi_schema["components"]["schemas"][invoker_name] + outputs_ref = { '$ref': f'#/components/schemas/{output_type_title}' } + if 'additionalProperties' not in invoker_schema: + invoker_schema['additionalProperties'] = {} + + invoker_schema['additionalProperties']['outputs'] = outputs_ref + + app.openapi_schema = openapi_schema + return app.openapi_schema + +app.openapi = custom_openapi + +# Override API doc favicons +app.mount('/static', StaticFiles(directory='static/dream_web'), name='static') + +@app.get("/docs", include_in_schema=False) +def overridden_swagger(): + return get_swagger_ui_html( + openapi_url=app.openapi_url, + title=app.title, + swagger_favicon_url="/static/favicon.ico" + ) + +@app.get("/redoc", include_in_schema=False) +def overridden_redoc(): + return get_redoc_html( + openapi_url=app.openapi_url, + title=app.title, + redoc_favicon_url="/static/favicon.ico" + ) + +def invoke_api(): + # Start our own event loop for eventing usage + # TODO: determine if there's a better way to do this + loop = asyncio.new_event_loop() + config = uvicorn.Config( + app = app, + host = "0.0.0.0", + port = 9090, + loop = loop) + # Use access_log to turn off logging + + server = uvicorn.Server(config) + loop.run_until_complete(server.serve()) + + +if __name__ == "__main__": + invoke_api() diff --git a/ldm/invoke/app/cli_app.py b/ldm/invoke/app/cli_app.py new file mode 100644 index 0000000000..6071afabb2 --- /dev/null +++ b/ldm/invoke/app/cli_app.py @@ -0,0 +1,306 @@ +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) + +import argparse +import shlex +import os +import time +from typing import Any, Dict, Iterable, Literal, Union, get_args, get_origin, get_type_hints +from pydantic import BaseModel +from pydantic.fields import Field + +from .services.processor import DefaultInvocationProcessor + +from .services.graph import EdgeConnection, GraphExecutionState + +from .services.sqlite import SqliteItemStorage + +from .invocations.image import ImageField +from .services.generate_initializer import get_generate +from .services.image_storage import DiskImageStorage +from .services.invocation_queue import MemoryInvocationQueue +from .invocations.baseinvocation import BaseInvocation +from .services.invocation_services import InvocationServices +from .services.invoker import Invoker, InvokerServices +from .invocations import * +from ..args import Args +from .services.events import EventServiceBase + + +class InvocationCommand(BaseModel): + invocation: Union[BaseInvocation.get_invocations()] = Field(discriminator="type") + + +class InvalidArgs(Exception): + pass + + +def get_invocation_parser() -> argparse.ArgumentParser: + + # Create invocation parser + parser = argparse.ArgumentParser() + def exit(*args, **kwargs): + raise InvalidArgs + parser.exit = exit + + subparsers = parser.add_subparsers(dest='type') + invocation_parsers = dict() + + # Add history parser + history_parser = subparsers.add_parser('history', help="Shows the invocation history") + history_parser.add_argument('count', nargs='?', default=5, type=int, help="The number of history entries to show") + + # Add default parser + default_parser = subparsers.add_parser('default', help="Define a default value for all inputs with a specified name") + default_parser.add_argument('input', type=str, help="The input field") + default_parser.add_argument('value', help="The default value") + + default_parser = subparsers.add_parser('reset_default', help="Resets a default value") + default_parser.add_argument('input', type=str, help="The input field") + + # Create subparsers for each invocation + invocations = BaseInvocation.get_all_subclasses() + for invocation in invocations: + hints = get_type_hints(invocation) + cmd_name = get_args(hints['type'])[0] + command_parser = subparsers.add_parser(cmd_name, help=invocation.__doc__) + invocation_parsers[cmd_name] = command_parser + + # Add linking capability + command_parser.add_argument('--link', '-l', action='append', nargs=3, + help="A link in the format 'dest_field source_node source_field'. source_node can be relative to history (e.g. -1)") + + command_parser.add_argument('--link_node', '-ln', action='append', + help="A link from all fields in the specified node. Node can be relative to history (e.g. -1)") + + # Convert all fields to arguments + fields = invocation.__fields__ + for name, field in fields.items(): + if name in ['id', 'type']: + continue + + if get_origin(field.type_) == Literal: + allowed_values = get_args(field.type_) + allowed_types = set() + for val in allowed_values: + allowed_types.add(type(val)) + allowed_types_list = list(allowed_types) + field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] + + command_parser.add_argument( + f"--{name}", + dest=name, + type=field_type, + default=field.default, + choices = allowed_values, + help=field.field_info.description + ) + else: + command_parser.add_argument( + f"--{name}", + dest=name, + type=field.type_, + default=field.default, + help=field.field_info.description + ) + + return parser + + +def get_invocation_command(invocation) -> str: + fields = invocation.__fields__.items() + type_hints = get_type_hints(type(invocation)) + command = [invocation.type] + for name,field in fields: + if name in ['id', 'type']: + continue + + # TODO: add links + + # Skip image fields when serializing command + type_hint = type_hints.get(name) or None + if type_hint is ImageField or ImageField in get_args(type_hint): + continue + + field_value = getattr(invocation, name) + field_default = field.default + if field_value != field_default: + if type_hint is str or str in get_args(type_hint): + command.append(f'--{name} "{field_value}"') + else: + command.append(f'--{name} {field_value}') + + return ' '.join(command) + + +def get_graph_execution_history(graph_execution_state: GraphExecutionState) -> Iterable[str]: + """Gets the history of fully-executed invocations for a graph execution""" + return (n for n in reversed(graph_execution_state.executed_history) if n in graph_execution_state.graph.nodes) + + +def generate_matching_edges(a: BaseInvocation, b: BaseInvocation) -> list[tuple[EdgeConnection, EdgeConnection]]: + """Generates all possible edges between two invocations""" + atype = type(a) + btype = type(b) + + aoutputtype = atype.get_output_type() + + afields = get_type_hints(aoutputtype) + bfields = get_type_hints(btype) + + matching_fields = set(afields.keys()).intersection(bfields.keys()) + + # Remove invalid fields + invalid_fields = set(['type', 'id']) + matching_fields = matching_fields.difference(invalid_fields) + + edges = [(EdgeConnection(node_id = a.id, field = field), EdgeConnection(node_id = b.id, field = field)) for field in matching_fields] + return edges + + +def invoke_cli(): + args = Args() + config = args.parse_args() + + generate = get_generate(args, config) + + # NOTE: load model on first use, uncomment to load at startup + # TODO: Make this a config option? + #generate.load_model() + + events = EventServiceBase() + + output_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../outputs')) + + services = InvocationServices( + generate = generate, + events = events, + images = DiskImageStorage(output_folder) + ) + + # TODO: build a file/path manager? + db_location = os.path.join(output_folder, 'invokeai.db') + + invoker_services = InvokerServices( + queue = MemoryInvocationQueue(), + graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = db_location, table_name = 'graph_executions'), + processor = DefaultInvocationProcessor() + ) + + invoker = Invoker(services, invoker_services) + session = invoker.create_execution_state() + + parser = get_invocation_parser() + + # Uncomment to print out previous sessions at startup + # print(invoker_services.session_manager.list()) + + # Defaults storage + defaults: Dict[str, Any] = dict() + + while True: + try: + cmd_input = input("> ") + except KeyboardInterrupt: + # Ctrl-c exits + break + + if cmd_input in ['exit','q']: + break; + + if cmd_input in ['--help','help','h','?']: + parser.print_help() + continue + + try: + # Refresh the state of the session + session = invoker.invoker_services.graph_execution_manager.get(session.id) + history = list(get_graph_execution_history(session)) + + # Split the command for piping + cmds = cmd_input.split('|') + start_id = len(history) + current_id = start_id + new_invocations = list() + for cmd in cmds: + # Parse args to create invocation + args = vars(parser.parse_args(shlex.split(cmd.strip()))) + + # Check for special commands + # TODO: These might be better as Pydantic models, similar to the invocations + if args['type'] == 'history': + history_count = args['count'] or 5 + for i in range(min(history_count, len(history))): + entry_id = history[-1 - i] + entry = session.graph.get_node(entry_id) + print(f'{entry_id}: {get_invocation_command(entry.invocation)}') + continue + + if args['type'] == 'reset_default': + if args['input'] in defaults: + del defaults[args['input']] + continue + + if args['type'] == 'default': + field = args['input'] + field_value = args['value'] + defaults[field] = field_value + continue + + # Override defaults + for field_name,field_default in defaults.items(): + if field_name in args: + args[field_name] = field_default + + # Parse invocation + args['id'] = current_id + command = InvocationCommand(invocation = args) + + # Pipe previous command output (if there was a previous command) + edges = [] + if len(history) > 0 or current_id != start_id: + from_id = history[0] if current_id == start_id else str(current_id - 1) + from_node = next(filter(lambda n: n[0].id == from_id, new_invocations))[0] if current_id != start_id else session.graph.get_node(from_id) + matching_edges = generate_matching_edges(from_node, command.invocation) + edges.extend(matching_edges) + + # Parse provided links + if 'link_node' in args and args['link_node']: + for link in args['link_node']: + link_node = session.graph.get_node(link) + matching_edges = generate_matching_edges(link_node, command.invocation) + edges.extend(matching_edges) + + if 'link' in args and args['link']: + for link in args['link']: + edges.append((EdgeConnection(node_id = link[1], field = link[0]), EdgeConnection(node_id = command.invocation.id, field = link[2]))) + + new_invocations.append((command.invocation, edges)) + + current_id = current_id + 1 + + # Command line was parsed successfully + # Add the invocations to the session + for invocation in new_invocations: + session.add_node(invocation[0]) + for edge in invocation[1]: + session.add_edge(edge) + + # Execute all available invocations + invoker.invoke(session, invoke_all = True) + while not session.is_complete(): + # Wait some time + session = invoker.invoker_services.graph_execution_manager.get(session.id) + time.sleep(0.1) + + except InvalidArgs: + print('Invalid command, use "help" to list commands') + continue + + except SystemExit: + continue + + invoker.stop() + + +if __name__ == "__main__": + invoke_cli() diff --git a/ldm/invoke/app/invocations/__init__.py b/ldm/invoke/app/invocations/__init__.py new file mode 100644 index 0000000000..6407a1cdee --- /dev/null +++ b/ldm/invoke/app/invocations/__init__.py @@ -0,0 +1,8 @@ +import os + +__all__ = [] + +dirname = os.path.dirname(os.path.abspath(__file__)) +for f in os.listdir(dirname): + if f != "__init__.py" and os.path.isfile("%s/%s" % (dirname, f)) and f[-3:] == ".py": + __all__.append(f[:-3]) diff --git a/ldm/invoke/app/invocations/baseinvocation.py b/ldm/invoke/app/invocations/baseinvocation.py new file mode 100644 index 0000000000..1ad2d99112 --- /dev/null +++ b/ldm/invoke/app/invocations/baseinvocation.py @@ -0,0 +1,74 @@ +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) + +from abc import ABC, abstractmethod +from inspect import signature +from typing import get_args, get_type_hints +from pydantic import BaseModel, Field +from ..services.invocation_services import InvocationServices + + +class InvocationContext: + services: InvocationServices + graph_execution_state_id: str + + def __init__(self, services: InvocationServices, graph_execution_state_id: str): + self.services = services + self.graph_execution_state_id = graph_execution_state_id + + +class BaseInvocationOutput(BaseModel): + """Base class for all invocation outputs""" + + # All outputs must include a type name like this: + # type: Literal['your_output_name'] + + @classmethod + def get_all_subclasses_tuple(cls): + subclasses = [] + toprocess = [cls] + while len(toprocess) > 0: + next = toprocess.pop(0) + next_subclasses = next.__subclasses__() + subclasses.extend(next_subclasses) + toprocess.extend(next_subclasses) + return tuple(subclasses) + + +class BaseInvocation(ABC, BaseModel): + """A node to process inputs and produce outputs. + May use dependency injection in __init__ to receive providers. + """ + + # All invocations must include a type name like this: + # type: Literal['your_output_name'] + + @classmethod + def get_all_subclasses(cls): + subclasses = [] + toprocess = [cls] + while len(toprocess) > 0: + next = toprocess.pop(0) + next_subclasses = next.__subclasses__() + subclasses.extend(next_subclasses) + toprocess.extend(next_subclasses) + return subclasses + + @classmethod + def get_invocations(cls): + return tuple(BaseInvocation.get_all_subclasses()) + + @classmethod + def get_invocations_map(cls): + # Get the type strings out of the literals and into a dictionary + return dict(map(lambda t: (get_args(get_type_hints(t)['type'])[0], t),BaseInvocation.get_all_subclasses())) + + @classmethod + def get_output_type(cls): + return signature(cls.invoke).return_annotation + + @abstractmethod + def invoke(self, context: InvocationContext) -> BaseInvocationOutput: + """Invoke with provided context and return outputs.""" + pass + + id: str = Field(description="The id of this node. Must be unique among all nodes.") diff --git a/ldm/invoke/app/invocations/cv.py b/ldm/invoke/app/invocations/cv.py new file mode 100644 index 0000000000..f950669736 --- /dev/null +++ b/ldm/invoke/app/invocations/cv.py @@ -0,0 +1,42 @@ +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) + +from typing import Literal +import numpy +from pydantic import Field +from PIL import Image, ImageOps +import cv2 as cv +from .image import ImageField, ImageOutput +from .baseinvocation import BaseInvocation, InvocationContext +from ..services.image_storage import ImageType + + +class CvInpaintInvocation(BaseInvocation): + """Simple inpaint using opencv.""" + type: Literal['cv_inpaint'] = 'cv_inpaint' + + # Inputs + image: ImageField = Field(default=None, description="The image to inpaint") + mask: ImageField = Field(default=None, description="The mask to use when inpainting") + + def invoke(self, context: InvocationContext) -> ImageOutput: + image = context.services.images.get(self.image.image_type, self.image.image_name) + mask = context.services.images.get(self.mask.image_type, self.mask.image_name) + + # Convert to cv image/mask + # TODO: consider making these utility functions + cv_image = cv.cvtColor(numpy.array(image.convert('RGB')), cv.COLOR_RGB2BGR) + cv_mask = numpy.array(ImageOps.invert(mask)) + + # Inpaint + cv_inpainted = cv.inpaint(cv_image, cv_mask, 3, cv.INPAINT_TELEA) + + # Convert back to Pillow + # TODO: consider making a utility function + image_inpainted = Image.fromarray(cv.cvtColor(cv_inpainted, cv.COLOR_BGR2RGB)) + + image_type = ImageType.INTERMEDIATE + image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) + context.services.images.save(image_type, image_name, image_inpainted) + return ImageOutput( + image = ImageField(image_type = image_type, image_name = image_name) + ) diff --git a/ldm/invoke/app/invocations/generate.py b/ldm/invoke/app/invocations/generate.py new file mode 100644 index 0000000000..60b656bf0c --- /dev/null +++ b/ldm/invoke/app/invocations/generate.py @@ -0,0 +1,160 @@ +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) + +from datetime import datetime, timezone +from typing import Any, Literal, Optional, Union +import numpy as np +from pydantic import Field +from PIL import Image +from skimage.exposure.histogram_matching import match_histograms +from .image import ImageField, ImageOutput +from .baseinvocation import BaseInvocation, InvocationContext +from ..services.image_storage import ImageType +from ..services.invocation_services import InvocationServices + + +SAMPLER_NAME_VALUES = Literal["ddim","plms","k_lms","k_dpm_2","k_dpm_2_a","k_euler","k_euler_a","k_heun"] + +# Text to image +class TextToImageInvocation(BaseInvocation): + """Generates an image using text2img.""" + type: Literal['txt2img'] = 'txt2img' + + # Inputs + # TODO: consider making prompt optional to enable providing prompt through a link + prompt: Optional[str] = Field(description="The prompt to generate an image from") + seed: int = Field(default=-1, ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)") + steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image") + width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image") + height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting image") + cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt") + sampler_name: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The sampler to use") + seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams") + model: str = Field(default='', description="The model to use (currently ignored)") + progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation") + + # TODO: pass this an emitter method or something? or a session for dispatching? + def dispatch_progress(self, context: InvocationContext, sample: Any = None, step: int = 0) -> None: + context.services.events.emit_generator_progress( + context.graph_execution_state_id, self.id, step, float(step) / float(self.steps) + ) + + def invoke(self, context: InvocationContext) -> ImageOutput: + + def step_callback(sample, step = 0): + self.dispatch_progress(context, sample, step) + + # Handle invalid model parameter + # TODO: figure out if this can be done via a validator that uses the model_cache + # TODO: How to get the default model name now? + if self.model is None or self.model == '': + self.model = context.services.generate.model_name + + # Set the model (if already cached, this does nothing) + context.services.generate.set_model(self.model) + + results = context.services.generate.prompt2image( + prompt = self.prompt, + step_callback = step_callback, + **self.dict(exclude = {'prompt'}) # Shorthand for passing all of the parameters above manually + ) + + # Results are image and seed, unwrap for now and ignore the seed + # TODO: pre-seed? + # TODO: can this return multiple results? Should it? + image_type = ImageType.RESULT + image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) + context.services.images.save(image_type, image_name, results[0][0]) + return ImageOutput( + image = ImageField(image_type = image_type, image_name = image_name) + ) + + +class ImageToImageInvocation(TextToImageInvocation): + """Generates an image using img2img.""" + type: Literal['img2img'] = 'img2img' + + # Inputs + image: Union[ImageField,None] = Field(description="The input image") + strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the original image") + fit: bool = Field(default=True, description="Whether or not the result should be fit to the aspect ratio of the input image") + + def invoke(self, context: InvocationContext) -> ImageOutput: + image = None if self.image is None else context.services.images.get(self.image.image_type, self.image.image_name) + mask = None + + def step_callback(sample, step = 0): + self.dispatch_progress(context, sample, step) + + # Handle invalid model parameter + # TODO: figure out if this can be done via a validator that uses the model_cache + # TODO: How to get the default model name now? + if self.model is None or self.model == '': + self.model = context.services.generate.model_name + + # Set the model (if already cached, this does nothing) + context.services.generate.set_model(self.model) + + results = context.services.generate.prompt2image( + prompt = self.prompt, + init_img = image, + init_mask = mask, + step_callback = step_callback, + **self.dict(exclude = {'prompt','image','mask'}) # Shorthand for passing all of the parameters above manually + ) + + result_image = results[0][0] + + # Results are image and seed, unwrap for now and ignore the seed + # TODO: pre-seed? + # TODO: can this return multiple results? Should it? + image_type = ImageType.RESULT + image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) + context.services.images.save(image_type, image_name, result_image) + return ImageOutput( + image = ImageField(image_type = image_type, image_name = image_name) + ) + + +class InpaintInvocation(ImageToImageInvocation): + """Generates an image using inpaint.""" + type: Literal['inpaint'] = 'inpaint' + + # Inputs + mask: Union[ImageField,None] = Field(description="The mask") + inpaint_replace: float = Field(default=0.0, ge=0.0, le=1.0, description="The amount by which to replace masked areas with latent noise") + + def invoke(self, context: InvocationContext) -> ImageOutput: + image = None if self.image is None else context.services.images.get(self.image.image_type, self.image.image_name) + mask = None if self.mask is None else context.services.images.get(self.mask.image_type, self.mask.image_name) + + def step_callback(sample, step = 0): + self.dispatch_progress(context, sample, step) + + # Handle invalid model parameter + # TODO: figure out if this can be done via a validator that uses the model_cache + # TODO: How to get the default model name now? + if self.model is None or self.model == '': + self.model = context.services.generate.model_name + + # Set the model (if already cached, this does nothing) + context.services.generate.set_model(self.model) + + results = context.services.generate.prompt2image( + prompt = self.prompt, + init_img = image, + init_mask = mask, + step_callback = step_callback, + **self.dict(exclude = {'prompt','image','mask'}) # Shorthand for passing all of the parameters above manually + ) + + result_image = results[0][0] + + # Results are image and seed, unwrap for now and ignore the seed + # TODO: pre-seed? + # TODO: can this return multiple results? Should it? + image_type = ImageType.RESULT + image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) + context.services.images.save(image_type, image_name, result_image) + return ImageOutput( + image = ImageField(image_type = image_type, image_name = image_name) + ) diff --git a/ldm/invoke/app/invocations/image.py b/ldm/invoke/app/invocations/image.py new file mode 100644 index 0000000000..cb326b1bb7 --- /dev/null +++ b/ldm/invoke/app/invocations/image.py @@ -0,0 +1,219 @@ +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) + +from datetime import datetime, timezone +from typing import Literal, Optional +import numpy +from pydantic import Field, BaseModel +from PIL import Image, ImageOps, ImageFilter +from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext +from ..services.image_storage import ImageType +from ..services.invocation_services import InvocationServices + + +class ImageField(BaseModel): + """An image field used for passing image objects between invocations""" + image_type: str = Field(default=ImageType.RESULT, description="The type of the image") + image_name: Optional[str] = Field(default=None, description="The name of the image") + + +class ImageOutput(BaseInvocationOutput): + """Base class for invocations that output an image""" + type: Literal['image'] = 'image' + + image: ImageField = Field(default=None, description="The output image") + + +class MaskOutput(BaseInvocationOutput): + """Base class for invocations that output a mask""" + type: Literal['mask'] = 'mask' + + mask: ImageField = Field(default=None, description="The output mask") + + +# TODO: this isn't really necessary anymore +class LoadImageInvocation(BaseInvocation): + """Load an image from a filename and provide it as output.""" + type: Literal['load_image'] = 'load_image' + + # Inputs + image_type: ImageType = Field(description="The type of the image") + image_name: str = Field(description="The name of the image") + + def invoke(self, context: InvocationContext) -> ImageOutput: + return ImageOutput( + image = ImageField(image_type = self.image_type, image_name = self.image_name) + ) + + +class ShowImageInvocation(BaseInvocation): + """Displays a provided image, and passes it forward in the pipeline.""" + type: Literal['show_image'] = 'show_image' + + # Inputs + image: ImageField = Field(default=None, description="The image to show") + + def invoke(self, context: InvocationContext) -> ImageOutput: + image = context.services.images.get(self.image.image_type, self.image.image_name) + if image: + image.show() + + # TODO: how to handle failure? + + return ImageOutput( + image = ImageField(image_type = self.image.image_type, image_name = self.image.image_name) + ) + + +class CropImageInvocation(BaseInvocation): + """Crops an image to a specified box. The box can be outside of the image.""" + type: Literal['crop'] = 'crop' + + # Inputs + image: ImageField = Field(default=None, description="The image to crop") + x: int = Field(default=0, description="The left x coordinate of the crop rectangle") + y: int = Field(default=0, description="The top y coordinate of the crop rectangle") + width: int = Field(default=512, gt=0, description="The width of the crop rectangle") + height: int = Field(default=512, gt=0, description="The height of the crop rectangle") + + def invoke(self, context: InvocationContext) -> ImageOutput: + image = context.services.images.get(self.image.image_type, self.image.image_name) + + image_crop = Image.new(mode = 'RGBA', size = (self.width, self.height), color = (0, 0, 0, 0)) + image_crop.paste(image, (-self.x, -self.y)) + + image_type = ImageType.INTERMEDIATE + image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) + context.services.images.save(image_type, image_name, image_crop) + return ImageOutput( + image = ImageField(image_type = image_type, image_name = image_name) + ) + + +class PasteImageInvocation(BaseInvocation): + """Pastes an image into another image.""" + type: Literal['paste'] = 'paste' + + # Inputs + base_image: ImageField = Field(default=None, description="The base image") + image: ImageField = Field(default=None, description="The image to paste") + mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting") + x: int = Field(default=0, description="The left x coordinate at which to paste the image") + y: int = Field(default=0, description="The top y coordinate at which to paste the image") + + def invoke(self, context: InvocationContext) -> ImageOutput: + base_image = context.services.images.get(self.base_image.image_type, self.base_image.image_name) + image = context.services.images.get(self.image.image_type, self.image.image_name) + mask = None if self.mask is None else ImageOps.invert(services.images.get(self.mask.image_type, self.mask.image_name)) + # TODO: probably shouldn't invert mask here... should user be required to do it? + + min_x = min(0, self.x) + min_y = min(0, self.y) + max_x = max(base_image.width, image.width + self.x) + max_y = max(base_image.height, image.height + self.y) + + new_image = Image.new(mode = 'RGBA', size = (max_x - min_x, max_y - min_y), color = (0, 0, 0, 0)) + new_image.paste(base_image, (abs(min_x), abs(min_y))) + new_image.paste(image, (max(0, self.x), max(0, self.y)), mask = mask) + + image_type = ImageType.RESULT + image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) + context.services.images.save(image_type, image_name, new_image) + return ImageOutput( + image = ImageField(image_type = image_type, image_name = image_name) + ) + + +class MaskFromAlphaInvocation(BaseInvocation): + """Extracts the alpha channel of an image as a mask.""" + type: Literal['tomask'] = 'tomask' + + # Inputs + image: ImageField = Field(default=None, description="The image to create the mask from") + invert: bool = Field(default=False, description="Whether or not to invert the mask") + + def invoke(self, context: InvocationContext) -> MaskOutput: + image = context.services.images.get(self.image.image_type, self.image.image_name) + + image_mask = image.split()[-1] + if self.invert: + image_mask = ImageOps.invert(image_mask) + + image_type = ImageType.INTERMEDIATE + image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) + context.services.images.save(image_type, image_name, image_mask) + return MaskOutput( + mask = ImageField(image_type = image_type, image_name = image_name) + ) + + +class BlurInvocation(BaseInvocation): + """Blurs an image""" + type: Literal['blur'] = 'blur' + + # Inputs + image: ImageField = Field(default=None, description="The image to blur") + radius: float = Field(default=8.0, ge=0, description="The blur radius") + blur_type: Literal['gaussian', 'box'] = Field(default='gaussian', description="The type of blur") + + def invoke(self, context: InvocationContext) -> ImageOutput: + image = context.services.images.get(self.image.image_type, self.image.image_name) + + blur = ImageFilter.GaussianBlur(self.radius) if self.blur_type == 'gaussian' else ImageFilter.BoxBlur(self.radius) + blur_image = image.filter(blur) + + image_type = ImageType.INTERMEDIATE + image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) + context.services.images.save(image_type, image_name, blur_image) + return ImageOutput( + image = ImageField(image_type = image_type, image_name = image_name) + ) + + +class LerpInvocation(BaseInvocation): + """Linear interpolation of all pixels of an image""" + type: Literal['lerp'] = 'lerp' + + # Inputs + image: ImageField = Field(default=None, description="The image to lerp") + min: int = Field(default=0, ge=0, le=255, description="The minimum output value") + max: int = Field(default=255, ge=0, le=255, description="The maximum output value") + + def invoke(self, context: InvocationContext) -> ImageOutput: + image = context.services.images.get(self.image.image_type, self.image.image_name) + + image_arr = numpy.asarray(image, dtype=numpy.float32) / 255 + image_arr = image_arr * (self.max - self.min) + self.max + + lerp_image = Image.fromarray(numpy.uint8(image_arr)) + + image_type = ImageType.INTERMEDIATE + image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) + context.services.images.save(image_type, image_name, lerp_image) + return ImageOutput( + image = ImageField(image_type = image_type, image_name = image_name) + ) + + +class InverseLerpInvocation(BaseInvocation): + """Inverse linear interpolation of all pixels of an image""" + type: Literal['ilerp'] = 'ilerp' + + # Inputs + image: ImageField = Field(default=None, description="The image to lerp") + min: int = Field(default=0, ge=0, le=255, description="The minimum input value") + max: int = Field(default=255, ge=0, le=255, description="The maximum input value") + + def invoke(self, context: InvocationContext) -> ImageOutput: + image = context.services.images.get(self.image.image_type, self.image.image_name) + + image_arr = numpy.asarray(image, dtype=numpy.float32) + image_arr = numpy.minimum(numpy.maximum(image_arr - self.min, 0) / float(self.max - self.min), 1) * 255 + + ilerp_image = Image.fromarray(numpy.uint8(image_arr)) + + image_type = ImageType.INTERMEDIATE + image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) + context.services.images.save(image_type, image_name, ilerp_image) + return ImageOutput( + image = ImageField(image_type = image_type, image_name = image_name) + ) diff --git a/ldm/invoke/app/invocations/prompt.py b/ldm/invoke/app/invocations/prompt.py new file mode 100644 index 0000000000..029cad9660 --- /dev/null +++ b/ldm/invoke/app/invocations/prompt.py @@ -0,0 +1,9 @@ +from typing import Literal +from pydantic.fields import Field +from .baseinvocation import BaseInvocationOutput + +class PromptOutput(BaseInvocationOutput): + """Base class for invocations that output a prompt""" + type: Literal['prompt'] = 'prompt' + + prompt: str = Field(default=None, description="The output prompt") diff --git a/ldm/invoke/app/invocations/reconstruct.py b/ldm/invoke/app/invocations/reconstruct.py new file mode 100644 index 0000000000..98201ce837 --- /dev/null +++ b/ldm/invoke/app/invocations/reconstruct.py @@ -0,0 +1,36 @@ +from datetime import datetime, timezone +from typing import Literal, Union +from pydantic import Field +from .image import ImageField, ImageOutput +from .baseinvocation import BaseInvocation, InvocationContext +from ..services.image_storage import ImageType +from ..services.invocation_services import InvocationServices + + +class RestoreFaceInvocation(BaseInvocation): + """Restores faces in an image.""" + type: Literal['restore_face'] = 'restore_face' + + # Inputs + image: Union[ImageField,None] = Field(description="The input image") + strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the restoration") + + + def invoke(self, context: InvocationContext) -> ImageOutput: + image = context.services.images.get(self.image.image_type, self.image.image_name) + results = context.services.generate.upscale_and_reconstruct( + image_list = [[image, 0]], + upscale = None, + strength = self.strength, # GFPGAN strength + save_original = False, + image_callback = None, + ) + + # Results are image and seed, unwrap for now + # TODO: can this return multiple results? + image_type = ImageType.RESULT + image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) + context.services.images.save(image_type, image_name, results[0][0]) + return ImageOutput( + image = ImageField(image_type = image_type, image_name = image_name) + ) diff --git a/ldm/invoke/app/invocations/upscale.py b/ldm/invoke/app/invocations/upscale.py new file mode 100644 index 0000000000..1df8c44ea8 --- /dev/null +++ b/ldm/invoke/app/invocations/upscale.py @@ -0,0 +1,38 @@ +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) + +from datetime import datetime, timezone +from typing import Literal, Union +from pydantic import Field +from .image import ImageField, ImageOutput +from .baseinvocation import BaseInvocation, InvocationContext +from ..services.image_storage import ImageType +from ..services.invocation_services import InvocationServices + + +class UpscaleInvocation(BaseInvocation): + """Upscales an image.""" + type: Literal['upscale'] = 'upscale' + + # Inputs + image: Union[ImageField,None] = Field(description="The input image", default=None) + strength: float = Field(default=0.75, gt=0, le=1, description="The strength") + level: Literal[2,4] = Field(default=2, description = "The upscale level") + + def invoke(self, context: InvocationContext) -> ImageOutput: + image = context.services.images.get(self.image.image_type, self.image.image_name) + results = context.services.generate.upscale_and_reconstruct( + image_list = [[image, 0]], + upscale = (self.level, self.strength), + strength = 0.0, # GFPGAN strength + save_original = False, + image_callback = None, + ) + + # Results are image and seed, unwrap for now + # TODO: can this return multiple results? + image_type = ImageType.RESULT + image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) + context.services.images.save(image_type, image_name, results[0][0]) + return ImageOutput( + image = ImageField(image_type = image_type, image_name = image_name) + ) diff --git a/ldm/invoke/app/services/__init__.py b/ldm/invoke/app/services/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/ldm/invoke/app/services/events.py b/ldm/invoke/app/services/events.py new file mode 100644 index 0000000000..7b850b61ac --- /dev/null +++ b/ldm/invoke/app/services/events.py @@ -0,0 +1,78 @@ +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) + +from typing import Any, Dict + + +class EventServiceBase: + session_event: str = 'session_event' + + """Basic event bus, to have an empty stand-in when not needed""" + def dispatch(self, event_name: str, payload: Any) -> None: + pass + + def __emit_session_event(self, + event_name: str, + payload: Dict) -> None: + self.dispatch( + event_name = EventServiceBase.session_event, + payload = dict( + event = event_name, + data = payload + ) + ) + + # Define events here for every event in the system. + # This will make them easier to integrate until we find a schema generator. + def emit_generator_progress(self, + graph_execution_state_id: str, + invocation_id: str, + step: int, + percent: float + ) -> None: + """Emitted when there is generation progress""" + self.__emit_session_event( + event_name = 'generator_progress', + payload = dict( + graph_execution_state_id = graph_execution_state_id, + invocation_id = invocation_id, + step = step, + percent = percent + ) + ) + + def emit_invocation_complete(self, + graph_execution_state_id: str, + invocation_id: str, + result: Dict + ) -> None: + """Emitted when an invocation has completed""" + self.__emit_session_event( + event_name = 'invocation_complete', + payload = dict( + graph_execution_state_id = graph_execution_state_id, + invocation_id = invocation_id, + result = result + ) + ) + + def emit_invocation_started(self, + graph_execution_state_id: str, + invocation_id: str + ) -> None: + """Emitted when an invocation has started""" + self.__emit_session_event( + event_name = 'invocation_started', + payload = dict( + graph_execution_state_id = graph_execution_state_id, + invocation_id = invocation_id + ) + ) + + def emit_graph_execution_complete(self, graph_execution_state_id: str) -> None: + """Emitted when a session has completed all invocations""" + self.__emit_session_event( + event_name = 'graph_execution_state_complete', + payload = dict( + graph_execution_state_id = graph_execution_state_id + ) + ) diff --git a/ldm/invoke/app/services/generate_initializer.py b/ldm/invoke/app/services/generate_initializer.py new file mode 100644 index 0000000000..39c0fe491e --- /dev/null +++ b/ldm/invoke/app/services/generate_initializer.py @@ -0,0 +1,233 @@ +from argparse import Namespace +import os +import sys +import traceback + +from ...model_manager import ModelManager + +from ...globals import Globals +from ....generate import Generate +import ldm.invoke + + +# TODO: most of this code should be split into individual services as the Generate.py code is deprecated +def get_generate(args, config) -> Generate: + if not args.conf: + config_file = os.path.join(Globals.root,'configs','models.yaml') + if not os.path.exists(config_file): + report_model_error(args, FileNotFoundError(f"The file {config_file} could not be found.")) + + print(f'>> {ldm.invoke.__app_name__}, version {ldm.invoke.__version__}') + print(f'>> InvokeAI runtime directory is "{Globals.root}"') + + # these two lines prevent a horrible warning message from appearing + # when the frozen CLIP tokenizer is imported + import transformers # type: ignore + transformers.logging.set_verbosity_error() + import diffusers + diffusers.logging.set_verbosity_error() + + # Loading Face Restoration and ESRGAN Modules + gfpgan,codeformer,esrgan = load_face_restoration(args) + + # normalize the config directory relative to root + if not os.path.isabs(args.conf): + args.conf = os.path.normpath(os.path.join(Globals.root,args.conf)) + + if args.embeddings: + if not os.path.isabs(args.embedding_path): + embedding_path = os.path.normpath(os.path.join(Globals.root,args.embedding_path)) + else: + embedding_path = args.embedding_path + else: + embedding_path = None + + # migrate legacy models + ModelManager.migrate_models() + + # load the infile as a list of lines + if args.infile: + try: + if os.path.isfile(args.infile): + infile = open(args.infile, 'r', encoding='utf-8') + elif args.infile == '-': # stdin + infile = sys.stdin + else: + raise FileNotFoundError(f'{args.infile} not found.') + except (FileNotFoundError, IOError) as e: + print(f'{e}. Aborting.') + sys.exit(-1) + + # creating a Generate object: + try: + gen = Generate( + conf = args.conf, + model = args.model, + sampler_name = args.sampler_name, + embedding_path = embedding_path, + full_precision = args.full_precision, + precision = args.precision, + gfpgan = gfpgan, + codeformer = codeformer, + esrgan = esrgan, + free_gpu_mem = args.free_gpu_mem, + safety_checker = args.safety_checker, + max_loaded_models = args.max_loaded_models, + ) + except (FileNotFoundError, TypeError, AssertionError) as e: + report_model_error(opt,e) + except (IOError, KeyError) as e: + print(f'{e}. Aborting.') + sys.exit(-1) + + if args.seamless: + print(">> changed to seamless tiling mode") + + # preload the model + try: + gen.load_model() + except KeyError: + pass + except Exception as e: + report_model_error(args, e) + + # try to autoconvert new models + # autoimport new .ckpt files + if path := args.autoconvert: + gen.model_manager.autoconvert_weights( + conf_path=args.conf, + weights_directory=path, + ) + + return gen + + +def load_face_restoration(opt): + try: + gfpgan, codeformer, esrgan = None, None, None + if opt.restore or opt.esrgan: + from ldm.invoke.restoration import Restoration + restoration = Restoration() + if opt.restore: + gfpgan, codeformer = restoration.load_face_restore_models(opt.gfpgan_model_path) + else: + print('>> Face restoration disabled') + if opt.esrgan: + esrgan = restoration.load_esrgan(opt.esrgan_bg_tile) + else: + print('>> Upscaling disabled') + else: + print('>> Face restoration and upscaling disabled') + except (ModuleNotFoundError, ImportError): + print(traceback.format_exc(), file=sys.stderr) + print('>> You may need to install the ESRGAN and/or GFPGAN modules') + return gfpgan,codeformer,esrgan + + +def report_model_error(opt:Namespace, e:Exception): + print(f'** An error occurred while attempting to initialize the model: "{str(e)}"') + print('** This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models.') + yes_to_all = os.environ.get('INVOKE_MODEL_RECONFIGURE') + if yes_to_all: + print('** Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE') + else: + response = input('Do you want to run invokeai-configure script to select and/or reinstall models? [y] ') + if response.startswith(('n', 'N')): + return + + print('invokeai-configure is launching....\n') + + # Match arguments that were set on the CLI + # only the arguments accepted by the configuration script are parsed + root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else [] + config = ["--config", opt.conf] if opt.conf is not None else [] + previous_args = sys.argv + sys.argv = [ 'invokeai-configure' ] + sys.argv.extend(root_dir) + sys.argv.extend(config) + if yes_to_all is not None: + for arg in yes_to_all.split(): + sys.argv.append(arg) + + from ldm.invoke.config import invokeai_configure + invokeai_configure.main() + # TODO: Figure out how to restart + # print('** InvokeAI will now restart') + # sys.argv = previous_args + # main() # would rather do a os.exec(), but doesn't exist? + # sys.exit(0) + + +# Temporary initializer for Generate until we migrate off of it +def old_get_generate(args, config) -> Generate: + # TODO: Remove the need for globals + from ldm.invoke.globals import Globals + + # alert - setting globals here + Globals.root = os.path.expanduser(args.root_dir or os.environ.get('INVOKEAI_ROOT') or os.path.abspath('.')) + Globals.try_patchmatch = args.patchmatch + + print(f'>> InvokeAI runtime directory is "{Globals.root}"') + + # these two lines prevent a horrible warning message from appearing + # when the frozen CLIP tokenizer is imported + import transformers + transformers.logging.set_verbosity_error() + + # Loading Face Restoration and ESRGAN Modules + gfpgan, codeformer, esrgan = None, None, None + try: + if config.restore or config.esrgan: + from ldm.invoke.restoration import Restoration + restoration = Restoration() + if config.restore: + gfpgan, codeformer = restoration.load_face_restore_models(config.gfpgan_model_path) + else: + print('>> Face restoration disabled') + if config.esrgan: + esrgan = restoration.load_esrgan(config.esrgan_bg_tile) + else: + print('>> Upscaling disabled') + else: + print('>> Face restoration and upscaling disabled') + except (ModuleNotFoundError, ImportError): + print(traceback.format_exc(), file=sys.stderr) + print('>> You may need to install the ESRGAN and/or GFPGAN modules') + + # normalize the config directory relative to root + if not os.path.isabs(config.conf): + config.conf = os.path.normpath(os.path.join(Globals.root,config.conf)) + + if config.embeddings: + if not os.path.isabs(config.embedding_path): + embedding_path = os.path.normpath(os.path.join(Globals.root,config.embedding_path)) + else: + embedding_path = None + + + # TODO: lazy-initialize this by wrapping it + try: + generate = Generate( + conf = config.conf, + model = config.model, + sampler_name = config.sampler_name, + embedding_path = embedding_path, + full_precision = config.full_precision, + precision = config.precision, + gfpgan = gfpgan, + codeformer = codeformer, + esrgan = esrgan, + free_gpu_mem = config.free_gpu_mem, + safety_checker = config.safety_checker, + max_loaded_models = config.max_loaded_models, + ) + except (FileNotFoundError, TypeError, AssertionError): + #emergency_model_reconfigure() # TODO? + sys.exit(-1) + except (IOError, KeyError) as e: + print(f'{e}. Aborting.') + sys.exit(-1) + + generate.free_gpu_mem = config.free_gpu_mem + + return generate diff --git a/ldm/invoke/app/services/graph.py b/ldm/invoke/app/services/graph.py new file mode 100644 index 0000000000..8d1583fc8b --- /dev/null +++ b/ldm/invoke/app/services/graph.py @@ -0,0 +1,797 @@ +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) + +import copy +import itertools +from types import NoneType +import uuid +import networkx as nx +from pydantic import BaseModel, validator +from pydantic.fields import Field +from typing import Any, Literal, Optional, Union, get_args, get_origin, get_type_hints, Annotated + +from .invocation_services import InvocationServices +from ..invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext +from ..invocations import * + + +class EdgeConnection(BaseModel): + node_id: str = Field(description="The id of the node for this edge connection") + field: str = Field(description="The field for this connection") + + def __eq__(self, other): + return (isinstance(other, self.__class__) and + getattr(other, 'node_id', None) == self.node_id and + getattr(other, 'field', None) == self.field) + + def __hash__(self): + return hash(f'{self.node_id}.{self.field}') + + +def get_output_field(node: BaseInvocation, field: str) -> Any: + node_type = type(node) + node_outputs = get_type_hints(node_type.get_output_type()) + node_output_field = node_outputs.get(field) or None + return node_output_field + + +def get_input_field(node: BaseInvocation, field: str) -> Any: + node_type = type(node) + node_inputs = get_type_hints(node_type) + node_input_field = node_inputs.get(field) or None + return node_input_field + + +def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool: + if not from_type: + return False + if not to_type: + return False + + # TODO: this is pretty forgiving on generic types. Clean that up (need to handle optionals and such) + if from_type and to_type: + # Ports are compatible + if (from_type == to_type or + from_type == Any or + to_type == Any or + Any in get_args(from_type) or + Any in get_args(to_type)): + return True + + if from_type in get_args(to_type): + return True + + if to_type in get_args(from_type): + return True + + if not issubclass(from_type, to_type): + return False + else: + return False + + return True + + +def are_connections_compatible( + from_node: BaseInvocation, + from_field: str, + to_node: BaseInvocation, + to_field: str) -> bool: + """Determines if a connection between fields of two nodes is compatible.""" + + # TODO: handle iterators and collectors + from_node_field = get_output_field(from_node, from_field) + to_node_field = get_input_field(to_node, to_field) + + return are_connection_types_compatible(from_node_field, to_node_field) + + +class NodeAlreadyInGraphError(Exception): + pass + + +class InvalidEdgeError(Exception): + pass + +class NodeNotFoundError(Exception): + pass + +class NodeAlreadyExecutedError(Exception): + pass + + +# TODO: Create and use an Empty output? +class GraphInvocationOutput(BaseInvocationOutput): + type: Literal['graph_output'] = 'graph_output' + + +# TODO: Fill this out and move to invocations +class GraphInvocation(BaseInvocation): + type: Literal['graph'] = 'graph' + + # TODO: figure out how to create a default here + graph: 'Graph' = Field(description="The graph to run", default=None) + + def invoke(self, context: InvocationContext) -> GraphInvocationOutput: + """Invoke with provided services and return outputs.""" + return GraphInvocationOutput() + + +class IterateInvocationOutput(BaseInvocationOutput): + """Used to connect iteration outputs. Will be expanded to a specific output.""" + type: Literal['iterate_output'] = 'iterate_output' + + item: Any = Field(description="The item being iterated over") + + +# TODO: Fill this out and move to invocations +class IterateInvocation(BaseInvocation): + type: Literal['iterate'] = 'iterate' + + collection: list[Any] = Field(description="The list of items to iterate over", default_factory=list) + index: int = Field(description="The index, will be provided on executed iterators", default=0) + + def invoke(self, context: InvocationContext) -> IterateInvocationOutput: + """Produces the outputs as values""" + return IterateInvocationOutput(item = self.collection[self.index]) + + +class CollectInvocationOutput(BaseInvocationOutput): + type: Literal['collect_output'] = 'collect_output' + + collection: list[Any] = Field(description="The collection of input items") + + +class CollectInvocation(BaseInvocation): + """Collects values into a collection""" + type: Literal['collect'] = 'collect' + + item: Any = Field(description="The item to collect (all inputs must be of the same type)", default=None) + collection: list[Any] = Field(description="The collection, will be provided on execution", default_factory=list) + + def invoke(self, context: InvocationContext) -> CollectInvocationOutput: + """Invoke with provided services and return outputs.""" + return CollectInvocationOutput(collection = copy.copy(self.collection)) + + +InvocationsUnion = Union[BaseInvocation.get_invocations()] +InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()] + + +class Graph(BaseModel): + id: str = Field(description="The id of this graph", default_factory=uuid.uuid4) + # TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me + nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(description="The nodes in this graph", default_factory=dict) + edges: list[tuple[EdgeConnection,EdgeConnection]] = Field(description="The connections between nodes and their fields in this graph", default_factory=list) + + def add_node(self, node: BaseInvocation) -> None: + """Adds a node to a graph + + :raises NodeAlreadyInGraphError: the node is already present in the graph. + """ + + if node.id in self.nodes: + raise NodeAlreadyInGraphError() + + self.nodes[node.id] = node + + + def _get_graph_and_node(self, node_path: str) -> tuple['Graph', str]: + """Returns the graph and node id for a node path.""" + # Materialized graphs may have nodes at the top level + if node_path in self.nodes: + return (self, node_path) + + node_id = node_path if '.' not in node_path else node_path[:node_path.index('.')] + if node_id not in self.nodes: + raise NodeNotFoundError(f'Node {node_path} not found in graph') + + node = self.nodes[node_id] + + if not isinstance(node, GraphInvocation): + # There's more node path left but this isn't a graph - failure + raise NodeNotFoundError('Node path terminated early at a non-graph node') + + return node.graph._get_graph_and_node(node_path[node_path.index('.')+1:]) + + + def delete_node(self, node_path: str) -> None: + """Deletes a node from a graph""" + + try: + graph, node_id = self._get_graph_and_node(node_path) + + # Delete edges for this node + input_edges = self._get_input_edges_and_graphs(node_path) + output_edges = self._get_output_edges_and_graphs(node_path) + + for edge_graph,_,edge in input_edges: + edge_graph.delete_edge(edge) + + for edge_graph,_,edge in output_edges: + edge_graph.delete_edge(edge) + + del graph.nodes[node_id] + + except NodeNotFoundError: + pass # Ignore, not doesn't exist (should this throw?) + + + def add_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None: + """Adds an edge to a graph + + :raises InvalidEdgeError: the provided edge is invalid. + """ + + if self._is_edge_valid(edge) and edge not in self.edges: + self.edges.append(edge) + else: + raise InvalidEdgeError() + + + def delete_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None: + """Deletes an edge from a graph""" + + try: + self.edges.remove(edge) + except KeyError: + pass + + + def is_valid(self) -> bool: + """Validates the graph.""" + + # Validate all subgraphs + for gn in (n for n in self.nodes.values() if isinstance(n, GraphInvocation)): + if not gn.graph.is_valid(): + return False + + # Validate all edges reference nodes in the graph + node_ids = set([e[0].node_id for e in self.edges]+[e[1].node_id for e in self.edges]) + if not all((self.has_node(node_id) for node_id in node_ids)): + return False + + # Validate there are no cycles + g = self.nx_graph_flat() + if not nx.is_directed_acyclic_graph(g): + return False + + # Validate all edge connections are valid + if not all((are_connections_compatible( + self.get_node(e[0].node_id), e[0].field, + self.get_node(e[1].node_id), e[1].field + ) for e in self.edges)): + return False + + # Validate all iterators + # TODO: may need to validate all iterators in subgraphs so edge connections in parent graphs will be available + if not all((self._is_iterator_connection_valid(n.id) for n in self.nodes.values() if isinstance(n, IterateInvocation))): + return False + + # Validate all collectors + # TODO: may need to validate all collectors in subgraphs so edge connections in parent graphs will be available + if not all((self._is_collector_connection_valid(n.id) for n in self.nodes.values() if isinstance(n, CollectInvocation))): + return False + + return True + + def _is_edge_valid(self, edge: tuple[EdgeConnection, EdgeConnection]) -> bool: + """Validates that a new edge doesn't create a cycle in the graph""" + + # Validate that the nodes exist (edges may contain node paths, so we can't just check for nodes directly) + try: + from_node = self.get_node(edge[0].node_id) + to_node = self.get_node(edge[1].node_id) + except NodeNotFoundError: + return False + + # Validate that an edge to this node+field doesn't already exist + input_edges = self._get_input_edges(edge[1].node_id, edge[1].field) + if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation): + return False + + # Validate that no cycles would be created + g = self.nx_graph_flat() + g.add_edge(edge[0].node_id, edge[1].node_id) + if not nx.is_directed_acyclic_graph(g): + return False + + # Validate that the field types are compatible + if not are_connections_compatible(from_node, edge[0].field, to_node, edge[1].field): + return False + + # Validate if iterator output type matches iterator input type (if this edge results in both being set) + if isinstance(to_node, IterateInvocation) and edge[1].field == 'collection': + if not self._is_iterator_connection_valid(edge[1].node_id, new_input = edge[0]): + return False + + # Validate if iterator input type matches output type (if this edge results in both being set) + if isinstance(from_node, IterateInvocation) and edge[0].field == 'item': + if not self._is_iterator_connection_valid(edge[0].node_id, new_output = edge[1]): + return False + + # Validate if collector input type matches output type (if this edge results in both being set) + if isinstance(to_node, CollectInvocation) and edge[1].field == 'item': + if not self._is_collector_connection_valid(edge[1].node_id, new_input = edge[0]): + return False + + # Validate if collector output type matches input type (if this edge results in both being set) + if isinstance(from_node, CollectInvocation) and edge[0].field == 'collection': + if not self._is_collector_connection_valid(edge[0].node_id, new_output = edge[1]): + return False + + return True + + def has_node(self, node_path: str) -> bool: + """Determines whether or not a node exists in the graph.""" + try: + n = self.get_node(node_path) + if n is not None: + return True + else: + return False + except NodeNotFoundError: + return False + + def get_node(self, node_path: str) -> InvocationsUnion: + """Gets a node from the graph using a node path.""" + # Materialized graphs may have nodes at the top level + graph, node_id = self._get_graph_and_node(node_path) + return graph.nodes[node_id] + + + def _get_node_path(self, node_id: str, prefix: Optional[str] = None) -> str: + return node_id if prefix is None or prefix == '' else f'{prefix}.{node_id}' + + + def update_node(self, node_path: str, new_node: BaseInvocation) -> None: + """Updates a node in the graph.""" + graph, node_id = self._get_graph_and_node(node_path) + node = graph.nodes[node_id] + + # Ensure the node type matches the new node + if type(node) != type(new_node): + raise TypeError(f'Node {node_path} is type {type(node)} but new node is type {type(new_node)}') + + # Ensure the new id is either the same or is not in the graph + prefix = None if '.' not in node_path else node_path[:node_path.rindex('.')] + new_path = self._get_node_path(new_node.id, prefix = prefix) + if new_node.id != node.id and self.has_node(new_path): + raise NodeAlreadyInGraphError('Node with id {new_node.id} already exists in graph') + + # Set the new node in the graph + graph.nodes[new_node.id] = new_node + if new_node.id != node.id: + input_edges = self._get_input_edges_and_graphs(node_path) + output_edges = self._get_output_edges_and_graphs(node_path) + + # Delete node and all edges + graph.delete_node(node_path) + + # Create new edges for each input and output + for graph,_,edge in input_edges: + # Remove the graph prefix from the node path + new_graph_node_path = new_node.id if '.' not in edge[1].node_id else f'{edge[1].node_id[edge[1].node_id.rindex("."):]}.{new_node.id}' + graph.add_edge((edge[0], EdgeConnection(node_id = new_graph_node_path, field = edge[1].field))) + + for graph,_,edge in output_edges: + # Remove the graph prefix from the node path + new_graph_node_path = new_node.id if '.' not in edge[0].node_id else f'{edge[0].node_id[edge[0].node_id.rindex("."):]}.{new_node.id}' + graph.add_edge((EdgeConnection(node_id = new_graph_node_path, field = edge[0].field), edge[1])) + + + def _get_input_edges(self, node_path: str, field: Optional[str] = None) -> list[tuple[EdgeConnection,EdgeConnection]]: + """Gets all input edges for a node""" + edges = self._get_input_edges_and_graphs(node_path) + + # Filter to edges that match the field + filtered_edges = (e for e in edges if field is None or e[2][1].field == field) + + # Create full node paths for each edge + return [(EdgeConnection(node_id = self._get_node_path(e[0].node_id, prefix = prefix), field=e[0].field), EdgeConnection(node_id = self._get_node_path(e[1].node_id, prefix = prefix), field=e[1].field)) for _,prefix,e in filtered_edges] + + + def _get_input_edges_and_graphs(self, node_path: str, prefix: Optional[str] = None) -> list[tuple['Graph', str, tuple[EdgeConnection,EdgeConnection]]]: + """Gets all input edges for a node along with the graph they are in and the graph's path""" + edges = list() + + # Return any input edges that appear in this graph + edges.extend([(self, prefix, e) for e in self.edges if e[1].node_id == node_path]) + + node_id = node_path if '.' not in node_path else node_path[:node_path.index('.')] + node = self.nodes[node_id] + + if isinstance(node, GraphInvocation): + graph = node.graph + graph_path = node.id if prefix is None or prefix == '' else self._get_node_path(node.id, prefix = prefix) + graph_edges = graph._get_input_edges_and_graphs(node_path[(len(node_id)+1):], prefix=graph_path) + edges.extend(graph_edges) + + return edges + + + def _get_output_edges(self, node_path: str, field: str) -> list[tuple[EdgeConnection,EdgeConnection]]: + """Gets all output edges for a node""" + edges = self._get_output_edges_and_graphs(node_path) + + # Filter to edges that match the field + filtered_edges = (e for e in edges if e[2][0].field == field) + + # Create full node paths for each edge + return [(EdgeConnection(node_id = self._get_node_path(e[0].node_id, prefix = prefix), field=e[0].field), EdgeConnection(node_id = self._get_node_path(e[1].node_id, prefix = prefix), field=e[1].field)) for _,prefix,e in filtered_edges] + + + def _get_output_edges_and_graphs(self, node_path: str, prefix: Optional[str] = None) -> list[tuple['Graph', str, tuple[EdgeConnection,EdgeConnection]]]: + """Gets all output edges for a node along with the graph they are in and the graph's path""" + edges = list() + + # Return any input edges that appear in this graph + edges.extend([(self, prefix, e) for e in self.edges if e[0].node_id == node_path]) + + node_id = node_path if '.' not in node_path else node_path[:node_path.index('.')] + node = self.nodes[node_id] + + if isinstance(node, GraphInvocation): + graph = node.graph + graph_path = node.id if prefix is None or prefix == '' else self._get_node_path(node.id, prefix = prefix) + graph_edges = graph._get_output_edges_and_graphs(node_path[(len(node_id)+1):], prefix=graph_path) + edges.extend(graph_edges) + + return edges + + + def _is_iterator_connection_valid(self, node_path: str, new_input: Optional[EdgeConnection] = None, new_output: Optional[EdgeConnection] = None) -> bool: + inputs = list([e[0] for e in self._get_input_edges(node_path, 'collection')]) + outputs = list([e[1] for e in self._get_output_edges(node_path, 'item')]) + + if new_input is not None: + inputs.append(new_input) + if new_output is not None: + outputs.append(new_output) + + # Only one input is allowed for iterators + if len(inputs) > 1: + return False + + # Get input and output fields (the fields linked to the iterator's input/output) + input_field = get_output_field(self.get_node(inputs[0].node_id), inputs[0].field) + output_fields = list([get_input_field(self.get_node(e.node_id), e.field) for e in outputs]) + + # Input type must be a list + if get_origin(input_field) != list: + return False + + # Validate that all outputs match the input type + input_field_item_type = get_args(input_field)[0] + if not all((are_connection_types_compatible(input_field_item_type, f) for f in output_fields)): + return False + + return True + + def _is_collector_connection_valid(self, node_path: str, new_input: Optional[EdgeConnection] = None, new_output: Optional[EdgeConnection] = None) -> bool: + inputs = list([e[0] for e in self._get_input_edges(node_path, 'item')]) + outputs = list([e[1] for e in self._get_output_edges(node_path, 'collection')]) + + if new_input is not None: + inputs.append(new_input) + if new_output is not None: + outputs.append(new_output) + + # Get input and output fields (the fields linked to the iterator's input/output) + input_fields = list([get_output_field(self.get_node(e.node_id), e.field) for e in inputs]) + output_fields = list([get_input_field(self.get_node(e.node_id), e.field) for e in outputs]) + + # Validate that all inputs are derived from or match a single type + input_field_types = set([t for input_field in input_fields for t in ([input_field] if get_origin(input_field) == None else get_args(input_field)) if t != NoneType]) # Get unique types + type_tree = nx.DiGraph() + type_tree.add_nodes_from(input_field_types) + type_tree.add_edges_from([e for e in itertools.permutations(input_field_types, 2) if issubclass(e[1], e[0])]) + type_degrees = type_tree.in_degree(type_tree.nodes) + if sum((t[1] == 0 for t in type_degrees)) != 1: + return False # There is more than one root type + + # Get the input root type + input_root_type = next(t[0] for t in type_degrees if t[1] == 0) + + # Verify that all outputs are lists + if not all((get_origin(f) == list for f in output_fields)): + return False + + # Verify that all outputs match the input type (are a base class or the same class) + if not all((issubclass(input_root_type, get_args(f)[0]) for f in output_fields)): + return False + + return True + + def nx_graph(self) -> nx.DiGraph: + """Returns a NetworkX DiGraph representing the layout of this graph""" + # TODO: Cache this? + g = nx.DiGraph() + g.add_nodes_from([n for n in self.nodes.keys()]) + g.add_edges_from(set([(e[0].node_id, e[1].node_id) for e in self.edges])) + return g + + def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None) -> nx.DiGraph: + """Returns a flattened NetworkX DiGraph, including all subgraphs (but not with iterations expanded)""" + g = nx_graph or nx.DiGraph() + + # Add all nodes from this graph except graph/iteration nodes + g.add_nodes_from([self._get_node_path(n.id, prefix) for n in self.nodes.values() if not isinstance(n, GraphInvocation) and not isinstance(n, IterateInvocation)]) + + # Expand graph nodes + for sgn in (gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation)): + sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix)) + + # TODO: figure out if iteration nodes need to be expanded + + unique_edges = set([(e[0].node_id, e[1].node_id) for e in self.edges]) + g.add_edges_from([(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix)) for e in unique_edges]) + return g + + +class GraphExecutionState(BaseModel): + """Tracks the state of a graph execution""" + id: str = Field(description="The id of the execution state", default_factory=uuid.uuid4) + + # TODO: Store a reference to the graph instead of the actual graph? + graph: Graph = Field(description="The graph being executed") + + # The graph of materialized nodes + execution_graph: Graph = Field(description="The expanded graph of activated and executed nodes", default_factory=Graph) + + # Nodes that have been executed + executed: set[str] = Field(description="The set of node ids that have been executed", default_factory=set) + executed_history: list[str] = Field(description="The list of node ids that have been executed, in order of execution", default_factory=list) + + # The results of executed nodes + results: dict[str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]] = Field(description="The results of node executions", default_factory=dict) + + # Map of prepared/executed nodes to their original nodes + prepared_source_mapping: dict[str, str] = Field(description="The map of prepared nodes to original graph nodes", default_factory=dict) + + # Map of original nodes to prepared nodes + source_prepared_mapping: dict[str, set[str]] = Field(description="The map of original graph nodes to prepared nodes", default_factory=dict) + + def next(self) -> BaseInvocation | None: + """Gets the next node ready to execute.""" + + # TODO: enable multiple nodes to execute simultaneously by tracking currently executing nodes + # possibly with a timeout? + + # If there are no prepared nodes, prepare some nodes + next_node = self._get_next_node() + if next_node is None: + prepared_id = self._prepare() + + # TODO: prepare multiple nodes at once? + # while prepared_id is not None and not isinstance(self.graph.nodes[prepared_id], IterateInvocation): + # prepared_id = self._prepare() + + if prepared_id is not None: + next_node = self._get_next_node() + + # Get values from edges + if next_node is not None: + self._prepare_inputs(next_node) + + # If next is still none, there's no next node, return None + return next_node + + def complete(self, node_id: str, output: InvocationOutputsUnion): + """Marks a node as complete""" + + if node_id not in self.execution_graph.nodes: + return # TODO: log error? + + # Mark node as executed + self.executed.add(node_id) + self.results[node_id] = output + + # Check if source node is complete (all prepared nodes are complete) + source_node = self.prepared_source_mapping[node_id] + prepared_nodes = self.source_prepared_mapping[source_node] + + if all([n in self.executed for n in prepared_nodes]): + self.executed.add(source_node) + self.executed_history.append(source_node) + + def is_complete(self) -> bool: + """Returns true if the graph is complete""" + return all((k in self.executed for k in self.graph.nodes)) + + def _create_execution_node(self, node_path: str, iteration_node_map: list[tuple[str, str]]) -> list[str]: + """Prepares an iteration node and connects all edges, returning the new node id""" + + node = self.graph.get_node(node_path) + + self_iteration_count = -1 + + # If this is an iterator node, we must create a copy for each iteration + if isinstance(node, IterateInvocation): + # Get input collection edge (should error if there are no inputs) + input_collection_edge = next(iter(self.graph._get_input_edges(node_path, 'collection'))) + input_collection_prepared_node_id = next(n[1] for n in iteration_node_map if n[0] == input_collection_edge[0].node_id) + input_collection_prepared_node_output = self.results[input_collection_prepared_node_id] + input_collection = getattr(input_collection_prepared_node_output, input_collection_edge[0].field) + self_iteration_count = len(input_collection) + + new_nodes = list() + if self_iteration_count == 0: + # TODO: should this raise a warning? It might just happen if an empty collection is input, and should be valid. + return new_nodes + + # Get all input edges + input_edges = self.graph._get_input_edges(node_path) + + # Create new edges for this iteration + # For collect nodes, this may contain multiple inputs to the same field + new_edges = list() + for edge in input_edges: + for input_node_id in (n[1] for n in iteration_node_map if n[0] == edge[0].node_id): + new_edge = (EdgeConnection(node_id = input_node_id, field = edge[0].field), EdgeConnection(node_id = '', field = edge[1].field)) + new_edges.append(new_edge) + + # Create a new node (or one for each iteration of this iterator) + for i in (range(self_iteration_count) if self_iteration_count > 0 else [-1]): + # Create a new node + new_node = copy.deepcopy(node) + + # Create the node id (use a random uuid) + new_node.id = str(uuid.uuid4()) + + # Set the iteration index for iteration invocations + if isinstance(new_node, IterateInvocation): + new_node.index = i + + # Add to execution graph + self.execution_graph.add_node(new_node) + self.prepared_source_mapping[new_node.id] = node_path + if node_path not in self.source_prepared_mapping: + self.source_prepared_mapping[node_path] = set() + self.source_prepared_mapping[node_path].add(new_node.id) + + # Add new edges to execution graph + for edge in new_edges: + new_edge = (edge[0], EdgeConnection(node_id = new_node.id, field = edge[1].field)) + self.execution_graph.add_edge(new_edge) + + new_nodes.append(new_node.id) + + return new_nodes + + def _iterator_graph(self) -> nx.DiGraph: + """Gets a DiGraph with edges to collectors removed so an ancestor search produces all active iterators for any node""" + g = self.graph.nx_graph() + collectors = (n for n in self.graph.nodes if isinstance(self.graph.nodes[n], CollectInvocation)) + for c in collectors: + g.remove_edges_from(list(g.in_edges(c))) + return g + + + def _get_node_iterators(self, node_id: str) -> list[str]: + """Gets iterators for a node""" + g = self._iterator_graph() + iterators = [n for n in nx.ancestors(g, node_id) if isinstance(self.graph.nodes[n], IterateInvocation)] + return iterators + + + def _prepare(self) -> Optional[str]: + # Get flattened source graph + g = self.graph.nx_graph_flat() + + # Find next unprepared node where all source nodes are executed + sorted_nodes = nx.topological_sort(g) + next_node_id = next((n for n in sorted_nodes if n not in self.source_prepared_mapping and all((e[0] in self.executed for e in g.in_edges(n)))), None) + + if next_node_id == None: + return None + + # Get all parents of the next node + next_node_parents = [e[0] for e in g.in_edges(next_node_id)] + + # Create execution nodes + next_node = self.graph.get_node(next_node_id) + new_node_ids = list() + if isinstance(next_node, CollectInvocation): + # Collapse all iterator input mappings and create a single execution node for the collect invocation + all_iteration_mappings = list(itertools.chain(*(((s,p) for p in self.source_prepared_mapping[s]) for s in next_node_parents))) + #all_iteration_mappings = list(set(itertools.chain(*prepared_parent_mappings))) + create_results = self._create_execution_node(next_node_id, all_iteration_mappings) + if create_results is not None: + new_node_ids.extend(create_results) + else: # Iterators or normal nodes + # Get all iterator combinations for this node + # Will produce a list of lists of prepared iterator nodes, from which results can be iterated + iterator_nodes = self._get_node_iterators(next_node_id) + iterator_nodes_prepared = [list(self.source_prepared_mapping[n]) for n in iterator_nodes] + iterator_node_prepared_combinations = list(itertools.product(*iterator_nodes_prepared)) + + # Select the correct prepared parents for each iteration + # For every iterator, the parent must either not be a child of that iterator, or must match the prepared iteration for that iterator + # TODO: Handle a node mapping to none + eg = self.execution_graph.nx_graph_flat() + prepared_parent_mappings = [[(n,self._get_iteration_node(n, g, eg, it)) for n in next_node_parents] for it in iterator_node_prepared_combinations] + + # Create execution node for each iteration + for iteration_mappings in prepared_parent_mappings: + create_results = self._create_execution_node(next_node_id, iteration_mappings) + if create_results is not None: + new_node_ids.extend(create_results) + + return next(iter(new_node_ids), None) + + def _get_iteration_node(self, source_node_path: str, graph: nx.DiGraph, execution_graph: nx.DiGraph, prepared_iterator_nodes: list[str]) -> Optional[str]: + """Gets the prepared version of the specified source node that matches every iteration specified""" + prepared_nodes = self.source_prepared_mapping[source_node_path] + if len(prepared_nodes) == 1: + return next(iter(prepared_nodes)) + + # Check if the requested node is an iterator + prepared_iterator = next((n for n in prepared_nodes if n in prepared_iterator_nodes), None) + if prepared_iterator is not None: + return prepared_iterator + + # Filter to only iterator nodes that are a parent of the specified node, in tuple format (prepared, source) + iterator_source_node_mapping = [(n, self.prepared_source_mapping[n]) for n in prepared_iterator_nodes] + parent_iterators = [itn for itn in iterator_source_node_mapping if nx.has_path(graph, itn[1], source_node_path)] + + return next((n for n in prepared_nodes if all(pit for pit in parent_iterators if nx.has_path(execution_graph, pit[0], n))), None) + + def _get_next_node(self) -> Optional[BaseInvocation]: + g = self.execution_graph.nx_graph() + sorted_nodes = nx.topological_sort(g) + next_node = next((n for n in sorted_nodes if n not in self.executed), None) + if next_node is None: + return None + + return self.execution_graph.nodes[next_node] + + def _prepare_inputs(self, node: BaseInvocation): + input_edges = [e for e in self.execution_graph.edges if e[1].node_id == node.id] + if isinstance(node, CollectInvocation): + output_collection = [getattr(self.results[edge[0].node_id], edge[0].field) for edge in input_edges if edge[1].field == 'item'] + setattr(node, 'collection', output_collection) + else: + for edge in input_edges: + output_value = getattr(self.results[edge[0].node_id], edge[0].field) + setattr(node, edge[1].field, output_value) + + # TODO: Add API for modifying underlying graph that checks if the change will be valid given the current execution state + def _is_edge_valid(self, edge: tuple[EdgeConnection, EdgeConnection]) -> bool: + if not self._is_edge_valid(edge): + return False + + # Invalid if destination has already been prepared or executed + if edge[1].node_id in self.source_prepared_mapping: + return False + + # Otherwise, the edge is valid + return True + + def _is_node_updatable(self, node_id: str) -> bool: + # The node is updatable as long as it hasn't been prepared or executed + return node_id not in self.source_prepared_mapping + + def add_node(self, node: BaseInvocation) -> None: + self.graph.add_node(node) + + def update_node(self, node_path: str, new_node: BaseInvocation) -> None: + if not self._is_node_updatable(node_path): + raise NodeAlreadyExecutedError(f'Node {node_path} has already been prepared or executed and cannot be updated') + self.graph.update_node(node_path, new_node) + + def delete_node(self, node_path: str) -> None: + if not self._is_node_updatable(node_path): + raise NodeAlreadyExecutedError(f'Node {node_path} has already been prepared or executed and cannot be deleted') + self.graph.delete_node(node_path) + + def add_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None: + if not self._is_node_updatable(edge[1].node_id): + raise NodeAlreadyExecutedError(f'Destination node {edge[1].node_id} has already been prepared or executed and cannot be linked to') + self.graph.add_edge(edge) + + def delete_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None: + if not self._is_node_updatable(edge[1].node_id): + raise NodeAlreadyExecutedError(f'Destination node {edge[1].node_id} has already been prepared or executed and cannot have a source edge deleted') + self.graph.delete_edge(edge) + +GraphInvocation.update_forward_refs() diff --git a/ldm/invoke/app/services/image_storage.py b/ldm/invoke/app/services/image_storage.py new file mode 100644 index 0000000000..03227d870b --- /dev/null +++ b/ldm/invoke/app/services/image_storage.py @@ -0,0 +1,104 @@ +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) + +from abc import ABC, abstractmethod +from enum import Enum +import datetime +import os +from pathlib import Path +from queue import Queue +from typing import Dict +from PIL.Image import Image +from ...pngwriter import PngWriter + + +class ImageType(str, Enum): + RESULT = 'results' + INTERMEDIATE = 'intermediates' + UPLOAD = 'uploads' + + +class ImageStorageBase(ABC): + """Responsible for storing and retrieving images.""" + + @abstractmethod + def get(self, image_type: ImageType, image_name: str) -> Image: + pass + + # TODO: make this a bit more flexible for e.g. cloud storage + @abstractmethod + def get_path(self, image_type: ImageType, image_name: str) -> str: + pass + + @abstractmethod + def save(self, image_type: ImageType, image_name: str, image: Image) -> None: + pass + + @abstractmethod + def delete(self, image_type: ImageType, image_name: str) -> None: + pass + + def create_name(self, context_id: str, node_id: str) -> str: + return f'{context_id}_{node_id}_{str(int(datetime.datetime.now(datetime.timezone.utc).timestamp()))}.png' + + +class DiskImageStorage(ImageStorageBase): + """Stores images on disk""" + __output_folder: str + __pngWriter: PngWriter + __cache_ids: Queue # TODO: this is an incredibly naive cache + __cache: Dict[str, Image] + __max_cache_size: int + + def __init__(self, output_folder: str): + self.__output_folder = output_folder + self.__pngWriter = PngWriter(output_folder) + self.__cache = dict() + self.__cache_ids = Queue() + self.__max_cache_size = 10 # TODO: get this from config + + Path(output_folder).mkdir(parents=True, exist_ok=True) + + # TODO: don't hard-code. get/save/delete should maybe take subpath? + for image_type in ImageType: + Path(os.path.join(output_folder, image_type)).mkdir(parents=True, exist_ok=True) + + def get(self, image_type: ImageType, image_name: str) -> Image: + image_path = self.get_path(image_type, image_name) + cache_item = self.__get_cache(image_path) + if cache_item: + return cache_item + + image = Image.open(image_path) + self.__set_cache(image_path, image) + return image + + # TODO: make this a bit more flexible for e.g. cloud storage + def get_path(self, image_type: ImageType, image_name: str) -> str: + path = os.path.join(self.__output_folder, image_type, image_name) + return path + + def save(self, image_type: ImageType, image_name: str, image: Image) -> None: + image_subpath = os.path.join(image_type, image_name) + self.__pngWriter.save_image_and_prompt_to_png(image, "", image_subpath, None) # TODO: just pass full path to png writer + + image_path = self.get_path(image_type, image_name) + self.__set_cache(image_path, image) + + def delete(self, image_type: ImageType, image_name: str) -> None: + image_path = self.get_path(image_type, image_name) + if os.path.exists(image_path): + os.remove(image_path) + + if image_path in self.__cache: + del self.__cache[image_path] + + def __get_cache(self, image_name: str) -> Image: + return None if image_name not in self.__cache else self.__cache[image_name] + + def __set_cache(self, image_name: str, image: Image): + if not image_name in self.__cache: + self.__cache[image_name] = image + self.__cache_ids.put(image_name) # TODO: this should refresh position for LRU cache + if len(self.__cache) > self.__max_cache_size: + cache_id = self.__cache_ids.get() + del self.__cache[cache_id] diff --git a/ldm/invoke/app/services/invocation_queue.py b/ldm/invoke/app/services/invocation_queue.py new file mode 100644 index 0000000000..0a5b5ae3bb --- /dev/null +++ b/ldm/invoke/app/services/invocation_queue.py @@ -0,0 +1,46 @@ +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) + +from abc import ABC, abstractmethod +from queue import Queue + + +# TODO: make this serializable +class InvocationQueueItem: + #session_id: str + graph_execution_state_id: str + invocation_id: str + invoke_all: bool + + def __init__(self, + #session_id: str, + graph_execution_state_id: str, + invocation_id: str, + invoke_all: bool = False): + #self.session_id = session_id + self.graph_execution_state_id = graph_execution_state_id + self.invocation_id = invocation_id + self.invoke_all = invoke_all + + +class InvocationQueueABC(ABC): + """Abstract base class for all invocation queues""" + @abstractmethod + def get(self) -> InvocationQueueItem: + pass + + @abstractmethod + def put(self, item: InvocationQueueItem|None) -> None: + pass + + +class MemoryInvocationQueue(InvocationQueueABC): + __queue: Queue + + def __init__(self): + self.__queue = Queue() + + def get(self) -> InvocationQueueItem: + return self.__queue.get() + + def put(self, item: InvocationQueueItem|None) -> None: + self.__queue.put(item) diff --git a/ldm/invoke/app/services/invocation_services.py b/ldm/invoke/app/services/invocation_services.py new file mode 100644 index 0000000000..9eb5309d3d --- /dev/null +++ b/ldm/invoke/app/services/invocation_services.py @@ -0,0 +1,20 @@ +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) +from .image_storage import ImageStorageBase +from .events import EventServiceBase +from ....generate import Generate + + +class InvocationServices(): + """Services that can be used by invocations""" + generate: Generate # TODO: wrap Generate, or split it up from model? + events: EventServiceBase + images: ImageStorageBase + + def __init__(self, + generate: Generate, + events: EventServiceBase, + images: ImageStorageBase + ): + self.generate = generate + self.events = events + self.images = images diff --git a/ldm/invoke/app/services/invoker.py b/ldm/invoke/app/services/invoker.py new file mode 100644 index 0000000000..796f541781 --- /dev/null +++ b/ldm/invoke/app/services/invoker.py @@ -0,0 +1,109 @@ +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) + +from abc import ABC +from threading import Event, Thread +from .graph import Graph, GraphExecutionState +from .item_storage import ItemStorageABC +from ..invocations.baseinvocation import InvocationContext +from .invocation_services import InvocationServices +from .invocation_queue import InvocationQueueABC, InvocationQueueItem + + +class InvokerServices: + """Services used by the Invoker for execution""" + + queue: InvocationQueueABC + graph_execution_manager: ItemStorageABC[GraphExecutionState] + processor: 'InvocationProcessorABC' + + def __init__(self, + queue: InvocationQueueABC, + graph_execution_manager: ItemStorageABC[GraphExecutionState], + processor: 'InvocationProcessorABC'): + self.queue = queue + self.graph_execution_manager = graph_execution_manager + self.processor = processor + + +class Invoker: + """The invoker, used to execute invocations""" + + services: InvocationServices + invoker_services: InvokerServices + + def __init__(self, + services: InvocationServices, # Services used by nodes to perform invocations + invoker_services: InvokerServices # Services used by the invoker for orchestration + ): + self.services = services + self.invoker_services = invoker_services + self._start() + + + def invoke(self, graph_execution_state: GraphExecutionState, invoke_all: bool = False) -> str|None: + """Determines the next node to invoke and returns the id of the invoked node, or None if there are no nodes to execute""" + + # Get the next invocation + invocation = graph_execution_state.next() + if not invocation: + return None + + # Save the execution state + self.invoker_services.graph_execution_manager.set(graph_execution_state) + + # Queue the invocation + print(f'queueing item {invocation.id}') + self.invoker_services.queue.put(InvocationQueueItem( + #session_id = session.id, + graph_execution_state_id = graph_execution_state.id, + invocation_id = invocation.id, + invoke_all = invoke_all + )) + + return invocation.id + + + def create_execution_state(self, graph: Graph|None = None) -> GraphExecutionState: + """Creates a new execution state for the given graph""" + new_state = GraphExecutionState(graph = Graph() if graph is None else graph) + self.invoker_services.graph_execution_manager.set(new_state) + return new_state + + + def __start_service(self, service) -> None: + # Call start() method on any services that have it + start_op = getattr(service, 'start', None) + if callable(start_op): + start_op(self) + + + def __stop_service(self, service) -> None: + # Call stop() method on any services that have it + stop_op = getattr(service, 'stop', None) + if callable(stop_op): + stop_op(self) + + + def _start(self) -> None: + """Starts the invoker. This is called automatically when the invoker is created.""" + for service in vars(self.invoker_services): + self.__start_service(getattr(self.invoker_services, service)) + + for service in vars(self.services): + self.__start_service(getattr(self.services, service)) + + + def stop(self) -> None: + """Stops the invoker. A new invoker will have to be created to execute further.""" + # First stop all services + for service in vars(self.services): + self.__stop_service(getattr(self.services, service)) + + for service in vars(self.invoker_services): + self.__stop_service(getattr(self.invoker_services, service)) + + self.invoker_services.queue.put(None) + + +class InvocationProcessorABC(ABC): + pass \ No newline at end of file diff --git a/ldm/invoke/app/services/item_storage.py b/ldm/invoke/app/services/item_storage.py new file mode 100644 index 0000000000..738f06cb7e --- /dev/null +++ b/ldm/invoke/app/services/item_storage.py @@ -0,0 +1,57 @@ + +from typing import Callable, TypeVar, Generic +from pydantic import BaseModel, Field +from pydantic.generics import GenericModel +from abc import ABC, abstractmethod + +T = TypeVar('T', bound=BaseModel) + +class PaginatedResults(GenericModel, Generic[T]): + """Paginated results""" + items: list[T] = Field(description = "Items") + page: int = Field(description = "Current Page") + pages: int = Field(description = "Total number of pages") + per_page: int = Field(description = "Number of items per page") + total: int = Field(description = "Total number of items in result") + + +class ItemStorageABC(ABC, Generic[T]): + _on_changed_callbacks: list[Callable[[T], None]] + _on_deleted_callbacks: list[Callable[[str], None]] + + def __init__(self) -> None: + self._on_changed_callbacks = list() + self._on_deleted_callbacks = list() + + """Base item storage class""" + @abstractmethod + def get(self, item_id: str) -> T: + pass + + @abstractmethod + def set(self, item: T) -> None: + pass + + @abstractmethod + def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]: + pass + + @abstractmethod + def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]: + pass + + def on_changed(self, on_changed: Callable[[T], None]) -> None: + """Register a callback for when an item is changed""" + self._on_changed_callbacks.append(on_changed) + + def on_deleted(self, on_deleted: Callable[[str], None]) -> None: + """Register a callback for when an item is deleted""" + self._on_deleted_callbacks.append(on_deleted) + + def _on_changed(self, item: T) -> None: + for callback in self._on_changed_callbacks: + callback(item) + + def _on_deleted(self, item_id: str) -> None: + for callback in self._on_deleted_callbacks: + callback(item_id) diff --git a/ldm/invoke/app/services/processor.py b/ldm/invoke/app/services/processor.py new file mode 100644 index 0000000000..9b51a6bcbc --- /dev/null +++ b/ldm/invoke/app/services/processor.py @@ -0,0 +1,78 @@ +from threading import Event, Thread +from ..invocations.baseinvocation import InvocationContext +from .invocation_queue import InvocationQueueItem +from .invoker import InvocationProcessorABC, Invoker + + +class DefaultInvocationProcessor(InvocationProcessorABC): + __invoker_thread: Thread + __stop_event: Event + __invoker: Invoker + + def start(self, invoker) -> None: + self.__invoker = invoker + self.__stop_event = Event() + self.__invoker_thread = Thread( + name = "invoker_processor", + target = self.__process, + kwargs = dict(stop_event = self.__stop_event) + ) + self.__invoker_thread.daemon = True # TODO: probably better to just not use threads? + self.__invoker_thread.start() + + + def stop(self, *args, **kwargs) -> None: + self.__stop_event.set() + + + def __process(self, stop_event: Event): + try: + while not stop_event.is_set(): + queue_item: InvocationQueueItem = self.__invoker.invoker_services.queue.get() + if not queue_item: # Probably stopping + continue + + graph_execution_state = self.__invoker.invoker_services.graph_execution_manager.get(queue_item.graph_execution_state_id) + invocation = graph_execution_state.execution_graph.get_node(queue_item.invocation_id) + + # Send starting event + self.__invoker.services.events.emit_invocation_started( + graph_execution_state_id = graph_execution_state.id, + invocation_id = invocation.id + ) + + # Invoke + try: + outputs = invocation.invoke(InvocationContext( + services = self.__invoker.services, + graph_execution_state_id = graph_execution_state.id + )) + + # Save outputs and history + graph_execution_state.complete(invocation.id, outputs) + + # Save the state changes + self.__invoker.invoker_services.graph_execution_manager.set(graph_execution_state) + + # Send complete event + self.__invoker.services.events.emit_invocation_complete( + graph_execution_state_id = graph_execution_state.id, + invocation_id = invocation.id, + result = outputs.dict() + ) + + # Queue any further commands if invoking all + is_complete = graph_execution_state.is_complete() + if queue_item.invoke_all and not is_complete: + self.__invoker.invoke(graph_execution_state, invoke_all = True) + elif is_complete: + self.__invoker.services.events.emit_graph_execution_complete(graph_execution_state.id) + except KeyboardInterrupt: + pass + except Exception as e: + # TODO: Log the error, mark the invocation as failed, and emit an event + print(f'Error invoking {invocation.id}: {e}') + pass + + except KeyboardInterrupt: + ... # Log something? diff --git a/ldm/invoke/app/services/sqlite.py b/ldm/invoke/app/services/sqlite.py new file mode 100644 index 0000000000..8858bbd874 --- /dev/null +++ b/ldm/invoke/app/services/sqlite.py @@ -0,0 +1,119 @@ +import sqlite3 +from threading import Lock +from typing import Generic, TypeVar, Union, get_args +from pydantic import BaseModel, parse_raw_as +from .item_storage import ItemStorageABC, PaginatedResults + +T = TypeVar('T', bound=BaseModel) + +sqlite_memory = ':memory:' + +class SqliteItemStorage(ItemStorageABC, Generic[T]): + _filename: str + _table_name: str + _conn: sqlite3.Connection + _cursor: sqlite3.Cursor + _id_field: str + _lock: Lock + + def __init__(self, filename: str, table_name: str, id_field: str = 'id'): + super().__init__() + + self._filename = filename + self._table_name = table_name + self._id_field = id_field # TODO: validate that T has this field + self._lock = Lock() + + self._conn = sqlite3.connect(self._filename, check_same_thread=False) # TODO: figure out a better threading solution + self._cursor = self._conn.cursor() + + self._create_table() + + def _create_table(self): + try: + self._lock.acquire() + self._cursor.execute(f'''CREATE TABLE IF NOT EXISTS {self._table_name} ( + item TEXT, + id TEXT GENERATED ALWAYS AS (json_extract(item, '$.{self._id_field}')) VIRTUAL NOT NULL);''') + self._cursor.execute(f'''CREATE UNIQUE INDEX IF NOT EXISTS {self._table_name}_id ON {self._table_name}(id);''') + finally: + self._lock.release() + + def _parse_item(self, item: str) -> T: + item_type = get_args(self.__orig_class__)[0] + return parse_raw_as(item_type, item) + + def set(self, item: T): + try: + self._lock.acquire() + self._cursor.execute(f'''INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);''', (item.json(),)) + finally: + self._lock.release() + self._on_changed(item) + + def get(self, id: str) -> Union[T, None]: + try: + self._lock.acquire() + self._cursor.execute(f'''SELECT item FROM {self._table_name} WHERE id = ?;''', (str(id),)) + result = self._cursor.fetchone() + finally: + self._lock.release() + + if not result: + return None + + return self._parse_item(result[0]) + + def delete(self, id: str): + try: + self._lock.acquire() + self._cursor.execute(f'''DELETE FROM {self._table_name} WHERE id = ?;''', (str(id),)) + finally: + self._lock.release() + self._on_deleted(id) + + def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]: + try: + self._lock.acquire() + self._cursor.execute(f'''SELECT item FROM {self._table_name} LIMIT ? OFFSET ?;''', (per_page, page * per_page)) + result = self._cursor.fetchall() + + items = list(map(lambda r: self._parse_item(r[0]), result)) + + self._cursor.execute(f'''SELECT count(*) FROM {self._table_name};''') + count = self._cursor.fetchone()[0] + finally: + self._lock.release() + + pageCount = int(count / per_page) + 1 + + return PaginatedResults[T]( + items = items, + page = page, + pages = pageCount, + per_page = per_page, + total = count + ) + + def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]: + try: + self._lock.acquire() + self._cursor.execute(f'''SELECT item FROM {self._table_name} WHERE item LIKE ? LIMIT ? OFFSET ?;''', (f'%{query}%', per_page, page * per_page)) + result = self._cursor.fetchall() + + items = list(map(lambda r: self._parse_item(r[0]), result)) + + self._cursor.execute(f'''SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;''', (f'%{query}%',)) + count = self._cursor.fetchone()[0] + finally: + self._lock.release() + + pageCount = int(count / per_page) + 1 + + return PaginatedResults[T]( + items = items, + page = page, + pages = pageCount, + per_page = per_page, + total = count + ) diff --git a/pyproject.toml b/pyproject.toml index b544b9eb9c..bfa36ff7d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,9 @@ dependencies = [ "einops", "eventlet", "facexlib", + "fastapi==0.85.0", + "fastapi-events==0.6.0", + "fastapi-socketio==0.0.9", "flask==2.1.3", "flask_cors==3.0.10", "flask_socketio==5.3.0", @@ -60,6 +63,7 @@ dependencies = [ "pudb", "pypatchmatch", "pyreadline3", + "python-multipart==0.0.5", "pytorch-lightning==1.7.7", "realesrgan", "requests==2.28.2", @@ -74,6 +78,7 @@ dependencies = [ "torchmetrics", "torchvision>=0.14.1", "transformers~=4.25", + "uvicorn[standard]==0.20.0", "windows-curses; sys_platform=='win32'", ] description = "An implementation of Stable Diffusion which provides various new features and options to aid the image generation process" diff --git a/scripts/invoke-new.py b/scripts/invoke-new.py new file mode 100644 index 0000000000..2bc5330a5c --- /dev/null +++ b/scripts/invoke-new.py @@ -0,0 +1,20 @@ +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) + +import os +import sys + +def main(): + # Change working directory to the repo root + os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + + if '--web' in sys.argv: + from ldm.invoke.app.api_app import invoke_api + invoke_api() + else: + # TODO: Parse some top-level args here. + from ldm.invoke.app.cli_app import invoke_cli + invoke_cli() + + +if __name__ == '__main__': + main() diff --git a/static/dream_web/test.html b/static/dream_web/test.html new file mode 100644 index 0000000000..e99abb3703 --- /dev/null +++ b/static/dream_web/test.html @@ -0,0 +1,206 @@ + + + + InvokeAI Test + + + + + + + + + + + + + + + +
+ +
+ + + + + \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/nodes/__init__.py b/tests/nodes/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/nodes/test_graph_execution_state.py b/tests/nodes/test_graph_execution_state.py new file mode 100644 index 0000000000..0a5dcc7734 --- /dev/null +++ b/tests/nodes/test_graph_execution_state.py @@ -0,0 +1,114 @@ +from .test_invoker import create_edge +from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation +from ldm.invoke.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext +from ldm.invoke.app.services.invocation_services import InvocationServices +from ldm.invoke.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState +from ldm.invoke.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation +from ldm.invoke.app.invocations.upscale import UpscaleInvocation +import pytest + + +@pytest.fixture +def simple_graph(): + g = Graph() + g.add_node(PromptTestInvocation(id = "1", prompt = "Banana sushi")) + g.add_node(ImageTestInvocation(id = "2")) + g.add_edge(create_edge("1", "prompt", "2", "prompt")) + return g + +@pytest.fixture +def mock_services(): + # NOTE: none of these are actually called by the test invocations + return InvocationServices(generate = None, events = None, images = None) + +def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]: + n = g.next() + if n is None: + return (None, None) + + print(f'invoking {n.id}: {type(n)}') + o = n.invoke(InvocationContext(services, "1")) + g.complete(n.id, o) + + return (n, o) + +def test_graph_state_executes_in_order(simple_graph, mock_services): + g = GraphExecutionState(graph = simple_graph) + + n1 = invoke_next(g, mock_services) + n2 = invoke_next(g, mock_services) + n3 = g.next() + + assert g.prepared_source_mapping[n1[0].id] == "1" + assert g.prepared_source_mapping[n2[0].id] == "2" + assert n3 is None + assert g.results[n1[0].id].prompt == n1[0].prompt + assert n2[0].prompt == n1[0].prompt + +def test_graph_is_complete(simple_graph, mock_services): + g = GraphExecutionState(graph = simple_graph) + n1 = invoke_next(g, mock_services) + n2 = invoke_next(g, mock_services) + n3 = g.next() + + assert g.is_complete() + +def test_graph_is_not_complete(simple_graph, mock_services): + g = GraphExecutionState(graph = simple_graph) + n1 = invoke_next(g, mock_services) + n2 = g.next() + + assert not g.is_complete() + +# TODO: test completion with iterators/subgraphs + +def test_graph_state_expands_iterator(mock_services): + graph = Graph() + test_prompts = ["Banana sushi", "Cat sushi"] + graph.add_node(PromptCollectionTestInvocation(id = "1", collection = list(test_prompts))) + graph.add_node(IterateInvocation(id = "2")) + graph.add_node(ImageTestInvocation(id = "3")) + graph.add_edge(create_edge("1", "collection", "2", "collection")) + graph.add_edge(create_edge("2", "item", "3", "prompt")) + + g = GraphExecutionState(graph = graph) + n1 = invoke_next(g, mock_services) + n2 = invoke_next(g, mock_services) + n3 = invoke_next(g, mock_services) + n4 = invoke_next(g, mock_services) + n5 = invoke_next(g, mock_services) + + assert g.prepared_source_mapping[n1[0].id] == "1" + assert g.prepared_source_mapping[n2[0].id] == "2" + assert g.prepared_source_mapping[n3[0].id] == "2" + assert g.prepared_source_mapping[n4[0].id] == "3" + assert g.prepared_source_mapping[n5[0].id] == "3" + + assert isinstance(n4[0], ImageTestInvocation) + assert isinstance(n5[0], ImageTestInvocation) + + prompts = [n4[0].prompt, n5[0].prompt] + assert sorted(prompts) == sorted(test_prompts) + +def test_graph_state_collects(mock_services): + graph = Graph() + test_prompts = ["Banana sushi", "Cat sushi"] + graph.add_node(PromptCollectionTestInvocation(id = "1", collection = list(test_prompts))) + graph.add_node(IterateInvocation(id = "2")) + graph.add_node(PromptTestInvocation(id = "3")) + graph.add_node(CollectInvocation(id = "4")) + graph.add_edge(create_edge("1", "collection", "2", "collection")) + graph.add_edge(create_edge("2", "item", "3", "prompt")) + graph.add_edge(create_edge("3", "prompt", "4", "item")) + + g = GraphExecutionState(graph = graph) + n1 = invoke_next(g, mock_services) + n2 = invoke_next(g, mock_services) + n3 = invoke_next(g, mock_services) + n4 = invoke_next(g, mock_services) + n5 = invoke_next(g, mock_services) + n6 = invoke_next(g, mock_services) + + assert isinstance(n6[0], CollectInvocation) + + assert sorted(g.results[n6[0].id].collection) == sorted(test_prompts) diff --git a/tests/nodes/test_invoker.py b/tests/nodes/test_invoker.py new file mode 100644 index 0000000000..a6d96f61c0 --- /dev/null +++ b/tests/nodes/test_invoker.py @@ -0,0 +1,85 @@ +from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation, TestEventService, create_edge, wait_until +from ldm.invoke.app.services.processor import DefaultInvocationProcessor +from ldm.invoke.app.services.sqlite import SqliteItemStorage, sqlite_memory +from ldm.invoke.app.services.invocation_queue import MemoryInvocationQueue +from ldm.invoke.app.services.invoker import Invoker, InvokerServices +from ldm.invoke.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext +from ldm.invoke.app.services.invocation_services import InvocationServices +from ldm.invoke.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState +from ldm.invoke.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation +from ldm.invoke.app.invocations.upscale import UpscaleInvocation +import pytest + + +@pytest.fixture +def simple_graph(): + g = Graph() + g.add_node(PromptTestInvocation(id = "1", prompt = "Banana sushi")) + g.add_node(ImageTestInvocation(id = "2")) + g.add_edge(create_edge("1", "prompt", "2", "prompt")) + return g + +@pytest.fixture +def mock_services() -> InvocationServices: + # NOTE: none of these are actually called by the test invocations + return InvocationServices(generate = None, events = TestEventService(), images = None) + +@pytest.fixture() +def mock_invoker_services() -> InvokerServices: + return InvokerServices( + queue = MemoryInvocationQueue(), + graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'), + processor = DefaultInvocationProcessor() + ) + +@pytest.fixture() +def mock_invoker(mock_services: InvocationServices, mock_invoker_services: InvokerServices) -> Invoker: + return Invoker( + services = mock_services, + invoker_services = mock_invoker_services + ) + +def test_can_create_graph_state(mock_invoker: Invoker): + g = mock_invoker.create_execution_state() + mock_invoker.stop() + + assert g is not None + assert isinstance(g, GraphExecutionState) + +def test_can_create_graph_state_from_graph(mock_invoker: Invoker, simple_graph): + g = mock_invoker.create_execution_state(graph = simple_graph) + mock_invoker.stop() + + assert g is not None + assert isinstance(g, GraphExecutionState) + assert g.graph == simple_graph + +def test_can_invoke(mock_invoker: Invoker, simple_graph): + g = mock_invoker.create_execution_state(graph = simple_graph) + invocation_id = mock_invoker.invoke(g) + assert invocation_id is not None + + def has_executed_any(g: GraphExecutionState): + g = mock_invoker.invoker_services.graph_execution_manager.get(g.id) + return len(g.executed) > 0 + + wait_until(lambda: has_executed_any(g), timeout = 5, interval = 1) + mock_invoker.stop() + + g = mock_invoker.invoker_services.graph_execution_manager.get(g.id) + assert len(g.executed) > 0 + +def test_can_invoke_all(mock_invoker: Invoker, simple_graph): + g = mock_invoker.create_execution_state(graph = simple_graph) + invocation_id = mock_invoker.invoke(g, invoke_all = True) + assert invocation_id is not None + + def has_executed_all(g: GraphExecutionState): + g = mock_invoker.invoker_services.graph_execution_manager.get(g.id) + return g.is_complete() + + wait_until(lambda: has_executed_all(g), timeout = 5, interval = 1) + mock_invoker.stop() + + g = mock_invoker.invoker_services.graph_execution_manager.get(g.id) + assert g.is_complete() diff --git a/tests/nodes/test_node_graph.py b/tests/nodes/test_node_graph.py new file mode 100644 index 0000000000..1b5b341192 --- /dev/null +++ b/tests/nodes/test_node_graph.py @@ -0,0 +1,501 @@ +from ldm.invoke.app.invocations.image import * + +from .test_nodes import ListPassThroughInvocation, PromptTestInvocation +from ldm.invoke.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation +from ldm.invoke.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation +from ldm.invoke.app.invocations.upscale import UpscaleInvocation +import pytest + + +# Helpers +def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> tuple[EdgeConnection, EdgeConnection]: + return (EdgeConnection(node_id = from_id, field = from_field), EdgeConnection(node_id = to_id, field = to_field)) + +# Tests +def test_connections_are_compatible(): + from_node = TextToImageInvocation(id = "1", prompt = "Banana sushi") + from_field = "image" + to_node = UpscaleInvocation(id = "2") + to_field = "image" + + result = are_connections_compatible(from_node, from_field, to_node, to_field) + + assert result == True + +def test_connections_are_incompatible(): + from_node = TextToImageInvocation(id = "1", prompt = "Banana sushi") + from_field = "image" + to_node = UpscaleInvocation(id = "2") + to_field = "strength" + + result = are_connections_compatible(from_node, from_field, to_node, to_field) + + assert result == False + +def test_connections_incompatible_with_invalid_fields(): + from_node = TextToImageInvocation(id = "1", prompt = "Banana sushi") + from_field = "invalid_field" + to_node = UpscaleInvocation(id = "2") + to_field = "image" + + # From field is invalid + result = are_connections_compatible(from_node, from_field, to_node, to_field) + assert result == False + + # To field is invalid + from_field = "image" + to_field = "invalid_field" + + result = are_connections_compatible(from_node, from_field, to_node, to_field) + assert result == False + +def test_graph_can_add_node(): + g = Graph() + n = TextToImageInvocation(id = "1", prompt = "Banana sushi") + g.add_node(n) + + assert n.id in g.nodes + +def test_graph_fails_to_add_node_with_duplicate_id(): + g = Graph() + n = TextToImageInvocation(id = "1", prompt = "Banana sushi") + g.add_node(n) + n2 = TextToImageInvocation(id = "1", prompt = "Banana sushi the second") + + with pytest.raises(NodeAlreadyInGraphError): + g.add_node(n2) + +def test_graph_updates_node(): + g = Graph() + n = TextToImageInvocation(id = "1", prompt = "Banana sushi") + g.add_node(n) + n2 = TextToImageInvocation(id = "2", prompt = "Banana sushi the second") + g.add_node(n2) + + nu = TextToImageInvocation(id = "1", prompt = "Banana sushi updated") + + g.update_node("1", nu) + + assert g.nodes["1"].prompt == "Banana sushi updated" + +def test_graph_fails_to_update_node_if_type_changes(): + g = Graph() + n = TextToImageInvocation(id = "1", prompt = "Banana sushi") + g.add_node(n) + n2 = UpscaleInvocation(id = "2") + g.add_node(n2) + + nu = UpscaleInvocation(id = "1") + + with pytest.raises(TypeError): + g.update_node("1", nu) + +def test_graph_allows_non_conflicting_id_change(): + g = Graph() + n = TextToImageInvocation(id = "1", prompt = "Banana sushi") + g.add_node(n) + n2 = UpscaleInvocation(id = "2") + g.add_node(n2) + e1 = create_edge(n.id,"image",n2.id,"image") + g.add_edge(e1) + + nu = TextToImageInvocation(id = "3", prompt = "Banana sushi") + g.update_node("1", nu) + + with pytest.raises(NodeNotFoundError): + g.get_node("1") + + assert g.get_node("3").prompt == "Banana sushi" + + assert len(g.edges) == 1 + assert (EdgeConnection(node_id = "3", field = "image"), EdgeConnection(node_id = "2", field = "image")) in g.edges + +def test_graph_fails_to_update_node_id_if_conflict(): + g = Graph() + n = TextToImageInvocation(id = "1", prompt = "Banana sushi") + g.add_node(n) + n2 = TextToImageInvocation(id = "2", prompt = "Banana sushi the second") + g.add_node(n2) + + nu = TextToImageInvocation(id = "2", prompt = "Banana sushi") + with pytest.raises(NodeAlreadyInGraphError): + g.update_node("1", nu) + +def test_graph_adds_edge(): + g = Graph() + n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n2 = UpscaleInvocation(id = "2") + g.add_node(n1) + g.add_node(n2) + e = create_edge(n1.id,"image",n2.id,"image") + + g.add_edge(e) + + assert e in g.edges + +def test_graph_fails_to_add_edge_with_cycle(): + g = Graph() + n1 = UpscaleInvocation(id = "1") + g.add_node(n1) + e = create_edge(n1.id,"image",n1.id,"image") + with pytest.raises(InvalidEdgeError): + g.add_edge(e) + +def test_graph_fails_to_add_edge_with_long_cycle(): + g = Graph() + n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n2 = UpscaleInvocation(id = "2") + n3 = UpscaleInvocation(id = "3") + g.add_node(n1) + g.add_node(n2) + g.add_node(n3) + e1 = create_edge(n1.id,"image",n2.id,"image") + e2 = create_edge(n2.id,"image",n3.id,"image") + e3 = create_edge(n3.id,"image",n2.id,"image") + g.add_edge(e1) + g.add_edge(e2) + with pytest.raises(InvalidEdgeError): + g.add_edge(e3) + +def test_graph_fails_to_add_edge_with_missing_node_id(): + g = Graph() + n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n2 = UpscaleInvocation(id = "2") + g.add_node(n1) + g.add_node(n2) + e1 = create_edge("1","image","3","image") + e2 = create_edge("3","image","1","image") + with pytest.raises(InvalidEdgeError): + g.add_edge(e1) + with pytest.raises(InvalidEdgeError): + g.add_edge(e2) + +def test_graph_fails_to_add_edge_when_destination_exists(): + g = Graph() + n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n2 = UpscaleInvocation(id = "2") + n3 = UpscaleInvocation(id = "3") + g.add_node(n1) + g.add_node(n2) + g.add_node(n3) + e1 = create_edge(n1.id,"image",n2.id,"image") + e2 = create_edge(n1.id,"image",n3.id,"image") + e3 = create_edge(n2.id,"image",n3.id,"image") + g.add_edge(e1) + g.add_edge(e2) + with pytest.raises(InvalidEdgeError): + g.add_edge(e3) + + +def test_graph_fails_to_add_edge_with_mismatched_types(): + g = Graph() + n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n2 = UpscaleInvocation(id = "2") + g.add_node(n1) + g.add_node(n2) + e1 = create_edge("1","image","2","strength") + with pytest.raises(InvalidEdgeError): + g.add_edge(e1) + +def test_graph_connects_collector(): + g = Graph() + n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n2 = TextToImageInvocation(id = "2", prompt = "Banana sushi 2") + n3 = CollectInvocation(id = "3") + n4 = ListPassThroughInvocation(id = "4") + g.add_node(n1) + g.add_node(n2) + g.add_node(n3) + g.add_node(n4) + + e1 = create_edge("1","image","3","item") + e2 = create_edge("2","image","3","item") + e3 = create_edge("3","collection","4","collection") + g.add_edge(e1) + g.add_edge(e2) + g.add_edge(e3) + +# TODO: test that derived types mixed with base types are compatible + +def test_graph_collector_invalid_with_varying_input_types(): + g = Graph() + n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n2 = PromptTestInvocation(id = "2", prompt = "banana sushi 2") + n3 = CollectInvocation(id = "3") + g.add_node(n1) + g.add_node(n2) + g.add_node(n3) + + e1 = create_edge("1","image","3","item") + e2 = create_edge("2","prompt","3","item") + g.add_edge(e1) + + with pytest.raises(InvalidEdgeError): + g.add_edge(e2) + +def test_graph_collector_invalid_with_varying_input_output(): + g = Graph() + n1 = PromptTestInvocation(id = "1", prompt = "Banana sushi") + n2 = PromptTestInvocation(id = "2", prompt = "Banana sushi 2") + n3 = CollectInvocation(id = "3") + n4 = ListPassThroughInvocation(id = "4") + g.add_node(n1) + g.add_node(n2) + g.add_node(n3) + g.add_node(n4) + + e1 = create_edge("1","prompt","3","item") + e2 = create_edge("2","prompt","3","item") + e3 = create_edge("3","collection","4","collection") + g.add_edge(e1) + g.add_edge(e2) + + with pytest.raises(InvalidEdgeError): + g.add_edge(e3) + +def test_graph_collector_invalid_with_non_list_output(): + g = Graph() + n1 = PromptTestInvocation(id = "1", prompt = "Banana sushi") + n2 = PromptTestInvocation(id = "2", prompt = "Banana sushi 2") + n3 = CollectInvocation(id = "3") + n4 = PromptTestInvocation(id = "4") + g.add_node(n1) + g.add_node(n2) + g.add_node(n3) + g.add_node(n4) + + e1 = create_edge("1","prompt","3","item") + e2 = create_edge("2","prompt","3","item") + e3 = create_edge("3","collection","4","prompt") + g.add_edge(e1) + g.add_edge(e2) + + with pytest.raises(InvalidEdgeError): + g.add_edge(e3) + +def test_graph_connects_iterator(): + g = Graph() + n1 = ListPassThroughInvocation(id = "1") + n2 = IterateInvocation(id = "2") + n3 = ImageToImageInvocation(id = "3", prompt = "Banana sushi") + g.add_node(n1) + g.add_node(n2) + g.add_node(n3) + + e1 = create_edge("1","collection","2","collection") + e2 = create_edge("2","item","3","image") + g.add_edge(e1) + g.add_edge(e2) + +# TODO: TEST INVALID ITERATOR SCENARIOS + +def test_graph_iterator_invalid_if_multiple_inputs(): + g = Graph() + n1 = ListPassThroughInvocation(id = "1") + n2 = IterateInvocation(id = "2") + n3 = ImageToImageInvocation(id = "3", prompt = "Banana sushi") + n4 = ListPassThroughInvocation(id = "4") + g.add_node(n1) + g.add_node(n2) + g.add_node(n3) + g.add_node(n4) + + e1 = create_edge("1","collection","2","collection") + e2 = create_edge("2","item","3","image") + e3 = create_edge("4","collection","2","collection") + g.add_edge(e1) + g.add_edge(e2) + + with pytest.raises(InvalidEdgeError): + g.add_edge(e3) + +def test_graph_iterator_invalid_if_input_not_list(): + g = Graph() + n1 = TextToImageInvocation(id = "1", promopt = "Banana sushi") + n2 = IterateInvocation(id = "2") + g.add_node(n1) + g.add_node(n2) + + e1 = create_edge("1","collection","2","collection") + + with pytest.raises(InvalidEdgeError): + g.add_edge(e1) + +def test_graph_iterator_invalid_if_output_and_input_types_different(): + g = Graph() + n1 = ListPassThroughInvocation(id = "1") + n2 = IterateInvocation(id = "2") + n3 = PromptTestInvocation(id = "3", prompt = "Banana sushi") + g.add_node(n1) + g.add_node(n2) + g.add_node(n3) + + e1 = create_edge("1","collection","2","collection") + e2 = create_edge("2","item","3","prompt") + g.add_edge(e1) + + with pytest.raises(InvalidEdgeError): + g.add_edge(e2) + +def test_graph_validates(): + g = Graph() + n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n2 = UpscaleInvocation(id = "2") + g.add_node(n1) + g.add_node(n2) + e1 = create_edge("1","image","2","image") + g.add_edge(e1) + + assert g.is_valid() == True + +def test_graph_invalid_if_edges_reference_missing_nodes(): + g = Graph() + n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + g.nodes[n1.id] = n1 + e1 = create_edge("1","image","2","image") + g.edges.append(e1) + + assert g.is_valid() == False + +def test_graph_invalid_if_subgraph_invalid(): + g = Graph() + n1 = GraphInvocation(id = "1") + n1.graph = Graph() + + n1_1 = TextToImageInvocation(id = "2", prompt = "Banana sushi") + n1.graph.nodes[n1_1.id] = n1_1 + e1 = create_edge("1","image","2","image") + n1.graph.edges.append(e1) + + g.nodes[n1.id] = n1 + + assert g.is_valid() == False + +def test_graph_invalid_if_has_cycle(): + g = Graph() + n1 = UpscaleInvocation(id = "1") + n2 = UpscaleInvocation(id = "2") + g.nodes[n1.id] = n1 + g.nodes[n2.id] = n2 + e1 = create_edge("1","image","2","image") + e2 = create_edge("2","image","1","image") + g.edges.append(e1) + g.edges.append(e2) + + assert g.is_valid() == False + +def test_graph_invalid_with_invalid_connection(): + g = Graph() + n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n2 = UpscaleInvocation(id = "2") + g.nodes[n1.id] = n1 + g.nodes[n2.id] = n2 + e1 = create_edge("1","image","2","strength") + g.edges.append(e1) + + assert g.is_valid() == False + + +# TODO: Subgraph operations +def test_graph_gets_subgraph_node(): + g = Graph() + n1 = GraphInvocation(id = "1") + n1.graph = Graph() + n1.graph.add_node + + n1_1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n1.graph.add_node(n1_1) + + g.add_node(n1) + + result = g.get_node('1.1') + + assert result is not None + assert result.id == '1' + assert result == n1_1 + +def test_graph_fails_to_get_missing_subgraph_node(): + g = Graph() + n1 = GraphInvocation(id = "1") + n1.graph = Graph() + n1.graph.add_node + + n1_1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n1.graph.add_node(n1_1) + + g.add_node(n1) + + with pytest.raises(NodeNotFoundError): + result = g.get_node('1.2') + +def test_graph_fails_to_enumerate_non_subgraph_node(): + g = Graph() + n1 = GraphInvocation(id = "1") + n1.graph = Graph() + n1.graph.add_node + + n1_1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n1.graph.add_node(n1_1) + + g.add_node(n1) + + n2 = UpscaleInvocation(id = "2") + g.add_node(n2) + + with pytest.raises(NodeNotFoundError): + result = g.get_node('2.1') + +def test_graph_gets_networkx_graph(): + g = Graph() + n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n2 = UpscaleInvocation(id = "2") + g.add_node(n1) + g.add_node(n2) + e = create_edge(n1.id,"image",n2.id,"image") + g.add_edge(e) + + nxg = g.nx_graph() + + assert '1' in nxg.nodes + assert '2' in nxg.nodes + assert ('1','2') in nxg.edges + + +# TODO: Graph serializes and deserializes +def test_graph_can_serialize(): + g = Graph() + n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n2 = UpscaleInvocation(id = "2") + g.add_node(n1) + g.add_node(n2) + e = create_edge(n1.id,"image",n2.id,"image") + g.add_edge(e) + + # Not throwing on this line is sufficient + json = g.json() + +def test_graph_can_deserialize(): + g = Graph() + n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n2 = UpscaleInvocation(id = "2") + g.add_node(n1) + g.add_node(n2) + e = create_edge(n1.id,"image",n2.id,"image") + g.add_edge(e) + + json = g.json() + g2 = Graph.parse_raw(json) + + assert g2 is not None + assert g2.nodes['1'] is not None + assert g2.nodes['2'] is not None + assert len(g2.edges) == 1 + assert g2.edges[0][0].node_id == '1' + assert g2.edges[0][0].field == 'image' + assert g2.edges[0][1].node_id == '2' + assert g2.edges[0][1].field == 'image' + +def test_graph_can_generate_schema(): + # Not throwing on this line is sufficient + # NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation + schema = Graph.schema_json(indent = 2) diff --git a/tests/nodes/test_nodes.py b/tests/nodes/test_nodes.py new file mode 100644 index 0000000000..fea2e75e95 --- /dev/null +++ b/tests/nodes/test_nodes.py @@ -0,0 +1,92 @@ +from typing import Any, Callable, Literal +from ldm.invoke.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext +from ldm.invoke.app.invocations.image import ImageField +from ldm.invoke.app.services.invocation_services import InvocationServices +from pydantic import Field +import pytest + +# Define test invocations before importing anything that uses invocations +class ListPassThroughInvocationOutput(BaseInvocationOutput): + type: Literal['test_list_output'] = 'test_list_output' + + collection: list[ImageField] = Field(default_factory=list) + +class ListPassThroughInvocation(BaseInvocation): + type: Literal['test_list'] = 'test_list' + + collection: list[ImageField] = Field(default_factory=list) + + def invoke(self, context: InvocationContext) -> ListPassThroughInvocationOutput: + return ListPassThroughInvocationOutput(collection = self.collection) + +class PromptTestInvocationOutput(BaseInvocationOutput): + type: Literal['test_prompt_output'] = 'test_prompt_output' + + prompt: str = Field(default = "") + +class PromptTestInvocation(BaseInvocation): + type: Literal['test_prompt'] = 'test_prompt' + + prompt: str = Field(default = "") + + def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput: + return PromptTestInvocationOutput(prompt = self.prompt) + +class ImageTestInvocationOutput(BaseInvocationOutput): + type: Literal['test_image_output'] = 'test_image_output' + + image: ImageField = Field() + +class ImageTestInvocation(BaseInvocation): + type: Literal['test_image'] = 'test_image' + + prompt: str = Field(default = "") + + def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput: + return ImageTestInvocationOutput(image=ImageField(image_name=self.id)) + +class PromptCollectionTestInvocationOutput(BaseInvocationOutput): + type: Literal['test_prompt_collection_output'] = 'test_prompt_collection_output' + collection: list[str] = Field(default_factory=list) + +class PromptCollectionTestInvocation(BaseInvocation): + type: Literal['test_prompt_collection'] = 'test_prompt_collection' + collection: list[str] = Field() + + def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput: + return PromptCollectionTestInvocationOutput(collection=self.collection.copy()) + + +from ldm.invoke.app.services.events import EventServiceBase +from ldm.invoke.app.services.graph import EdgeConnection + +def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> tuple[EdgeConnection, EdgeConnection]: + return (EdgeConnection(node_id = from_id, field = from_field), EdgeConnection(node_id = to_id, field = to_field)) + + +class TestEvent: + event_name: str + payload: Any + + def __init__(self, event_name: str, payload: Any): + self.event_name = event_name + self.payload = payload + +class TestEventService(EventServiceBase): + events: list + + def __init__(self): + super().__init__() + self.events = list() + + def dispatch(self, event_name: str, payload: Any) -> None: + pass + +def wait_until(condition: Callable[[], bool], timeout: int = 10, interval: float = 0.1) -> None: + import time + start_time = time.time() + while time.time() - start_time < timeout: + if condition(): + return + time.sleep(interval) + raise TimeoutError("Condition not met") \ No newline at end of file diff --git a/tests/nodes/test_sqlite.py b/tests/nodes/test_sqlite.py new file mode 100644 index 0000000000..e499bbce12 --- /dev/null +++ b/tests/nodes/test_sqlite.py @@ -0,0 +1,112 @@ +from ldm.invoke.app.services.sqlite import SqliteItemStorage, sqlite_memory +from pydantic import BaseModel, Field + + +class TestModel(BaseModel): + id: str = Field(description = "ID") + name: str = Field(description = "Name") + + +def test_sqlite_service_can_create_and_get(): + db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id') + db.set(TestModel(id = '1', name = 'Test')) + assert db.get('1') == TestModel(id = '1', name = 'Test') + +def test_sqlite_service_can_list(): + db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id') + db.set(TestModel(id = '1', name = 'Test')) + db.set(TestModel(id = '2', name = 'Test')) + db.set(TestModel(id = '3', name = 'Test')) + results = db.list() + assert results.page == 0 + assert results.pages == 1 + assert results.per_page == 10 + assert results.total == 3 + assert results.items == [TestModel(id = '1', name = 'Test'), TestModel(id = '2', name = 'Test'), TestModel(id = '3', name = 'Test')] + +def test_sqlite_service_can_delete(): + db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id') + db.set(TestModel(id = '1', name = 'Test')) + db.delete('1') + assert db.get('1') is None + +def test_sqlite_service_calls_set_callback(): + db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id') + called = False + def on_changed(item: TestModel): + nonlocal called + called = True + db.on_changed(on_changed) + db.set(TestModel(id = '1', name = 'Test')) + assert called + +def test_sqlite_service_calls_delete_callback(): + db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id') + called = False + def on_deleted(item_id: str): + nonlocal called + called = True + db.on_deleted(on_deleted) + db.set(TestModel(id = '1', name = 'Test')) + db.delete('1') + assert called + +def test_sqlite_service_can_list_with_pagination(): + db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id') + db.set(TestModel(id = '1', name = 'Test')) + db.set(TestModel(id = '2', name = 'Test')) + db.set(TestModel(id = '3', name = 'Test')) + results = db.list(page = 0, per_page = 2) + assert results.page == 0 + assert results.pages == 2 + assert results.per_page == 2 + assert results.total == 3 + assert results.items == [TestModel(id = '1', name = 'Test'), TestModel(id = '2', name = 'Test')] + +def test_sqlite_service_can_list_with_pagination_and_offset(): + db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id') + db.set(TestModel(id = '1', name = 'Test')) + db.set(TestModel(id = '2', name = 'Test')) + db.set(TestModel(id = '3', name = 'Test')) + results = db.list(page = 1, per_page = 2) + assert results.page == 1 + assert results.pages == 2 + assert results.per_page == 2 + assert results.total == 3 + assert results.items == [TestModel(id = '3', name = 'Test')] + +def test_sqlite_service_can_search(): + db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id') + db.set(TestModel(id = '1', name = 'Test')) + db.set(TestModel(id = '2', name = 'Test')) + db.set(TestModel(id = '3', name = 'Test')) + results = db.search(query = 'Test') + assert results.page == 0 + assert results.pages == 1 + assert results.per_page == 10 + assert results.total == 3 + assert results.items == [TestModel(id = '1', name = 'Test'), TestModel(id = '2', name = 'Test'), TestModel(id = '3', name = 'Test')] + +def test_sqlite_service_can_search_with_pagination(): + db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id') + db.set(TestModel(id = '1', name = 'Test')) + db.set(TestModel(id = '2', name = 'Test')) + db.set(TestModel(id = '3', name = 'Test')) + results = db.search(query = 'Test', page = 0, per_page = 2) + assert results.page == 0 + assert results.pages == 2 + assert results.per_page == 2 + assert results.total == 3 + assert results.items == [TestModel(id = '1', name = 'Test'), TestModel(id = '2', name = 'Test')] + +def test_sqlite_service_can_search_with_pagination_and_offset(): + db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id') + db.set(TestModel(id = '1', name = 'Test')) + db.set(TestModel(id = '2', name = 'Test')) + db.set(TestModel(id = '3', name = 'Test')) + results = db.search(query = 'Test', page = 1, per_page = 2) + assert results.page == 1 + assert results.pages == 2 + assert results.per_page == 2 + assert results.total == 3 + assert results.items == [TestModel(id = '3', name = 'Test')] From 81fd2ee8c19f69d9ce0693810bfd52ec53a3f4ec Mon Sep 17 00:00:00 2001 From: Kyle Schouviller Date: Fri, 24 Feb 2023 20:11:28 -0800 Subject: [PATCH 50/57] [nodes] Removed InvokerServices, simplying service model --- ldm/invoke/app/api/dependencies.py | 19 ++++------ ldm/invoke/app/api/routers/sessions.py | 28 +++++++------- ldm/invoke/app/cli_app.py | 25 ++++++------- .../app/services/invocation_services.py | 15 +++++++- ldm/invoke/app/services/invoker.py | 37 +++++-------------- ldm/invoke/app/services/processor.py | 6 +-- tests/nodes/test_graph_execution_state.py | 14 +++++-- tests/nodes/test_invoker.py | 26 ++++++------- 8 files changed, 81 insertions(+), 89 deletions(-) diff --git a/ldm/invoke/app/api/dependencies.py b/ldm/invoke/app/api/dependencies.py index 60dd522803..08f362133e 100644 --- a/ldm/invoke/app/api/dependencies.py +++ b/ldm/invoke/app/api/dependencies.py @@ -13,7 +13,7 @@ from ...globals import Globals from ..services.image_storage import DiskImageStorage from ..services.invocation_queue import MemoryInvocationQueue from ..services.invocation_services import InvocationServices -from ..services.invoker import Invoker, InvokerServices +from ..services.invoker import Invoker from ..services.generate_initializer import get_generate from .events import FastAPIEventService @@ -60,22 +60,19 @@ class ApiDependencies: images = DiskImageStorage(output_folder) - services = InvocationServices( - generate = generate, - events = events, - images = images - ) - # TODO: build a file/path manager? db_location = os.path.join(output_folder, 'invokeai.db') - invoker_services = InvokerServices( - queue = MemoryInvocationQueue(), + services = InvocationServices( + generate = generate, + events = events, + images = images, + queue = MemoryInvocationQueue(), graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = db_location, table_name = 'graph_executions'), - processor = DefaultInvocationProcessor() + processor = DefaultInvocationProcessor() ) - ApiDependencies.invoker = Invoker(services, invoker_services) + ApiDependencies.invoker = Invoker(services) @staticmethod def shutdown(): diff --git a/ldm/invoke/app/api/routers/sessions.py b/ldm/invoke/app/api/routers/sessions.py index 77008ad6e4..beb13736c6 100644 --- a/ldm/invoke/app/api/routers/sessions.py +++ b/ldm/invoke/app/api/routers/sessions.py @@ -44,9 +44,9 @@ async def list_sessions( ) -> PaginatedResults[GraphExecutionState]: """Gets a list of sessions, optionally searching""" if filter == '': - result = ApiDependencies.invoker.invoker_services.graph_execution_manager.list(page, per_page) + result = ApiDependencies.invoker.services.graph_execution_manager.list(page, per_page) else: - result = ApiDependencies.invoker.invoker_services.graph_execution_manager.search(query, page, per_page) + result = ApiDependencies.invoker.services.graph_execution_manager.search(query, page, per_page) return result @@ -60,7 +60,7 @@ async def get_session( session_id: str = Path(description = "The id of the session to get") ) -> GraphExecutionState: """Gets a session""" - session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) + session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) if session is None: return Response(status_code = 404) else: @@ -80,13 +80,13 @@ async def add_node( node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body(description = "The node to add") ) -> str: """Adds a node to the graph""" - session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) + session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) if session is None: return Response(status_code = 404) try: session.add_node(node) - ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? + ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? return session.id except NodeAlreadyExecutedError: return Response(status_code = 400) @@ -108,13 +108,13 @@ async def update_node( node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body(description = "The new node") ) -> GraphExecutionState: """Updates a node in the graph and removes all linked edges""" - session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) + session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) if session is None: return Response(status_code = 404) try: session.update_node(node_path, node) - ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? + ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? return session except NodeAlreadyExecutedError: return Response(status_code = 400) @@ -135,13 +135,13 @@ async def delete_node( node_path: str = Path(description = "The path to the node to delete") ) -> GraphExecutionState: """Deletes a node in the graph and removes all linked edges""" - session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) + session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) if session is None: return Response(status_code = 404) try: session.delete_node(node_path) - ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? + ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? return session except NodeAlreadyExecutedError: return Response(status_code = 400) @@ -162,13 +162,13 @@ async def add_edge( edge: tuple[EdgeConnection, EdgeConnection] = Body(description = "The edge to add") ) -> GraphExecutionState: """Adds an edge to the graph""" - session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) + session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) if session is None: return Response(status_code = 404) try: session.add_edge(edge) - ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? + ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? return session except NodeAlreadyExecutedError: return Response(status_code = 400) @@ -193,14 +193,14 @@ async def delete_edge( to_field: str = Path(description = "The field of the node the edge is going to") ) -> GraphExecutionState: """Deletes an edge from the graph""" - session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) + session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) if session is None: return Response(status_code = 404) try: edge = (EdgeConnection(node_id = from_node_id, field = from_field), EdgeConnection(node_id = to_node_id, field = to_field)) session.delete_edge(edge) - ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? + ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? return session except NodeAlreadyExecutedError: return Response(status_code = 400) @@ -221,7 +221,7 @@ async def invoke_session( all: bool = Query(default = False, description = "Whether or not to invoke all remaining invocations") ) -> None: """Invokes a session""" - session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) + session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) if session is None: return Response(status_code = 404) diff --git a/ldm/invoke/app/cli_app.py b/ldm/invoke/app/cli_app.py index 6071afabb2..9081f3b083 100644 --- a/ldm/invoke/app/cli_app.py +++ b/ldm/invoke/app/cli_app.py @@ -20,7 +20,7 @@ from .services.image_storage import DiskImageStorage from .services.invocation_queue import MemoryInvocationQueue from .invocations.baseinvocation import BaseInvocation from .services.invocation_services import InvocationServices -from .services.invoker import Invoker, InvokerServices +from .services.invoker import Invoker from .invocations import * from ..args import Args from .services.events import EventServiceBase @@ -171,28 +171,25 @@ def invoke_cli(): output_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../outputs')) - services = InvocationServices( - generate = generate, - events = events, - images = DiskImageStorage(output_folder) - ) - # TODO: build a file/path manager? db_location = os.path.join(output_folder, 'invokeai.db') - invoker_services = InvokerServices( - queue = MemoryInvocationQueue(), + services = InvocationServices( + generate = generate, + events = events, + images = DiskImageStorage(output_folder), + queue = MemoryInvocationQueue(), graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = db_location, table_name = 'graph_executions'), - processor = DefaultInvocationProcessor() + processor = DefaultInvocationProcessor() ) - invoker = Invoker(services, invoker_services) + invoker = Invoker(services) session = invoker.create_execution_state() parser = get_invocation_parser() # Uncomment to print out previous sessions at startup - # print(invoker_services.session_manager.list()) + # print(services.session_manager.list()) # Defaults storage defaults: Dict[str, Any] = dict() @@ -213,7 +210,7 @@ def invoke_cli(): try: # Refresh the state of the session - session = invoker.invoker_services.graph_execution_manager.get(session.id) + session = invoker.services.graph_execution_manager.get(session.id) history = list(get_graph_execution_history(session)) # Split the command for piping @@ -289,7 +286,7 @@ def invoke_cli(): invoker.invoke(session, invoke_all = True) while not session.is_complete(): # Wait some time - session = invoker.invoker_services.graph_execution_manager.get(session.id) + session = invoker.services.graph_execution_manager.get(session.id) time.sleep(0.1) except InvalidArgs: diff --git a/ldm/invoke/app/services/invocation_services.py b/ldm/invoke/app/services/invocation_services.py index 9eb5309d3d..40a64e64e5 100644 --- a/ldm/invoke/app/services/invocation_services.py +++ b/ldm/invoke/app/services/invocation_services.py @@ -1,4 +1,6 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) +from .invocation_queue import InvocationQueueABC +from .item_storage import ItemStorageABC from .image_storage import ImageStorageBase from .events import EventServiceBase from ....generate import Generate @@ -9,12 +11,23 @@ class InvocationServices(): generate: Generate # TODO: wrap Generate, or split it up from model? events: EventServiceBase images: ImageStorageBase + queue: InvocationQueueABC + + # NOTE: we must forward-declare any types that include invocations, since invocations can use services + graph_execution_manager: ItemStorageABC['GraphExecutionState'] + processor: 'InvocationProcessorABC' def __init__(self, generate: Generate, events: EventServiceBase, - images: ImageStorageBase + images: ImageStorageBase, + queue: InvocationQueueABC, + graph_execution_manager: ItemStorageABC['GraphExecutionState'], + processor: 'InvocationProcessorABC' ): self.generate = generate self.events = events self.images = images + self.queue = queue + self.graph_execution_manager = graph_execution_manager + self.processor = processor diff --git a/ldm/invoke/app/services/invoker.py b/ldm/invoke/app/services/invoker.py index 796f541781..4397a75021 100644 --- a/ldm/invoke/app/services/invoker.py +++ b/ldm/invoke/app/services/invoker.py @@ -9,34 +9,15 @@ from .invocation_services import InvocationServices from .invocation_queue import InvocationQueueABC, InvocationQueueItem -class InvokerServices: - """Services used by the Invoker for execution""" - - queue: InvocationQueueABC - graph_execution_manager: ItemStorageABC[GraphExecutionState] - processor: 'InvocationProcessorABC' - - def __init__(self, - queue: InvocationQueueABC, - graph_execution_manager: ItemStorageABC[GraphExecutionState], - processor: 'InvocationProcessorABC'): - self.queue = queue - self.graph_execution_manager = graph_execution_manager - self.processor = processor - - class Invoker: """The invoker, used to execute invocations""" services: InvocationServices - invoker_services: InvokerServices def __init__(self, - services: InvocationServices, # Services used by nodes to perform invocations - invoker_services: InvokerServices # Services used by the invoker for orchestration + services: InvocationServices ): self.services = services - self.invoker_services = invoker_services self._start() @@ -49,11 +30,11 @@ class Invoker: return None # Save the execution state - self.invoker_services.graph_execution_manager.set(graph_execution_state) + self.services.graph_execution_manager.set(graph_execution_state) # Queue the invocation print(f'queueing item {invocation.id}') - self.invoker_services.queue.put(InvocationQueueItem( + self.services.queue.put(InvocationQueueItem( #session_id = session.id, graph_execution_state_id = graph_execution_state.id, invocation_id = invocation.id, @@ -66,7 +47,7 @@ class Invoker: def create_execution_state(self, graph: Graph|None = None) -> GraphExecutionState: """Creates a new execution state for the given graph""" new_state = GraphExecutionState(graph = Graph() if graph is None else graph) - self.invoker_services.graph_execution_manager.set(new_state) + self.services.graph_execution_manager.set(new_state) return new_state @@ -86,8 +67,8 @@ class Invoker: def _start(self) -> None: """Starts the invoker. This is called automatically when the invoker is created.""" - for service in vars(self.invoker_services): - self.__start_service(getattr(self.invoker_services, service)) + for service in vars(self.services): + self.__start_service(getattr(self.services, service)) for service in vars(self.services): self.__start_service(getattr(self.services, service)) @@ -99,10 +80,10 @@ class Invoker: for service in vars(self.services): self.__stop_service(getattr(self.services, service)) - for service in vars(self.invoker_services): - self.__stop_service(getattr(self.invoker_services, service)) + for service in vars(self.services): + self.__stop_service(getattr(self.services, service)) - self.invoker_services.queue.put(None) + self.services.queue.put(None) class InvocationProcessorABC(ABC): diff --git a/ldm/invoke/app/services/processor.py b/ldm/invoke/app/services/processor.py index 9b51a6bcbc..9ea4349bbf 100644 --- a/ldm/invoke/app/services/processor.py +++ b/ldm/invoke/app/services/processor.py @@ -28,11 +28,11 @@ class DefaultInvocationProcessor(InvocationProcessorABC): def __process(self, stop_event: Event): try: while not stop_event.is_set(): - queue_item: InvocationQueueItem = self.__invoker.invoker_services.queue.get() + queue_item: InvocationQueueItem = self.__invoker.services.queue.get() if not queue_item: # Probably stopping continue - graph_execution_state = self.__invoker.invoker_services.graph_execution_manager.get(queue_item.graph_execution_state_id) + graph_execution_state = self.__invoker.services.graph_execution_manager.get(queue_item.graph_execution_state_id) invocation = graph_execution_state.execution_graph.get_node(queue_item.invocation_id) # Send starting event @@ -52,7 +52,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC): graph_execution_state.complete(invocation.id, outputs) # Save the state changes - self.__invoker.invoker_services.graph_execution_manager.set(graph_execution_state) + self.__invoker.services.graph_execution_manager.set(graph_execution_state) # Send complete event self.__invoker.services.events.emit_invocation_complete( diff --git a/tests/nodes/test_graph_execution_state.py b/tests/nodes/test_graph_execution_state.py index 0a5dcc7734..980c262501 100644 --- a/tests/nodes/test_graph_execution_state.py +++ b/tests/nodes/test_graph_execution_state.py @@ -1,10 +1,11 @@ from .test_invoker import create_edge from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation from ldm.invoke.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext +from ldm.invoke.app.services.processor import DefaultInvocationProcessor +from ldm.invoke.app.services.sqlite import SqliteItemStorage, sqlite_memory +from ldm.invoke.app.services.invocation_queue import MemoryInvocationQueue from ldm.invoke.app.services.invocation_services import InvocationServices from ldm.invoke.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState -from ldm.invoke.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation -from ldm.invoke.app.invocations.upscale import UpscaleInvocation import pytest @@ -19,7 +20,14 @@ def simple_graph(): @pytest.fixture def mock_services(): # NOTE: none of these are actually called by the test invocations - return InvocationServices(generate = None, events = None, images = None) + return InvocationServices( + generate = None, + events = None, + images = None, + queue = MemoryInvocationQueue(), + graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'), + processor = DefaultInvocationProcessor() + ) def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]: n = g.next() diff --git a/tests/nodes/test_invoker.py b/tests/nodes/test_invoker.py index a6d96f61c0..e9109728d5 100644 --- a/tests/nodes/test_invoker.py +++ b/tests/nodes/test_invoker.py @@ -2,12 +2,10 @@ from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTe from ldm.invoke.app.services.processor import DefaultInvocationProcessor from ldm.invoke.app.services.sqlite import SqliteItemStorage, sqlite_memory from ldm.invoke.app.services.invocation_queue import MemoryInvocationQueue -from ldm.invoke.app.services.invoker import Invoker, InvokerServices +from ldm.invoke.app.services.invoker import Invoker from ldm.invoke.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext from ldm.invoke.app.services.invocation_services import InvocationServices from ldm.invoke.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState -from ldm.invoke.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation -from ldm.invoke.app.invocations.upscale import UpscaleInvocation import pytest @@ -22,21 +20,19 @@ def simple_graph(): @pytest.fixture def mock_services() -> InvocationServices: # NOTE: none of these are actually called by the test invocations - return InvocationServices(generate = None, events = TestEventService(), images = None) - -@pytest.fixture() -def mock_invoker_services() -> InvokerServices: - return InvokerServices( + return InvocationServices( + generate = None, + events = TestEventService(), + images = None, queue = MemoryInvocationQueue(), graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'), processor = DefaultInvocationProcessor() ) @pytest.fixture() -def mock_invoker(mock_services: InvocationServices, mock_invoker_services: InvokerServices) -> Invoker: +def mock_invoker(mock_services: InvocationServices) -> Invoker: return Invoker( - services = mock_services, - invoker_services = mock_invoker_services + services = mock_services ) def test_can_create_graph_state(mock_invoker: Invoker): @@ -60,13 +56,13 @@ def test_can_invoke(mock_invoker: Invoker, simple_graph): assert invocation_id is not None def has_executed_any(g: GraphExecutionState): - g = mock_invoker.invoker_services.graph_execution_manager.get(g.id) + g = mock_invoker.services.graph_execution_manager.get(g.id) return len(g.executed) > 0 wait_until(lambda: has_executed_any(g), timeout = 5, interval = 1) mock_invoker.stop() - g = mock_invoker.invoker_services.graph_execution_manager.get(g.id) + g = mock_invoker.services.graph_execution_manager.get(g.id) assert len(g.executed) > 0 def test_can_invoke_all(mock_invoker: Invoker, simple_graph): @@ -75,11 +71,11 @@ def test_can_invoke_all(mock_invoker: Invoker, simple_graph): assert invocation_id is not None def has_executed_all(g: GraphExecutionState): - g = mock_invoker.invoker_services.graph_execution_manager.get(g.id) + g = mock_invoker.services.graph_execution_manager.get(g.id) return g.is_complete() wait_until(lambda: has_executed_all(g), timeout = 5, interval = 1) mock_invoker.stop() - g = mock_invoker.invoker_services.graph_execution_manager.get(g.id) + g = mock_invoker.services.graph_execution_manager.get(g.id) assert g.is_complete() From 1e7a6dc676c43c713a977a7d47b0d4748c61ea45 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Sat, 25 Feb 2023 20:21:47 -0800 Subject: [PATCH 51/57] doc(invoke_ai_web_server): put docstrings inside their functions Documentation strings are the first thing inside the function body. https://docs.python.org/3/tutorial/controlflow.html#defining-functions --- invokeai/backend/invoke_ai_web_server.py | 63 ++++++++++-------------- 1 file changed, 25 insertions(+), 38 deletions(-) diff --git a/invokeai/backend/invoke_ai_web_server.py b/invokeai/backend/invoke_ai_web_server.py index 8ee93c68f4..90e228b92b 100644 --- a/invokeai/backend/invoke_ai_web_server.py +++ b/invokeai/backend/invoke_ai_web_server.py @@ -7,13 +7,15 @@ import mimetypes import os import shutil import traceback +from pathlib import Path from threading import Event from uuid import uuid4 import eventlet -from pathlib import Path +import invokeai.frontend.dist as frontend from PIL import Image from PIL.Image import Image as ImageType +from compel.prompt_parser import Blend from flask import Flask, redirect, send_from_directory, request, make_response from flask_socketio import SocketIO from werkzeug.utils import secure_filename @@ -22,18 +24,15 @@ from invokeai.backend.modules.get_canvas_generation_mode import ( get_canvas_generation_mode, ) from invokeai.backend.modules.parameters import parameters_to_command -import invokeai.frontend.dist as frontend from ldm.generate import Generate from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash -from ldm.invoke.conditioning import get_tokens_for_prompt_object, get_prompt_structure, split_weighted_subprompts, \ - get_tokenizer +from ldm.invoke.conditioning import get_tokens_for_prompt_object, get_prompt_structure, get_tokenizer from ldm.invoke.generator.diffusers_pipeline import PipelineIntermediateState from ldm.invoke.generator.inpaint import infill_methods from ldm.invoke.globals import Globals, global_converted_ckpts_dir -from ldm.invoke.pngwriter import PngWriter, retrieve_metadata -from compel.prompt_parser import Blend from ldm.invoke.globals import global_models_dir from ldm.invoke.merge_diffusers import merge_diffusion_models +from ldm.invoke.pngwriter import PngWriter, retrieve_metadata # Loading Arguments opt = Args() @@ -1685,27 +1684,23 @@ class CanceledException(Exception): pass -""" -Returns a copy an image, cropped to a bounding box. -""" - - def copy_image_from_bounding_box( image: ImageType, x: int, y: int, width: int, height: int ) -> ImageType: + """ + Returns a copy an image, cropped to a bounding box. + """ with image as im: bounds = (x, y, x + width, y + height) im_cropped = im.crop(bounds) return im_cropped -""" -Converts a base64 image dataURL into an image. -The dataURL is split on the first commma. -""" - - def dataURL_to_image(dataURL: str) -> ImageType: + """ + Converts a base64 image dataURL into an image. + The dataURL is split on the first comma. + """ image = Image.open( io.BytesIO( base64.decodebytes( @@ -1719,12 +1714,10 @@ def dataURL_to_image(dataURL: str) -> ImageType: return image -""" -Converts an image into a base64 image dataURL. -""" - - def image_to_dataURL(image: ImageType) -> str: + """ + Converts an image into a base64 image dataURL. + """ buffered = io.BytesIO() image.save(buffered, format="PNG") image_base64 = "data:image/png;base64," + base64.b64encode( @@ -1733,13 +1726,11 @@ def image_to_dataURL(image: ImageType) -> str: return image_base64 -""" -Converts a base64 image dataURL into bytes. -The dataURL is split on the first commma. -""" - - def dataURL_to_bytes(dataURL: str) -> bytes: + """ + Converts a base64 image dataURL into bytes. + The dataURL is split on the first comma. + """ return base64.decodebytes( bytes( dataURL.split(",", 1)[1], @@ -1748,11 +1739,6 @@ def dataURL_to_bytes(dataURL: str) -> bytes: ) -""" -Pastes an image onto another with a bounding box. -""" - - def paste_image_into_bounding_box( recipient_image: ImageType, donor_image: ImageType, @@ -1761,23 +1747,24 @@ def paste_image_into_bounding_box( width: int, height: int, ) -> ImageType: + """ + Pastes an image onto another with a bounding box. + """ with recipient_image as im: bounds = (x, y, x + width, y + height) im.paste(donor_image, bounds) return recipient_image -""" -Saves a thumbnail of an image, returning its path. -""" - - def save_thumbnail( image: ImageType, filename: str, path: str, size: int = 256, ) -> str: + """ + Saves a thumbnail of an image, returning its path. + """ base_filename = os.path.splitext(filename)[0] thumbnail_path = os.path.join(path, base_filename + ".webp") From 3aab5e7e20132d739f44fc6c73899e440d95a6f0 Mon Sep 17 00:00:00 2001 From: mauwii Date: Sun, 26 Feb 2023 06:38:04 +0100 Subject: [PATCH 52/57] update .editorconfig - set `max_line_length = 88` for .py --- .editorconfig | 1 + 1 file changed, 1 insertion(+) diff --git a/.editorconfig b/.editorconfig index fe9b4a61d1..28e2100bab 100644 --- a/.editorconfig +++ b/.editorconfig @@ -13,6 +13,7 @@ trim_trailing_whitespace = true # Python [*.py] indent_size = 4 +max_line_length = 88 # css [*.css] From 47c1be332237666ee4da25d29c6e77d166a79546 Mon Sep 17 00:00:00 2001 From: mauwii Date: Sun, 26 Feb 2023 21:53:38 +0100 Subject: [PATCH 53/57] Revert "doc(invoke_ai_web_server): put docstrings inside their functions" This reverts commit 1e7a6dc676c43c713a977a7d47b0d4748c61ea45. --- invokeai/backend/invoke_ai_web_server.py | 63 ++++++++++++++---------- 1 file changed, 38 insertions(+), 25 deletions(-) diff --git a/invokeai/backend/invoke_ai_web_server.py b/invokeai/backend/invoke_ai_web_server.py index 90e228b92b..8ee93c68f4 100644 --- a/invokeai/backend/invoke_ai_web_server.py +++ b/invokeai/backend/invoke_ai_web_server.py @@ -7,15 +7,13 @@ import mimetypes import os import shutil import traceback -from pathlib import Path from threading import Event from uuid import uuid4 import eventlet -import invokeai.frontend.dist as frontend +from pathlib import Path from PIL import Image from PIL.Image import Image as ImageType -from compel.prompt_parser import Blend from flask import Flask, redirect, send_from_directory, request, make_response from flask_socketio import SocketIO from werkzeug.utils import secure_filename @@ -24,15 +22,18 @@ from invokeai.backend.modules.get_canvas_generation_mode import ( get_canvas_generation_mode, ) from invokeai.backend.modules.parameters import parameters_to_command +import invokeai.frontend.dist as frontend from ldm.generate import Generate from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash -from ldm.invoke.conditioning import get_tokens_for_prompt_object, get_prompt_structure, get_tokenizer +from ldm.invoke.conditioning import get_tokens_for_prompt_object, get_prompt_structure, split_weighted_subprompts, \ + get_tokenizer from ldm.invoke.generator.diffusers_pipeline import PipelineIntermediateState from ldm.invoke.generator.inpaint import infill_methods from ldm.invoke.globals import Globals, global_converted_ckpts_dir +from ldm.invoke.pngwriter import PngWriter, retrieve_metadata +from compel.prompt_parser import Blend from ldm.invoke.globals import global_models_dir from ldm.invoke.merge_diffusers import merge_diffusion_models -from ldm.invoke.pngwriter import PngWriter, retrieve_metadata # Loading Arguments opt = Args() @@ -1684,23 +1685,27 @@ class CanceledException(Exception): pass +""" +Returns a copy an image, cropped to a bounding box. +""" + + def copy_image_from_bounding_box( image: ImageType, x: int, y: int, width: int, height: int ) -> ImageType: - """ - Returns a copy an image, cropped to a bounding box. - """ with image as im: bounds = (x, y, x + width, y + height) im_cropped = im.crop(bounds) return im_cropped +""" +Converts a base64 image dataURL into an image. +The dataURL is split on the first commma. +""" + + def dataURL_to_image(dataURL: str) -> ImageType: - """ - Converts a base64 image dataURL into an image. - The dataURL is split on the first comma. - """ image = Image.open( io.BytesIO( base64.decodebytes( @@ -1714,10 +1719,12 @@ def dataURL_to_image(dataURL: str) -> ImageType: return image +""" +Converts an image into a base64 image dataURL. +""" + + def image_to_dataURL(image: ImageType) -> str: - """ - Converts an image into a base64 image dataURL. - """ buffered = io.BytesIO() image.save(buffered, format="PNG") image_base64 = "data:image/png;base64," + base64.b64encode( @@ -1726,11 +1733,13 @@ def image_to_dataURL(image: ImageType) -> str: return image_base64 +""" +Converts a base64 image dataURL into bytes. +The dataURL is split on the first commma. +""" + + def dataURL_to_bytes(dataURL: str) -> bytes: - """ - Converts a base64 image dataURL into bytes. - The dataURL is split on the first comma. - """ return base64.decodebytes( bytes( dataURL.split(",", 1)[1], @@ -1739,6 +1748,11 @@ def dataURL_to_bytes(dataURL: str) -> bytes: ) +""" +Pastes an image onto another with a bounding box. +""" + + def paste_image_into_bounding_box( recipient_image: ImageType, donor_image: ImageType, @@ -1747,24 +1761,23 @@ def paste_image_into_bounding_box( width: int, height: int, ) -> ImageType: - """ - Pastes an image onto another with a bounding box. - """ with recipient_image as im: bounds = (x, y, x + width, y + height) im.paste(donor_image, bounds) return recipient_image +""" +Saves a thumbnail of an image, returning its path. +""" + + def save_thumbnail( image: ImageType, filename: str, path: str, size: int = 256, ) -> str: - """ - Saves a thumbnail of an image, returning its path. - """ base_filename = os.path.splitext(filename)[0] thumbnail_path = os.path.join(path, base_filename + ".webp") From 2394f6458fc49a72cfd3c78e9048b3838a70a831 Mon Sep 17 00:00:00 2001 From: mauwii Date: Sun, 26 Feb 2023 21:54:06 +0100 Subject: [PATCH 54/57] Revert "[nodes] Removed InvokerServices, simplying service model" This reverts commit 81fd2ee8c19f69d9ce0693810bfd52ec53a3f4ec. --- ldm/invoke/app/api/dependencies.py | 19 ++++++---- ldm/invoke/app/api/routers/sessions.py | 28 +++++++------- ldm/invoke/app/cli_app.py | 25 +++++++------ .../app/services/invocation_services.py | 15 +------- ldm/invoke/app/services/invoker.py | 37 ++++++++++++++----- ldm/invoke/app/services/processor.py | 6 +-- tests/nodes/test_graph_execution_state.py | 14 ++----- tests/nodes/test_invoker.py | 26 +++++++------ 8 files changed, 89 insertions(+), 81 deletions(-) diff --git a/ldm/invoke/app/api/dependencies.py b/ldm/invoke/app/api/dependencies.py index 08f362133e..60dd522803 100644 --- a/ldm/invoke/app/api/dependencies.py +++ b/ldm/invoke/app/api/dependencies.py @@ -13,7 +13,7 @@ from ...globals import Globals from ..services.image_storage import DiskImageStorage from ..services.invocation_queue import MemoryInvocationQueue from ..services.invocation_services import InvocationServices -from ..services.invoker import Invoker +from ..services.invoker import Invoker, InvokerServices from ..services.generate_initializer import get_generate from .events import FastAPIEventService @@ -60,19 +60,22 @@ class ApiDependencies: images = DiskImageStorage(output_folder) - # TODO: build a file/path manager? - db_location = os.path.join(output_folder, 'invokeai.db') - services = InvocationServices( generate = generate, events = events, - images = images, - queue = MemoryInvocationQueue(), + images = images + ) + + # TODO: build a file/path manager? + db_location = os.path.join(output_folder, 'invokeai.db') + + invoker_services = InvokerServices( + queue = MemoryInvocationQueue(), graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = db_location, table_name = 'graph_executions'), - processor = DefaultInvocationProcessor() + processor = DefaultInvocationProcessor() ) - ApiDependencies.invoker = Invoker(services) + ApiDependencies.invoker = Invoker(services, invoker_services) @staticmethod def shutdown(): diff --git a/ldm/invoke/app/api/routers/sessions.py b/ldm/invoke/app/api/routers/sessions.py index beb13736c6..77008ad6e4 100644 --- a/ldm/invoke/app/api/routers/sessions.py +++ b/ldm/invoke/app/api/routers/sessions.py @@ -44,9 +44,9 @@ async def list_sessions( ) -> PaginatedResults[GraphExecutionState]: """Gets a list of sessions, optionally searching""" if filter == '': - result = ApiDependencies.invoker.services.graph_execution_manager.list(page, per_page) + result = ApiDependencies.invoker.invoker_services.graph_execution_manager.list(page, per_page) else: - result = ApiDependencies.invoker.services.graph_execution_manager.search(query, page, per_page) + result = ApiDependencies.invoker.invoker_services.graph_execution_manager.search(query, page, per_page) return result @@ -60,7 +60,7 @@ async def get_session( session_id: str = Path(description = "The id of the session to get") ) -> GraphExecutionState: """Gets a session""" - session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) + session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) if session is None: return Response(status_code = 404) else: @@ -80,13 +80,13 @@ async def add_node( node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body(description = "The node to add") ) -> str: """Adds a node to the graph""" - session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) + session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) if session is None: return Response(status_code = 404) try: session.add_node(node) - ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? + ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? return session.id except NodeAlreadyExecutedError: return Response(status_code = 400) @@ -108,13 +108,13 @@ async def update_node( node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body(description = "The new node") ) -> GraphExecutionState: """Updates a node in the graph and removes all linked edges""" - session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) + session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) if session is None: return Response(status_code = 404) try: session.update_node(node_path, node) - ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? + ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? return session except NodeAlreadyExecutedError: return Response(status_code = 400) @@ -135,13 +135,13 @@ async def delete_node( node_path: str = Path(description = "The path to the node to delete") ) -> GraphExecutionState: """Deletes a node in the graph and removes all linked edges""" - session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) + session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) if session is None: return Response(status_code = 404) try: session.delete_node(node_path) - ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? + ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? return session except NodeAlreadyExecutedError: return Response(status_code = 400) @@ -162,13 +162,13 @@ async def add_edge( edge: tuple[EdgeConnection, EdgeConnection] = Body(description = "The edge to add") ) -> GraphExecutionState: """Adds an edge to the graph""" - session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) + session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) if session is None: return Response(status_code = 404) try: session.add_edge(edge) - ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? + ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? return session except NodeAlreadyExecutedError: return Response(status_code = 400) @@ -193,14 +193,14 @@ async def delete_edge( to_field: str = Path(description = "The field of the node the edge is going to") ) -> GraphExecutionState: """Deletes an edge from the graph""" - session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) + session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) if session is None: return Response(status_code = 404) try: edge = (EdgeConnection(node_id = from_node_id, field = from_field), EdgeConnection(node_id = to_node_id, field = to_field)) session.delete_edge(edge) - ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? + ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? return session except NodeAlreadyExecutedError: return Response(status_code = 400) @@ -221,7 +221,7 @@ async def invoke_session( all: bool = Query(default = False, description = "Whether or not to invoke all remaining invocations") ) -> None: """Invokes a session""" - session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) + session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) if session is None: return Response(status_code = 404) diff --git a/ldm/invoke/app/cli_app.py b/ldm/invoke/app/cli_app.py index 9081f3b083..6071afabb2 100644 --- a/ldm/invoke/app/cli_app.py +++ b/ldm/invoke/app/cli_app.py @@ -20,7 +20,7 @@ from .services.image_storage import DiskImageStorage from .services.invocation_queue import MemoryInvocationQueue from .invocations.baseinvocation import BaseInvocation from .services.invocation_services import InvocationServices -from .services.invoker import Invoker +from .services.invoker import Invoker, InvokerServices from .invocations import * from ..args import Args from .services.events import EventServiceBase @@ -171,25 +171,28 @@ def invoke_cli(): output_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../outputs')) + services = InvocationServices( + generate = generate, + events = events, + images = DiskImageStorage(output_folder) + ) + # TODO: build a file/path manager? db_location = os.path.join(output_folder, 'invokeai.db') - services = InvocationServices( - generate = generate, - events = events, - images = DiskImageStorage(output_folder), - queue = MemoryInvocationQueue(), + invoker_services = InvokerServices( + queue = MemoryInvocationQueue(), graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = db_location, table_name = 'graph_executions'), - processor = DefaultInvocationProcessor() + processor = DefaultInvocationProcessor() ) - invoker = Invoker(services) + invoker = Invoker(services, invoker_services) session = invoker.create_execution_state() parser = get_invocation_parser() # Uncomment to print out previous sessions at startup - # print(services.session_manager.list()) + # print(invoker_services.session_manager.list()) # Defaults storage defaults: Dict[str, Any] = dict() @@ -210,7 +213,7 @@ def invoke_cli(): try: # Refresh the state of the session - session = invoker.services.graph_execution_manager.get(session.id) + session = invoker.invoker_services.graph_execution_manager.get(session.id) history = list(get_graph_execution_history(session)) # Split the command for piping @@ -286,7 +289,7 @@ def invoke_cli(): invoker.invoke(session, invoke_all = True) while not session.is_complete(): # Wait some time - session = invoker.services.graph_execution_manager.get(session.id) + session = invoker.invoker_services.graph_execution_manager.get(session.id) time.sleep(0.1) except InvalidArgs: diff --git a/ldm/invoke/app/services/invocation_services.py b/ldm/invoke/app/services/invocation_services.py index 40a64e64e5..9eb5309d3d 100644 --- a/ldm/invoke/app/services/invocation_services.py +++ b/ldm/invoke/app/services/invocation_services.py @@ -1,6 +1,4 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) -from .invocation_queue import InvocationQueueABC -from .item_storage import ItemStorageABC from .image_storage import ImageStorageBase from .events import EventServiceBase from ....generate import Generate @@ -11,23 +9,12 @@ class InvocationServices(): generate: Generate # TODO: wrap Generate, or split it up from model? events: EventServiceBase images: ImageStorageBase - queue: InvocationQueueABC - - # NOTE: we must forward-declare any types that include invocations, since invocations can use services - graph_execution_manager: ItemStorageABC['GraphExecutionState'] - processor: 'InvocationProcessorABC' def __init__(self, generate: Generate, events: EventServiceBase, - images: ImageStorageBase, - queue: InvocationQueueABC, - graph_execution_manager: ItemStorageABC['GraphExecutionState'], - processor: 'InvocationProcessorABC' + images: ImageStorageBase ): self.generate = generate self.events = events self.images = images - self.queue = queue - self.graph_execution_manager = graph_execution_manager - self.processor = processor diff --git a/ldm/invoke/app/services/invoker.py b/ldm/invoke/app/services/invoker.py index 4397a75021..796f541781 100644 --- a/ldm/invoke/app/services/invoker.py +++ b/ldm/invoke/app/services/invoker.py @@ -9,15 +9,34 @@ from .invocation_services import InvocationServices from .invocation_queue import InvocationQueueABC, InvocationQueueItem +class InvokerServices: + """Services used by the Invoker for execution""" + + queue: InvocationQueueABC + graph_execution_manager: ItemStorageABC[GraphExecutionState] + processor: 'InvocationProcessorABC' + + def __init__(self, + queue: InvocationQueueABC, + graph_execution_manager: ItemStorageABC[GraphExecutionState], + processor: 'InvocationProcessorABC'): + self.queue = queue + self.graph_execution_manager = graph_execution_manager + self.processor = processor + + class Invoker: """The invoker, used to execute invocations""" services: InvocationServices + invoker_services: InvokerServices def __init__(self, - services: InvocationServices + services: InvocationServices, # Services used by nodes to perform invocations + invoker_services: InvokerServices # Services used by the invoker for orchestration ): self.services = services + self.invoker_services = invoker_services self._start() @@ -30,11 +49,11 @@ class Invoker: return None # Save the execution state - self.services.graph_execution_manager.set(graph_execution_state) + self.invoker_services.graph_execution_manager.set(graph_execution_state) # Queue the invocation print(f'queueing item {invocation.id}') - self.services.queue.put(InvocationQueueItem( + self.invoker_services.queue.put(InvocationQueueItem( #session_id = session.id, graph_execution_state_id = graph_execution_state.id, invocation_id = invocation.id, @@ -47,7 +66,7 @@ class Invoker: def create_execution_state(self, graph: Graph|None = None) -> GraphExecutionState: """Creates a new execution state for the given graph""" new_state = GraphExecutionState(graph = Graph() if graph is None else graph) - self.services.graph_execution_manager.set(new_state) + self.invoker_services.graph_execution_manager.set(new_state) return new_state @@ -67,8 +86,8 @@ class Invoker: def _start(self) -> None: """Starts the invoker. This is called automatically when the invoker is created.""" - for service in vars(self.services): - self.__start_service(getattr(self.services, service)) + for service in vars(self.invoker_services): + self.__start_service(getattr(self.invoker_services, service)) for service in vars(self.services): self.__start_service(getattr(self.services, service)) @@ -80,10 +99,10 @@ class Invoker: for service in vars(self.services): self.__stop_service(getattr(self.services, service)) - for service in vars(self.services): - self.__stop_service(getattr(self.services, service)) + for service in vars(self.invoker_services): + self.__stop_service(getattr(self.invoker_services, service)) - self.services.queue.put(None) + self.invoker_services.queue.put(None) class InvocationProcessorABC(ABC): diff --git a/ldm/invoke/app/services/processor.py b/ldm/invoke/app/services/processor.py index 9ea4349bbf..9b51a6bcbc 100644 --- a/ldm/invoke/app/services/processor.py +++ b/ldm/invoke/app/services/processor.py @@ -28,11 +28,11 @@ class DefaultInvocationProcessor(InvocationProcessorABC): def __process(self, stop_event: Event): try: while not stop_event.is_set(): - queue_item: InvocationQueueItem = self.__invoker.services.queue.get() + queue_item: InvocationQueueItem = self.__invoker.invoker_services.queue.get() if not queue_item: # Probably stopping continue - graph_execution_state = self.__invoker.services.graph_execution_manager.get(queue_item.graph_execution_state_id) + graph_execution_state = self.__invoker.invoker_services.graph_execution_manager.get(queue_item.graph_execution_state_id) invocation = graph_execution_state.execution_graph.get_node(queue_item.invocation_id) # Send starting event @@ -52,7 +52,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC): graph_execution_state.complete(invocation.id, outputs) # Save the state changes - self.__invoker.services.graph_execution_manager.set(graph_execution_state) + self.__invoker.invoker_services.graph_execution_manager.set(graph_execution_state) # Send complete event self.__invoker.services.events.emit_invocation_complete( diff --git a/tests/nodes/test_graph_execution_state.py b/tests/nodes/test_graph_execution_state.py index 980c262501..0a5dcc7734 100644 --- a/tests/nodes/test_graph_execution_state.py +++ b/tests/nodes/test_graph_execution_state.py @@ -1,11 +1,10 @@ from .test_invoker import create_edge from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation from ldm.invoke.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext -from ldm.invoke.app.services.processor import DefaultInvocationProcessor -from ldm.invoke.app.services.sqlite import SqliteItemStorage, sqlite_memory -from ldm.invoke.app.services.invocation_queue import MemoryInvocationQueue from ldm.invoke.app.services.invocation_services import InvocationServices from ldm.invoke.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState +from ldm.invoke.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation +from ldm.invoke.app.invocations.upscale import UpscaleInvocation import pytest @@ -20,14 +19,7 @@ def simple_graph(): @pytest.fixture def mock_services(): # NOTE: none of these are actually called by the test invocations - return InvocationServices( - generate = None, - events = None, - images = None, - queue = MemoryInvocationQueue(), - graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'), - processor = DefaultInvocationProcessor() - ) + return InvocationServices(generate = None, events = None, images = None) def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]: n = g.next() diff --git a/tests/nodes/test_invoker.py b/tests/nodes/test_invoker.py index e9109728d5..a6d96f61c0 100644 --- a/tests/nodes/test_invoker.py +++ b/tests/nodes/test_invoker.py @@ -2,10 +2,12 @@ from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTe from ldm.invoke.app.services.processor import DefaultInvocationProcessor from ldm.invoke.app.services.sqlite import SqliteItemStorage, sqlite_memory from ldm.invoke.app.services.invocation_queue import MemoryInvocationQueue -from ldm.invoke.app.services.invoker import Invoker +from ldm.invoke.app.services.invoker import Invoker, InvokerServices from ldm.invoke.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext from ldm.invoke.app.services.invocation_services import InvocationServices from ldm.invoke.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState +from ldm.invoke.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation +from ldm.invoke.app.invocations.upscale import UpscaleInvocation import pytest @@ -20,19 +22,21 @@ def simple_graph(): @pytest.fixture def mock_services() -> InvocationServices: # NOTE: none of these are actually called by the test invocations - return InvocationServices( - generate = None, - events = TestEventService(), - images = None, + return InvocationServices(generate = None, events = TestEventService(), images = None) + +@pytest.fixture() +def mock_invoker_services() -> InvokerServices: + return InvokerServices( queue = MemoryInvocationQueue(), graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'), processor = DefaultInvocationProcessor() ) @pytest.fixture() -def mock_invoker(mock_services: InvocationServices) -> Invoker: +def mock_invoker(mock_services: InvocationServices, mock_invoker_services: InvokerServices) -> Invoker: return Invoker( - services = mock_services + services = mock_services, + invoker_services = mock_invoker_services ) def test_can_create_graph_state(mock_invoker: Invoker): @@ -56,13 +60,13 @@ def test_can_invoke(mock_invoker: Invoker, simple_graph): assert invocation_id is not None def has_executed_any(g: GraphExecutionState): - g = mock_invoker.services.graph_execution_manager.get(g.id) + g = mock_invoker.invoker_services.graph_execution_manager.get(g.id) return len(g.executed) > 0 wait_until(lambda: has_executed_any(g), timeout = 5, interval = 1) mock_invoker.stop() - g = mock_invoker.services.graph_execution_manager.get(g.id) + g = mock_invoker.invoker_services.graph_execution_manager.get(g.id) assert len(g.executed) > 0 def test_can_invoke_all(mock_invoker: Invoker, simple_graph): @@ -71,11 +75,11 @@ def test_can_invoke_all(mock_invoker: Invoker, simple_graph): assert invocation_id is not None def has_executed_all(g: GraphExecutionState): - g = mock_invoker.services.graph_execution_manager.get(g.id) + g = mock_invoker.invoker_services.graph_execution_manager.get(g.id) return g.is_complete() wait_until(lambda: has_executed_all(g), timeout = 5, interval = 1) mock_invoker.stop() - g = mock_invoker.services.graph_execution_manager.get(g.id) + g = mock_invoker.invoker_services.graph_execution_manager.get(g.id) assert g.is_complete() From 282ba201d2eaba7e37a162ab2ae07db8676af79b Mon Sep 17 00:00:00 2001 From: mauwii Date: Sun, 26 Feb 2023 21:54:13 +0100 Subject: [PATCH 55/57] Revert "parent 9eed1919c2071f9199996df747c8638c4a75e8fb" This reverts commit 357601e2d673786d42e920bb24d4a1cf55c66540. --- .coveragerc | 6 - .gitignore | 1 - .pytest.ini | 5 - docs/contributing/ARCHITECTURE.md | 93 -- docs/contributing/INVOCATIONS.md | 105 --- ldm/generate.py | 6 - ldm/invoke/app/api/dependencies.py | 83 -- ldm/invoke/app/api/events.py | 54 -- ldm/invoke/app/api/routers/images.py | 57 -- ldm/invoke/app/api/routers/sessions.py | 232 ----- ldm/invoke/app/api/sockets.py | 36 - ldm/invoke/app/api_app.py | 164 ---- ldm/invoke/app/cli_app.py | 306 ------- ldm/invoke/app/invocations/__init__.py | 8 - ldm/invoke/app/invocations/baseinvocation.py | 74 -- ldm/invoke/app/invocations/cv.py | 42 - ldm/invoke/app/invocations/generate.py | 160 ---- ldm/invoke/app/invocations/image.py | 219 ----- ldm/invoke/app/invocations/prompt.py | 9 - ldm/invoke/app/invocations/reconstruct.py | 36 - ldm/invoke/app/invocations/upscale.py | 38 - ldm/invoke/app/services/__init__.py | 0 ldm/invoke/app/services/events.py | 78 -- .../app/services/generate_initializer.py | 233 ----- ldm/invoke/app/services/graph.py | 797 ------------------ ldm/invoke/app/services/image_storage.py | 104 --- ldm/invoke/app/services/invocation_queue.py | 46 - .../app/services/invocation_services.py | 20 - ldm/invoke/app/services/invoker.py | 109 --- ldm/invoke/app/services/item_storage.py | 57 -- ldm/invoke/app/services/processor.py | 78 -- ldm/invoke/app/services/sqlite.py | 119 --- pyproject.toml | 5 - scripts/invoke-new.py | 20 - static/dream_web/test.html | 206 ----- tests/__init__.py | 0 tests/nodes/__init__.py | 0 tests/nodes/test_graph_execution_state.py | 114 --- tests/nodes/test_invoker.py | 85 -- tests/nodes/test_node_graph.py | 501 ----------- tests/nodes/test_nodes.py | 92 -- tests/nodes/test_sqlite.py | 112 --- 42 files changed, 4510 deletions(-) delete mode 100644 .coveragerc delete mode 100644 .pytest.ini delete mode 100644 docs/contributing/ARCHITECTURE.md delete mode 100644 docs/contributing/INVOCATIONS.md delete mode 100644 ldm/invoke/app/api/dependencies.py delete mode 100644 ldm/invoke/app/api/events.py delete mode 100644 ldm/invoke/app/api/routers/images.py delete mode 100644 ldm/invoke/app/api/routers/sessions.py delete mode 100644 ldm/invoke/app/api/sockets.py delete mode 100644 ldm/invoke/app/api_app.py delete mode 100644 ldm/invoke/app/cli_app.py delete mode 100644 ldm/invoke/app/invocations/__init__.py delete mode 100644 ldm/invoke/app/invocations/baseinvocation.py delete mode 100644 ldm/invoke/app/invocations/cv.py delete mode 100644 ldm/invoke/app/invocations/generate.py delete mode 100644 ldm/invoke/app/invocations/image.py delete mode 100644 ldm/invoke/app/invocations/prompt.py delete mode 100644 ldm/invoke/app/invocations/reconstruct.py delete mode 100644 ldm/invoke/app/invocations/upscale.py delete mode 100644 ldm/invoke/app/services/__init__.py delete mode 100644 ldm/invoke/app/services/events.py delete mode 100644 ldm/invoke/app/services/generate_initializer.py delete mode 100644 ldm/invoke/app/services/graph.py delete mode 100644 ldm/invoke/app/services/image_storage.py delete mode 100644 ldm/invoke/app/services/invocation_queue.py delete mode 100644 ldm/invoke/app/services/invocation_services.py delete mode 100644 ldm/invoke/app/services/invoker.py delete mode 100644 ldm/invoke/app/services/item_storage.py delete mode 100644 ldm/invoke/app/services/processor.py delete mode 100644 ldm/invoke/app/services/sqlite.py delete mode 100644 scripts/invoke-new.py delete mode 100644 static/dream_web/test.html delete mode 100644 tests/__init__.py delete mode 100644 tests/nodes/__init__.py delete mode 100644 tests/nodes/test_graph_execution_state.py delete mode 100644 tests/nodes/test_invoker.py delete mode 100644 tests/nodes/test_node_graph.py delete mode 100644 tests/nodes/test_nodes.py delete mode 100644 tests/nodes/test_sqlite.py diff --git a/.coveragerc b/.coveragerc deleted file mode 100644 index 8232fc4b93..0000000000 --- a/.coveragerc +++ /dev/null @@ -1,6 +0,0 @@ -[run] -omit='.env/*' -source='.' - -[report] -show_missing = true diff --git a/.gitignore b/.gitignore index 9b33e07164..9adb0be85a 100644 --- a/.gitignore +++ b/.gitignore @@ -68,7 +68,6 @@ htmlcov/ .cache nosetests.xml coverage.xml -cov.xml *.cover *.py,cover .hypothesis/ diff --git a/.pytest.ini b/.pytest.ini deleted file mode 100644 index 16ccfafe80..0000000000 --- a/.pytest.ini +++ /dev/null @@ -1,5 +0,0 @@ -[pytest] -DJANGO_SETTINGS_MODULE = webtas.settings -; python_files = tests.py test_*.py *_tests.py - -addopts = --cov=. --cov-config=.coveragerc --cov-report xml:cov.xml diff --git a/docs/contributing/ARCHITECTURE.md b/docs/contributing/ARCHITECTURE.md deleted file mode 100644 index d74df94492..0000000000 --- a/docs/contributing/ARCHITECTURE.md +++ /dev/null @@ -1,93 +0,0 @@ -# Invoke.AI Architecture - -```mermaid -flowchart TB - - subgraph apps[Applications] - webui[WebUI] - cli[CLI] - - subgraph webapi[Web API] - api[HTTP API] - sio[Socket.IO] - end - - end - - subgraph invoke[Invoke] - direction LR - invoker - services - sessions - invocations - end - - subgraph core[AI Core] - Generate - end - - webui --> webapi - webapi --> invoke - cli --> invoke - - invoker --> services & sessions - invocations --> services - sessions --> invocations - - services --> core - - %% Styles - classDef sg fill:#5028C8,font-weight:bold,stroke-width:2,color:#fff,stroke:#14141A - classDef default stroke-width:2px,stroke:#F6B314,color:#fff,fill:#14141A - - class apps,webapi,invoke,core sg - -``` - -## Applications - -Applications are built on top of the invoke framework. They should construct `invoker` and then interact through it. They should avoid interacting directly with core code in order to support a variety of configurations. - -### Web UI - -The Web UI is built on top of an HTTP API built with [FastAPI](https://fastapi.tiangolo.com/) and [Socket.IO](https://socket.io/). The frontend code is found in `/frontend` and the backend code is found in `/ldm/invoke/app/api_app.py` and `/ldm/invoke/app/api/`. The code is further organized as such: - -| Component | Description | -| --- | --- | -| api_app.py | Sets up the API app, annotates the OpenAPI spec with additional data, and runs the API | -| dependencies | Creates all invoker services and the invoker, and provides them to the API | -| events | An eventing system that could in the future be adapted to support horizontal scale-out | -| sockets | The Socket.IO interface - handles listening to and emitting session events (events are defined in the events service module) | -| routers | API definitions for different areas of API functionality | - -### CLI - -The CLI is built automatically from invocation metadata, and also supports invocation piping and auto-linking. Code is available in `/ldm/invoke/app/cli_app.py`. - -## Invoke - -The Invoke framework provides the interface to the underlying AI systems and is built with flexibility and extensibility in mind. There are four major concepts: invoker, sessions, invocations, and services. - -### Invoker - -The invoker (`/ldm/invoke/app/services/invoker.py`) is the primary interface through which applications interact with the framework. Its primary purpose is to create, manage, and invoke sessions. It also maintains two sets of services: -- **invocation services**, which are used by invocations to interact with core functionality. -- **invoker services**, which are used by the invoker to manage sessions and manage the invocation queue. - -### Sessions - -Invocations and links between them form a graph, which is maintained in a session. Sessions can be queued for invocation, which will execute their graph (either the next ready invocation, or all invocations). Sessions also maintain execution history for the graph (including storage of any outputs). An invocation may be added to a session at any time, and there is capability to add and entire graph at once, as well as to automatically link new invocations to previous invocations. Invocations can not be deleted or modified once added. - -The session graph does not support looping. This is left as an application problem to prevent additional complexity in the graph. - -### Invocations - -Invocations represent individual units of execution, with inputs and outputs. All invocations are located in `/ldm/invoke/app/invocations`, and are all automatically discovered and made available in the applications. These are the primary way to expose new functionality in Invoke.AI, and the [implementation guide](INVOCATIONS.md) explains how to add new invocations. - -### Services - -Services provide invocations access AI Core functionality and other necessary functionality (e.g. image storage). These are available in `/ldm/invoke/app/services`. As a general rule, new services should provide an interface as an abstract base class, and may provide a lightweight local implementation by default in their module. The goal for all services should be to enable the usage of different implementations (e.g. using cloud storage for image storage), but should not load any module dependencies unless that implementation has been used (i.e. don't import anything that won't be used, especially if it's expensive to import). - -## AI Core - -The AI Core is represented by the rest of the code base (i.e. the code outside of `/ldm/invoke/app/`). diff --git a/docs/contributing/INVOCATIONS.md b/docs/contributing/INVOCATIONS.md deleted file mode 100644 index c8a97c19e4..0000000000 --- a/docs/contributing/INVOCATIONS.md +++ /dev/null @@ -1,105 +0,0 @@ -# Invocations - -Invocations represent a single operation, its inputs, and its outputs. These operations and their outputs can be chained together to generate and modify images. - -## Creating a new invocation - -To create a new invocation, either find the appropriate module file in `/ldm/invoke/app/invocations` to add your invocation to, or create a new one in that folder. All invocations in that folder will be discovered and made available to the CLI and API automatically. Invocations make use of [typing](https://docs.python.org/3/library/typing.html) and [pydantic](https://pydantic-docs.helpmanual.io/) for validation and integration into the CLI and API. - -An invocation looks like this: - -```py -class UpscaleInvocation(BaseInvocation): - """Upscales an image.""" - type: Literal['upscale'] = 'upscale' - - # Inputs - image: Union[ImageField,None] = Field(description="The input image") - strength: float = Field(default=0.75, gt=0, le=1, description="The strength") - level: Literal[2,4] = Field(default=2, description = "The upscale level") - - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get(self.image.image_type, self.image.image_name) - results = context.services.generate.upscale_and_reconstruct( - image_list = [[image, 0]], - upscale = (self.level, self.strength), - strength = 0.0, # GFPGAN strength - save_original = False, - image_callback = None, - ) - - # Results are image and seed, unwrap for now - # TODO: can this return multiple results? - image_type = ImageType.RESULT - image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) - context.services.images.save(image_type, image_name, results[0][0]) - return ImageOutput( - image = ImageField(image_type = image_type, image_name = image_name) - ) -``` - -Each portion is important to implement correctly. - -### Class definition and type -```py -class UpscaleInvocation(BaseInvocation): - """Upscales an image.""" - type: Literal['upscale'] = 'upscale' -``` -All invocations must derive from `BaseInvocation`. They should have a docstring that declares what they do in a single, short line. They should also have a `type` with a type hint that's `Literal["command_name"]`, where `command_name` is what the user will type on the CLI or use in the API to create this invocation. The `command_name` must be unique. The `type` must be assigned to the value of the literal in the type hint. - -### Inputs -```py - # Inputs - image: Union[ImageField,None] = Field(description="The input image") - strength: float = Field(default=0.75, gt=0, le=1, description="The strength") - level: Literal[2,4] = Field(default=2, description="The upscale level") -``` -Inputs consist of three parts: a name, a type hint, and a `Field` with default, description, and validation information. For example: -| Part | Value | Description | -| ---- | ----- | ----------- | -| Name | `strength` | This field is referred to as `strength` | -| Type Hint | `float` | This field must be of type `float` | -| Field | `Field(default=0.75, gt=0, le=1, description="The strength")` | The default value is `0.75`, the value must be in the range (0,1], and help text will show "The strength" for this field. | - -Notice that `image` has type `Union[ImageField,None]`. The `Union` allows this field to be parsed with `None` as a value, which enables linking to previous invocations. All fields should either provide a default value or allow `None` as a value, so that they can be overwritten with a linked output from another invocation. - -The special type `ImageField` is also used here. All images are passed as `ImageField`, which protects them from pydantic validation errors (since images only ever come from links). - -Finally, note that for all linking, the `type` of the linked fields must match. If the `name` also matches, then the field can be **automatically linked** to a previous invocation by name and matching. - -### Invoke Function -```py - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get(self.image.image_type, self.image.image_name) - results = context.services.generate.upscale_and_reconstruct( - image_list = [[image, 0]], - upscale = (self.level, self.strength), - strength = 0.0, # GFPGAN strength - save_original = False, - image_callback = None, - ) - - # Results are image and seed, unwrap for now - image_type = ImageType.RESULT - image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) - context.services.images.save(image_type, image_name, results[0][0]) - return ImageOutput( - image = ImageField(image_type = image_type, image_name = image_name) - ) -``` -The `invoke` function is the last portion of an invocation. It is provided an `InvocationContext` which contains services to perform work as well as a `session_id` for use as needed. It should return a class with output values that derives from `BaseInvocationOutput`. - -Before being called, the invocation will have all of its fields set from defaults, inputs, and finally links (overriding in that order). - -Assume that this invocation may be running simultaneously with other invocations, may be running on another machine, or in other interesting scenarios. If you need functionality, please provide it as a service in the `InvocationServices` class, and make sure it can be overridden. - -### Outputs -```py -class ImageOutput(BaseInvocationOutput): - """Base class for invocations that output an image""" - type: Literal['image'] = 'image' - - image: ImageField = Field(default=None, description="The output image") -``` -Output classes look like an invocation class without the invoke method. Prefer to use an existing output class if available, and prefer to name inputs the same as outputs when possible, to promote automatic invocation linking. diff --git a/ldm/generate.py b/ldm/generate.py index 256f214b25..413a1e25cb 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -1030,8 +1030,6 @@ class Generate: image_callback=None, prefix=None, ): - - results = [] for r in image_list: image, seed = r try: @@ -1085,10 +1083,6 @@ class Generate: else: r[0] = image - results.append([image, seed]) - - return results - def apply_textmask( self, image_path: str, prompt: str, callback, threshold: float = 0.5 ): diff --git a/ldm/invoke/app/api/dependencies.py b/ldm/invoke/app/api/dependencies.py deleted file mode 100644 index 60dd522803..0000000000 --- a/ldm/invoke/app/api/dependencies.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) - -from argparse import Namespace -import os - -from ..services.processor import DefaultInvocationProcessor - -from ..services.graph import GraphExecutionState -from ..services.sqlite import SqliteItemStorage - -from ...globals import Globals - -from ..services.image_storage import DiskImageStorage -from ..services.invocation_queue import MemoryInvocationQueue -from ..services.invocation_services import InvocationServices -from ..services.invoker import Invoker, InvokerServices -from ..services.generate_initializer import get_generate -from .events import FastAPIEventService - - -# TODO: is there a better way to achieve this? -def check_internet()->bool: - ''' - Return true if the internet is reachable. - It does this by pinging huggingface.co. - ''' - import urllib.request - host = 'http://huggingface.co' - try: - urllib.request.urlopen(host,timeout=1) - return True - except: - return False - - -class ApiDependencies: - """Contains and initializes all dependencies for the API""" - invoker: Invoker = None - - @staticmethod - def initialize( - args, - config, - event_handler_id: int - ): - Globals.try_patchmatch = args.patchmatch - Globals.always_use_cpu = args.always_use_cpu - Globals.internet_available = args.internet_available and check_internet() - Globals.disable_xformers = not args.xformers - Globals.ckpt_convert = args.ckpt_convert - - # TODO: Use a logger - print(f'>> Internet connectivity is {Globals.internet_available}') - - generate = get_generate(args, config) - - events = FastAPIEventService(event_handler_id) - - output_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../../outputs')) - - images = DiskImageStorage(output_folder) - - services = InvocationServices( - generate = generate, - events = events, - images = images - ) - - # TODO: build a file/path manager? - db_location = os.path.join(output_folder, 'invokeai.db') - - invoker_services = InvokerServices( - queue = MemoryInvocationQueue(), - graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = db_location, table_name = 'graph_executions'), - processor = DefaultInvocationProcessor() - ) - - ApiDependencies.invoker = Invoker(services, invoker_services) - - @staticmethod - def shutdown(): - if ApiDependencies.invoker: - ApiDependencies.invoker.stop() diff --git a/ldm/invoke/app/api/events.py b/ldm/invoke/app/api/events.py deleted file mode 100644 index 701b48a316..0000000000 --- a/ldm/invoke/app/api/events.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) - -import asyncio -from queue import Empty, Queue -from typing import Any -from fastapi_events.dispatcher import dispatch -from ..services.events import EventServiceBase -import threading - -class FastAPIEventService(EventServiceBase): - event_handler_id: int - __queue: Queue - __stop_event: threading.Event - - def __init__(self, event_handler_id: int) -> None: - self.event_handler_id = event_handler_id - self.__queue = Queue() - self.__stop_event = threading.Event() - asyncio.create_task(self.__dispatch_from_queue(stop_event = self.__stop_event)) - - super().__init__() - - - def stop(self, *args, **kwargs): - self.__stop_event.set() - self.__queue.put(None) - - - def dispatch(self, event_name: str, payload: Any) -> None: - self.__queue.put(dict( - event_name = event_name, - payload = payload - )) - - - async def __dispatch_from_queue(self, stop_event: threading.Event): - """Get events on from the queue and dispatch them, from the correct thread""" - while not stop_event.is_set(): - try: - event = self.__queue.get(block = False) - if not event: # Probably stopping - continue - - dispatch( - event.get('event_name'), - payload = event.get('payload'), - middleware_id = self.event_handler_id) - - except Empty: - await asyncio.sleep(0.001) - pass - - except asyncio.CancelledError as e: - raise e # Raise a proper error diff --git a/ldm/invoke/app/api/routers/images.py b/ldm/invoke/app/api/routers/images.py deleted file mode 100644 index 1ae116e49d..0000000000 --- a/ldm/invoke/app/api/routers/images.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) - -from datetime import datetime, timezone -from fastapi import Path, UploadFile, Request -from fastapi.routing import APIRouter -from fastapi.responses import FileResponse, Response -from PIL import Image -from ...services.image_storage import ImageType -from ..dependencies import ApiDependencies - -images_router = APIRouter( - prefix = '/v1/images', - tags = ['images'] -) - - -@images_router.get('/{image_type}/{image_name}', - operation_id = 'get_image' - ) -async def get_image( - image_type: ImageType = Path(description = "The type of image to get"), - image_name: str = Path(description = "The name of the image to get") -): - """Gets a result""" - # TODO: This is not really secure at all. At least make sure only output results are served - filename = ApiDependencies.invoker.services.images.get_path(image_type, image_name) - return FileResponse(filename) - -@images_router.post('/uploads/', - operation_id = 'upload_image', - responses = { - 201: {'description': 'The image was uploaded successfully'}, - 404: {'description': 'Session not found'} - }) -async def upload_image( - file: UploadFile, - request: Request -): - if not file.content_type.startswith('image'): - return Response(status_code = 415) - - contents = await file.read() - try: - im = Image.open(contents) - except: - # Error opening the image - return Response(status_code = 415) - - filename = f'{str(int(datetime.now(timezone.utc).timestamp()))}.png' - ApiDependencies.invoker.services.images.save(ImageType.UPLOAD, filename, im) - - return Response( - status_code=201, - headers = { - 'Location': request.url_for('get_image', image_type=ImageType.UPLOAD, image_name=filename) - } - ) diff --git a/ldm/invoke/app/api/routers/sessions.py b/ldm/invoke/app/api/routers/sessions.py deleted file mode 100644 index 77008ad6e4..0000000000 --- a/ldm/invoke/app/api/routers/sessions.py +++ /dev/null @@ -1,232 +0,0 @@ -# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) - -from typing import List, Optional, Union, Annotated -from fastapi import Query, Path, Body -from fastapi.routing import APIRouter -from fastapi.responses import Response -from pydantic.fields import Field - -from ...services.item_storage import PaginatedResults -from ..dependencies import ApiDependencies -from ...invocations.baseinvocation import BaseInvocation -from ...services.graph import EdgeConnection, Graph, GraphExecutionState, NodeAlreadyExecutedError -from ...invocations import * - -session_router = APIRouter( - prefix = '/v1/sessions', - tags = ['sessions'] -) - - -@session_router.post('/', - operation_id = 'create_session', - responses = { - 200: {"model": GraphExecutionState}, - 400: {'description': 'Invalid json'} - }) -async def create_session( - graph: Optional[Graph] = Body(default = None, description = "The graph to initialize the session with") -) -> GraphExecutionState: - """Creates a new session, optionally initializing it with an invocation graph""" - session = ApiDependencies.invoker.create_execution_state(graph) - return session - - -@session_router.get('/', - operation_id = 'list_sessions', - responses = { - 200: {"model": PaginatedResults[GraphExecutionState]} - }) -async def list_sessions( - page: int = Query(default = 0, description = "The page of results to get"), - per_page: int = Query(default = 10, description = "The number of results per page"), - query: str = Query(default = '', description = "The query string to search for") -) -> PaginatedResults[GraphExecutionState]: - """Gets a list of sessions, optionally searching""" - if filter == '': - result = ApiDependencies.invoker.invoker_services.graph_execution_manager.list(page, per_page) - else: - result = ApiDependencies.invoker.invoker_services.graph_execution_manager.search(query, page, per_page) - return result - - -@session_router.get('/{session_id}', - operation_id = 'get_session', - responses = { - 200: {"model": GraphExecutionState}, - 404: {'description': 'Session not found'} - }) -async def get_session( - session_id: str = Path(description = "The id of the session to get") -) -> GraphExecutionState: - """Gets a session""" - session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) - if session is None: - return Response(status_code = 404) - else: - return session - - -@session_router.post('/{session_id}/nodes', - operation_id = 'add_node', - responses = { - 200: {"model": str}, - 400: {'description': 'Invalid node or link'}, - 404: {'description': 'Session not found'} - } -) -async def add_node( - session_id: str = Path(description = "The id of the session"), - node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body(description = "The node to add") -) -> str: - """Adds a node to the graph""" - session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) - if session is None: - return Response(status_code = 404) - - try: - session.add_node(node) - ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? - return session.id - except NodeAlreadyExecutedError: - return Response(status_code = 400) - except IndexError: - return Response(status_code = 400) - - -@session_router.put('/{session_id}/nodes/{node_path}', - operation_id = 'update_node', - responses = { - 200: {"model": GraphExecutionState}, - 400: {'description': 'Invalid node or link'}, - 404: {'description': 'Session not found'} - } -) -async def update_node( - session_id: str = Path(description = "The id of the session"), - node_path: str = Path(description = "The path to the node in the graph"), - node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body(description = "The new node") -) -> GraphExecutionState: - """Updates a node in the graph and removes all linked edges""" - session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) - if session is None: - return Response(status_code = 404) - - try: - session.update_node(node_path, node) - ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? - return session - except NodeAlreadyExecutedError: - return Response(status_code = 400) - except IndexError: - return Response(status_code = 400) - - -@session_router.delete('/{session_id}/nodes/{node_path}', - operation_id = 'delete_node', - responses = { - 200: {"model": GraphExecutionState}, - 400: {'description': 'Invalid node or link'}, - 404: {'description': 'Session not found'} - } -) -async def delete_node( - session_id: str = Path(description = "The id of the session"), - node_path: str = Path(description = "The path to the node to delete") -) -> GraphExecutionState: - """Deletes a node in the graph and removes all linked edges""" - session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) - if session is None: - return Response(status_code = 404) - - try: - session.delete_node(node_path) - ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? - return session - except NodeAlreadyExecutedError: - return Response(status_code = 400) - except IndexError: - return Response(status_code = 400) - - -@session_router.post('/{session_id}/edges', - operation_id = 'add_edge', - responses = { - 200: {"model": GraphExecutionState}, - 400: {'description': 'Invalid node or link'}, - 404: {'description': 'Session not found'} - } -) -async def add_edge( - session_id: str = Path(description = "The id of the session"), - edge: tuple[EdgeConnection, EdgeConnection] = Body(description = "The edge to add") -) -> GraphExecutionState: - """Adds an edge to the graph""" - session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) - if session is None: - return Response(status_code = 404) - - try: - session.add_edge(edge) - ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? - return session - except NodeAlreadyExecutedError: - return Response(status_code = 400) - except IndexError: - return Response(status_code = 400) - - -# TODO: the edge being in the path here is really ugly, find a better solution -@session_router.delete('/{session_id}/edges/{from_node_id}/{from_field}/{to_node_id}/{to_field}', - operation_id = 'delete_edge', - responses = { - 200: {"model": GraphExecutionState}, - 400: {'description': 'Invalid node or link'}, - 404: {'description': 'Session not found'} - } -) -async def delete_edge( - session_id: str = Path(description = "The id of the session"), - from_node_id: str = Path(description = "The id of the node the edge is coming from"), - from_field: str = Path(description = "The field of the node the edge is coming from"), - to_node_id: str = Path(description = "The id of the node the edge is going to"), - to_field: str = Path(description = "The field of the node the edge is going to") -) -> GraphExecutionState: - """Deletes an edge from the graph""" - session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) - if session is None: - return Response(status_code = 404) - - try: - edge = (EdgeConnection(node_id = from_node_id, field = from_field), EdgeConnection(node_id = to_node_id, field = to_field)) - session.delete_edge(edge) - ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? - return session - except NodeAlreadyExecutedError: - return Response(status_code = 400) - except IndexError: - return Response(status_code = 400) - - -@session_router.put('/{session_id}/invoke', - operation_id = 'invoke_session', - responses = { - 200: {"model": None}, - 202: {'description': 'The invocation is queued'}, - 400: {'description': 'The session has no invocations ready to invoke'}, - 404: {'description': 'Session not found'} - }) -async def invoke_session( - session_id: str = Path(description = "The id of the session to invoke"), - all: bool = Query(default = False, description = "Whether or not to invoke all remaining invocations") -) -> None: - """Invokes a session""" - session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) - if session is None: - return Response(status_code = 404) - - if session.is_complete(): - return Response(status_code = 400) - - ApiDependencies.invoker.invoke(session, invoke_all = all) - return Response(status_code=202) diff --git a/ldm/invoke/app/api/sockets.py b/ldm/invoke/app/api/sockets.py deleted file mode 100644 index eb4d5403c0..0000000000 --- a/ldm/invoke/app/api/sockets.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) - -from fastapi import FastAPI -from fastapi_socketio import SocketManager -from fastapi_events.handlers.local import local_handler -from fastapi_events.typing import Event -from ..services.events import EventServiceBase - -class SocketIO: - __sio: SocketManager - - def __init__(self, app: FastAPI): - self.__sio = SocketManager(app = app) - self.__sio.on('subscribe', handler=self._handle_sub) - self.__sio.on('unsubscribe', handler=self._handle_unsub) - - local_handler.register( - event_name = EventServiceBase.session_event, - _func=self._handle_session_event - ) - - async def _handle_session_event(self, event: Event): - await self.__sio.emit( - event = event[1]['event'], - data = event[1]['data'], - room = event[1]['data']['graph_execution_state_id'] - ) - - async def _handle_sub(self, sid, data, *args, **kwargs): - if 'session' in data: - self.__sio.enter_room(sid, data['session']) - - # @app.sio.on('unsubscribe') - async def _handle_unsub(self, sid, data, *args, **kwargs): - if 'session' in data: - self.__sio.leave_room(sid, data['session']) diff --git a/ldm/invoke/app/api_app.py b/ldm/invoke/app/api_app.py deleted file mode 100644 index db79b0d7e8..0000000000 --- a/ldm/invoke/app/api_app.py +++ /dev/null @@ -1,164 +0,0 @@ -# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) - -import asyncio -from inspect import signature -from fastapi import FastAPI -from fastapi.openapi.utils import get_openapi -from fastapi.openapi.docs import get_swagger_ui_html, get_redoc_html -from fastapi.staticfiles import StaticFiles -from fastapi_events.middleware import EventHandlerASGIMiddleware -from fastapi_events.handlers.local import local_handler -from fastapi.middleware.cors import CORSMiddleware -from pydantic.schema import schema -import uvicorn -from .api.sockets import SocketIO -from .invocations import * -from .invocations.baseinvocation import BaseInvocation -from .api.routers import images, sessions -from .api.dependencies import ApiDependencies -from ..args import Args - -# Create the app -# TODO: create this all in a method so configuration/etc. can be passed in? -app = FastAPI( - title = "Invoke AI", - docs_url = None, - redoc_url = None -) - -# Add event handler -event_handler_id: int = id(app) -app.add_middleware( - EventHandlerASGIMiddleware, - handlers = [local_handler], # TODO: consider doing this in services to support different configurations - middleware_id = event_handler_id) - -# Add CORS -# TODO: use configuration for this -origins = [] -app.add_middleware( - CORSMiddleware, - allow_origins=origins, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - -socket_io = SocketIO(app) - -config = {} - -# Add startup event to load dependencies -@app.on_event('startup') -async def startup_event(): - args = Args() - config = args.parse_args() - - ApiDependencies.initialize( - args = args, - config = config, - event_handler_id = event_handler_id - ) - -# Shut down threads -@app.on_event('shutdown') -async def shutdown_event(): - ApiDependencies.shutdown() - -# Include all routers -# TODO: REMOVE -# app.include_router( -# invocation.invocation_router, -# prefix = '/api') - -app.include_router( - sessions.session_router, - prefix = '/api' -) - -app.include_router( - images.images_router, - prefix = '/api' -) - -# Build a custom OpenAPI to include all outputs -# TODO: can outputs be included on metadata of invocation schemas somehow? -def custom_openapi(): - if app.openapi_schema: - return app.openapi_schema - openapi_schema = get_openapi( - title = app.title, - description = "An API for invoking AI image operations", - version = "1.0.0", - routes = app.routes - ) - - # Add all outputs - all_invocations = BaseInvocation.get_invocations() - output_types = set() - output_type_titles = dict() - for invoker in all_invocations: - output_type = signature(invoker.invoke).return_annotation - output_types.add(output_type) - - output_schemas = schema(output_types, ref_prefix="#/components/schemas/") - for schema_key, output_schema in output_schemas['definitions'].items(): - openapi_schema["components"]["schemas"][schema_key] = output_schema - - # TODO: note that we assume the schema_key here is the TYPE.__name__ - # This could break in some cases, figure out a better way to do it - output_type_titles[schema_key] = output_schema['title'] - - # Add a reference to the output type to additionalProperties of the invoker schema - for invoker in all_invocations: - invoker_name = invoker.__name__ - output_type = signature(invoker.invoke).return_annotation - output_type_title = output_type_titles[output_type.__name__] - invoker_schema = openapi_schema["components"]["schemas"][invoker_name] - outputs_ref = { '$ref': f'#/components/schemas/{output_type_title}' } - if 'additionalProperties' not in invoker_schema: - invoker_schema['additionalProperties'] = {} - - invoker_schema['additionalProperties']['outputs'] = outputs_ref - - app.openapi_schema = openapi_schema - return app.openapi_schema - -app.openapi = custom_openapi - -# Override API doc favicons -app.mount('/static', StaticFiles(directory='static/dream_web'), name='static') - -@app.get("/docs", include_in_schema=False) -def overridden_swagger(): - return get_swagger_ui_html( - openapi_url=app.openapi_url, - title=app.title, - swagger_favicon_url="/static/favicon.ico" - ) - -@app.get("/redoc", include_in_schema=False) -def overridden_redoc(): - return get_redoc_html( - openapi_url=app.openapi_url, - title=app.title, - redoc_favicon_url="/static/favicon.ico" - ) - -def invoke_api(): - # Start our own event loop for eventing usage - # TODO: determine if there's a better way to do this - loop = asyncio.new_event_loop() - config = uvicorn.Config( - app = app, - host = "0.0.0.0", - port = 9090, - loop = loop) - # Use access_log to turn off logging - - server = uvicorn.Server(config) - loop.run_until_complete(server.serve()) - - -if __name__ == "__main__": - invoke_api() diff --git a/ldm/invoke/app/cli_app.py b/ldm/invoke/app/cli_app.py deleted file mode 100644 index 6071afabb2..0000000000 --- a/ldm/invoke/app/cli_app.py +++ /dev/null @@ -1,306 +0,0 @@ -# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) - -import argparse -import shlex -import os -import time -from typing import Any, Dict, Iterable, Literal, Union, get_args, get_origin, get_type_hints -from pydantic import BaseModel -from pydantic.fields import Field - -from .services.processor import DefaultInvocationProcessor - -from .services.graph import EdgeConnection, GraphExecutionState - -from .services.sqlite import SqliteItemStorage - -from .invocations.image import ImageField -from .services.generate_initializer import get_generate -from .services.image_storage import DiskImageStorage -from .services.invocation_queue import MemoryInvocationQueue -from .invocations.baseinvocation import BaseInvocation -from .services.invocation_services import InvocationServices -from .services.invoker import Invoker, InvokerServices -from .invocations import * -from ..args import Args -from .services.events import EventServiceBase - - -class InvocationCommand(BaseModel): - invocation: Union[BaseInvocation.get_invocations()] = Field(discriminator="type") - - -class InvalidArgs(Exception): - pass - - -def get_invocation_parser() -> argparse.ArgumentParser: - - # Create invocation parser - parser = argparse.ArgumentParser() - def exit(*args, **kwargs): - raise InvalidArgs - parser.exit = exit - - subparsers = parser.add_subparsers(dest='type') - invocation_parsers = dict() - - # Add history parser - history_parser = subparsers.add_parser('history', help="Shows the invocation history") - history_parser.add_argument('count', nargs='?', default=5, type=int, help="The number of history entries to show") - - # Add default parser - default_parser = subparsers.add_parser('default', help="Define a default value for all inputs with a specified name") - default_parser.add_argument('input', type=str, help="The input field") - default_parser.add_argument('value', help="The default value") - - default_parser = subparsers.add_parser('reset_default', help="Resets a default value") - default_parser.add_argument('input', type=str, help="The input field") - - # Create subparsers for each invocation - invocations = BaseInvocation.get_all_subclasses() - for invocation in invocations: - hints = get_type_hints(invocation) - cmd_name = get_args(hints['type'])[0] - command_parser = subparsers.add_parser(cmd_name, help=invocation.__doc__) - invocation_parsers[cmd_name] = command_parser - - # Add linking capability - command_parser.add_argument('--link', '-l', action='append', nargs=3, - help="A link in the format 'dest_field source_node source_field'. source_node can be relative to history (e.g. -1)") - - command_parser.add_argument('--link_node', '-ln', action='append', - help="A link from all fields in the specified node. Node can be relative to history (e.g. -1)") - - # Convert all fields to arguments - fields = invocation.__fields__ - for name, field in fields.items(): - if name in ['id', 'type']: - continue - - if get_origin(field.type_) == Literal: - allowed_values = get_args(field.type_) - allowed_types = set() - for val in allowed_values: - allowed_types.add(type(val)) - allowed_types_list = list(allowed_types) - field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] - - command_parser.add_argument( - f"--{name}", - dest=name, - type=field_type, - default=field.default, - choices = allowed_values, - help=field.field_info.description - ) - else: - command_parser.add_argument( - f"--{name}", - dest=name, - type=field.type_, - default=field.default, - help=field.field_info.description - ) - - return parser - - -def get_invocation_command(invocation) -> str: - fields = invocation.__fields__.items() - type_hints = get_type_hints(type(invocation)) - command = [invocation.type] - for name,field in fields: - if name in ['id', 'type']: - continue - - # TODO: add links - - # Skip image fields when serializing command - type_hint = type_hints.get(name) or None - if type_hint is ImageField or ImageField in get_args(type_hint): - continue - - field_value = getattr(invocation, name) - field_default = field.default - if field_value != field_default: - if type_hint is str or str in get_args(type_hint): - command.append(f'--{name} "{field_value}"') - else: - command.append(f'--{name} {field_value}') - - return ' '.join(command) - - -def get_graph_execution_history(graph_execution_state: GraphExecutionState) -> Iterable[str]: - """Gets the history of fully-executed invocations for a graph execution""" - return (n for n in reversed(graph_execution_state.executed_history) if n in graph_execution_state.graph.nodes) - - -def generate_matching_edges(a: BaseInvocation, b: BaseInvocation) -> list[tuple[EdgeConnection, EdgeConnection]]: - """Generates all possible edges between two invocations""" - atype = type(a) - btype = type(b) - - aoutputtype = atype.get_output_type() - - afields = get_type_hints(aoutputtype) - bfields = get_type_hints(btype) - - matching_fields = set(afields.keys()).intersection(bfields.keys()) - - # Remove invalid fields - invalid_fields = set(['type', 'id']) - matching_fields = matching_fields.difference(invalid_fields) - - edges = [(EdgeConnection(node_id = a.id, field = field), EdgeConnection(node_id = b.id, field = field)) for field in matching_fields] - return edges - - -def invoke_cli(): - args = Args() - config = args.parse_args() - - generate = get_generate(args, config) - - # NOTE: load model on first use, uncomment to load at startup - # TODO: Make this a config option? - #generate.load_model() - - events = EventServiceBase() - - output_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../outputs')) - - services = InvocationServices( - generate = generate, - events = events, - images = DiskImageStorage(output_folder) - ) - - # TODO: build a file/path manager? - db_location = os.path.join(output_folder, 'invokeai.db') - - invoker_services = InvokerServices( - queue = MemoryInvocationQueue(), - graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = db_location, table_name = 'graph_executions'), - processor = DefaultInvocationProcessor() - ) - - invoker = Invoker(services, invoker_services) - session = invoker.create_execution_state() - - parser = get_invocation_parser() - - # Uncomment to print out previous sessions at startup - # print(invoker_services.session_manager.list()) - - # Defaults storage - defaults: Dict[str, Any] = dict() - - while True: - try: - cmd_input = input("> ") - except KeyboardInterrupt: - # Ctrl-c exits - break - - if cmd_input in ['exit','q']: - break; - - if cmd_input in ['--help','help','h','?']: - parser.print_help() - continue - - try: - # Refresh the state of the session - session = invoker.invoker_services.graph_execution_manager.get(session.id) - history = list(get_graph_execution_history(session)) - - # Split the command for piping - cmds = cmd_input.split('|') - start_id = len(history) - current_id = start_id - new_invocations = list() - for cmd in cmds: - # Parse args to create invocation - args = vars(parser.parse_args(shlex.split(cmd.strip()))) - - # Check for special commands - # TODO: These might be better as Pydantic models, similar to the invocations - if args['type'] == 'history': - history_count = args['count'] or 5 - for i in range(min(history_count, len(history))): - entry_id = history[-1 - i] - entry = session.graph.get_node(entry_id) - print(f'{entry_id}: {get_invocation_command(entry.invocation)}') - continue - - if args['type'] == 'reset_default': - if args['input'] in defaults: - del defaults[args['input']] - continue - - if args['type'] == 'default': - field = args['input'] - field_value = args['value'] - defaults[field] = field_value - continue - - # Override defaults - for field_name,field_default in defaults.items(): - if field_name in args: - args[field_name] = field_default - - # Parse invocation - args['id'] = current_id - command = InvocationCommand(invocation = args) - - # Pipe previous command output (if there was a previous command) - edges = [] - if len(history) > 0 or current_id != start_id: - from_id = history[0] if current_id == start_id else str(current_id - 1) - from_node = next(filter(lambda n: n[0].id == from_id, new_invocations))[0] if current_id != start_id else session.graph.get_node(from_id) - matching_edges = generate_matching_edges(from_node, command.invocation) - edges.extend(matching_edges) - - # Parse provided links - if 'link_node' in args and args['link_node']: - for link in args['link_node']: - link_node = session.graph.get_node(link) - matching_edges = generate_matching_edges(link_node, command.invocation) - edges.extend(matching_edges) - - if 'link' in args and args['link']: - for link in args['link']: - edges.append((EdgeConnection(node_id = link[1], field = link[0]), EdgeConnection(node_id = command.invocation.id, field = link[2]))) - - new_invocations.append((command.invocation, edges)) - - current_id = current_id + 1 - - # Command line was parsed successfully - # Add the invocations to the session - for invocation in new_invocations: - session.add_node(invocation[0]) - for edge in invocation[1]: - session.add_edge(edge) - - # Execute all available invocations - invoker.invoke(session, invoke_all = True) - while not session.is_complete(): - # Wait some time - session = invoker.invoker_services.graph_execution_manager.get(session.id) - time.sleep(0.1) - - except InvalidArgs: - print('Invalid command, use "help" to list commands') - continue - - except SystemExit: - continue - - invoker.stop() - - -if __name__ == "__main__": - invoke_cli() diff --git a/ldm/invoke/app/invocations/__init__.py b/ldm/invoke/app/invocations/__init__.py deleted file mode 100644 index 6407a1cdee..0000000000 --- a/ldm/invoke/app/invocations/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -import os - -__all__ = [] - -dirname = os.path.dirname(os.path.abspath(__file__)) -for f in os.listdir(dirname): - if f != "__init__.py" and os.path.isfile("%s/%s" % (dirname, f)) and f[-3:] == ".py": - __all__.append(f[:-3]) diff --git a/ldm/invoke/app/invocations/baseinvocation.py b/ldm/invoke/app/invocations/baseinvocation.py deleted file mode 100644 index 1ad2d99112..0000000000 --- a/ldm/invoke/app/invocations/baseinvocation.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) - -from abc import ABC, abstractmethod -from inspect import signature -from typing import get_args, get_type_hints -from pydantic import BaseModel, Field -from ..services.invocation_services import InvocationServices - - -class InvocationContext: - services: InvocationServices - graph_execution_state_id: str - - def __init__(self, services: InvocationServices, graph_execution_state_id: str): - self.services = services - self.graph_execution_state_id = graph_execution_state_id - - -class BaseInvocationOutput(BaseModel): - """Base class for all invocation outputs""" - - # All outputs must include a type name like this: - # type: Literal['your_output_name'] - - @classmethod - def get_all_subclasses_tuple(cls): - subclasses = [] - toprocess = [cls] - while len(toprocess) > 0: - next = toprocess.pop(0) - next_subclasses = next.__subclasses__() - subclasses.extend(next_subclasses) - toprocess.extend(next_subclasses) - return tuple(subclasses) - - -class BaseInvocation(ABC, BaseModel): - """A node to process inputs and produce outputs. - May use dependency injection in __init__ to receive providers. - """ - - # All invocations must include a type name like this: - # type: Literal['your_output_name'] - - @classmethod - def get_all_subclasses(cls): - subclasses = [] - toprocess = [cls] - while len(toprocess) > 0: - next = toprocess.pop(0) - next_subclasses = next.__subclasses__() - subclasses.extend(next_subclasses) - toprocess.extend(next_subclasses) - return subclasses - - @classmethod - def get_invocations(cls): - return tuple(BaseInvocation.get_all_subclasses()) - - @classmethod - def get_invocations_map(cls): - # Get the type strings out of the literals and into a dictionary - return dict(map(lambda t: (get_args(get_type_hints(t)['type'])[0], t),BaseInvocation.get_all_subclasses())) - - @classmethod - def get_output_type(cls): - return signature(cls.invoke).return_annotation - - @abstractmethod - def invoke(self, context: InvocationContext) -> BaseInvocationOutput: - """Invoke with provided context and return outputs.""" - pass - - id: str = Field(description="The id of this node. Must be unique among all nodes.") diff --git a/ldm/invoke/app/invocations/cv.py b/ldm/invoke/app/invocations/cv.py deleted file mode 100644 index f950669736..0000000000 --- a/ldm/invoke/app/invocations/cv.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) - -from typing import Literal -import numpy -from pydantic import Field -from PIL import Image, ImageOps -import cv2 as cv -from .image import ImageField, ImageOutput -from .baseinvocation import BaseInvocation, InvocationContext -from ..services.image_storage import ImageType - - -class CvInpaintInvocation(BaseInvocation): - """Simple inpaint using opencv.""" - type: Literal['cv_inpaint'] = 'cv_inpaint' - - # Inputs - image: ImageField = Field(default=None, description="The image to inpaint") - mask: ImageField = Field(default=None, description="The mask to use when inpainting") - - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get(self.image.image_type, self.image.image_name) - mask = context.services.images.get(self.mask.image_type, self.mask.image_name) - - # Convert to cv image/mask - # TODO: consider making these utility functions - cv_image = cv.cvtColor(numpy.array(image.convert('RGB')), cv.COLOR_RGB2BGR) - cv_mask = numpy.array(ImageOps.invert(mask)) - - # Inpaint - cv_inpainted = cv.inpaint(cv_image, cv_mask, 3, cv.INPAINT_TELEA) - - # Convert back to Pillow - # TODO: consider making a utility function - image_inpainted = Image.fromarray(cv.cvtColor(cv_inpainted, cv.COLOR_BGR2RGB)) - - image_type = ImageType.INTERMEDIATE - image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) - context.services.images.save(image_type, image_name, image_inpainted) - return ImageOutput( - image = ImageField(image_type = image_type, image_name = image_name) - ) diff --git a/ldm/invoke/app/invocations/generate.py b/ldm/invoke/app/invocations/generate.py deleted file mode 100644 index 60b656bf0c..0000000000 --- a/ldm/invoke/app/invocations/generate.py +++ /dev/null @@ -1,160 +0,0 @@ -# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) - -from datetime import datetime, timezone -from typing import Any, Literal, Optional, Union -import numpy as np -from pydantic import Field -from PIL import Image -from skimage.exposure.histogram_matching import match_histograms -from .image import ImageField, ImageOutput -from .baseinvocation import BaseInvocation, InvocationContext -from ..services.image_storage import ImageType -from ..services.invocation_services import InvocationServices - - -SAMPLER_NAME_VALUES = Literal["ddim","plms","k_lms","k_dpm_2","k_dpm_2_a","k_euler","k_euler_a","k_heun"] - -# Text to image -class TextToImageInvocation(BaseInvocation): - """Generates an image using text2img.""" - type: Literal['txt2img'] = 'txt2img' - - # Inputs - # TODO: consider making prompt optional to enable providing prompt through a link - prompt: Optional[str] = Field(description="The prompt to generate an image from") - seed: int = Field(default=-1, ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)") - steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image") - width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image") - height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting image") - cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt") - sampler_name: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The sampler to use") - seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams") - model: str = Field(default='', description="The model to use (currently ignored)") - progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation") - - # TODO: pass this an emitter method or something? or a session for dispatching? - def dispatch_progress(self, context: InvocationContext, sample: Any = None, step: int = 0) -> None: - context.services.events.emit_generator_progress( - context.graph_execution_state_id, self.id, step, float(step) / float(self.steps) - ) - - def invoke(self, context: InvocationContext) -> ImageOutput: - - def step_callback(sample, step = 0): - self.dispatch_progress(context, sample, step) - - # Handle invalid model parameter - # TODO: figure out if this can be done via a validator that uses the model_cache - # TODO: How to get the default model name now? - if self.model is None or self.model == '': - self.model = context.services.generate.model_name - - # Set the model (if already cached, this does nothing) - context.services.generate.set_model(self.model) - - results = context.services.generate.prompt2image( - prompt = self.prompt, - step_callback = step_callback, - **self.dict(exclude = {'prompt'}) # Shorthand for passing all of the parameters above manually - ) - - # Results are image and seed, unwrap for now and ignore the seed - # TODO: pre-seed? - # TODO: can this return multiple results? Should it? - image_type = ImageType.RESULT - image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) - context.services.images.save(image_type, image_name, results[0][0]) - return ImageOutput( - image = ImageField(image_type = image_type, image_name = image_name) - ) - - -class ImageToImageInvocation(TextToImageInvocation): - """Generates an image using img2img.""" - type: Literal['img2img'] = 'img2img' - - # Inputs - image: Union[ImageField,None] = Field(description="The input image") - strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the original image") - fit: bool = Field(default=True, description="Whether or not the result should be fit to the aspect ratio of the input image") - - def invoke(self, context: InvocationContext) -> ImageOutput: - image = None if self.image is None else context.services.images.get(self.image.image_type, self.image.image_name) - mask = None - - def step_callback(sample, step = 0): - self.dispatch_progress(context, sample, step) - - # Handle invalid model parameter - # TODO: figure out if this can be done via a validator that uses the model_cache - # TODO: How to get the default model name now? - if self.model is None or self.model == '': - self.model = context.services.generate.model_name - - # Set the model (if already cached, this does nothing) - context.services.generate.set_model(self.model) - - results = context.services.generate.prompt2image( - prompt = self.prompt, - init_img = image, - init_mask = mask, - step_callback = step_callback, - **self.dict(exclude = {'prompt','image','mask'}) # Shorthand for passing all of the parameters above manually - ) - - result_image = results[0][0] - - # Results are image and seed, unwrap for now and ignore the seed - # TODO: pre-seed? - # TODO: can this return multiple results? Should it? - image_type = ImageType.RESULT - image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) - context.services.images.save(image_type, image_name, result_image) - return ImageOutput( - image = ImageField(image_type = image_type, image_name = image_name) - ) - - -class InpaintInvocation(ImageToImageInvocation): - """Generates an image using inpaint.""" - type: Literal['inpaint'] = 'inpaint' - - # Inputs - mask: Union[ImageField,None] = Field(description="The mask") - inpaint_replace: float = Field(default=0.0, ge=0.0, le=1.0, description="The amount by which to replace masked areas with latent noise") - - def invoke(self, context: InvocationContext) -> ImageOutput: - image = None if self.image is None else context.services.images.get(self.image.image_type, self.image.image_name) - mask = None if self.mask is None else context.services.images.get(self.mask.image_type, self.mask.image_name) - - def step_callback(sample, step = 0): - self.dispatch_progress(context, sample, step) - - # Handle invalid model parameter - # TODO: figure out if this can be done via a validator that uses the model_cache - # TODO: How to get the default model name now? - if self.model is None or self.model == '': - self.model = context.services.generate.model_name - - # Set the model (if already cached, this does nothing) - context.services.generate.set_model(self.model) - - results = context.services.generate.prompt2image( - prompt = self.prompt, - init_img = image, - init_mask = mask, - step_callback = step_callback, - **self.dict(exclude = {'prompt','image','mask'}) # Shorthand for passing all of the parameters above manually - ) - - result_image = results[0][0] - - # Results are image and seed, unwrap for now and ignore the seed - # TODO: pre-seed? - # TODO: can this return multiple results? Should it? - image_type = ImageType.RESULT - image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) - context.services.images.save(image_type, image_name, result_image) - return ImageOutput( - image = ImageField(image_type = image_type, image_name = image_name) - ) diff --git a/ldm/invoke/app/invocations/image.py b/ldm/invoke/app/invocations/image.py deleted file mode 100644 index cb326b1bb7..0000000000 --- a/ldm/invoke/app/invocations/image.py +++ /dev/null @@ -1,219 +0,0 @@ -# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) - -from datetime import datetime, timezone -from typing import Literal, Optional -import numpy -from pydantic import Field, BaseModel -from PIL import Image, ImageOps, ImageFilter -from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext -from ..services.image_storage import ImageType -from ..services.invocation_services import InvocationServices - - -class ImageField(BaseModel): - """An image field used for passing image objects between invocations""" - image_type: str = Field(default=ImageType.RESULT, description="The type of the image") - image_name: Optional[str] = Field(default=None, description="The name of the image") - - -class ImageOutput(BaseInvocationOutput): - """Base class for invocations that output an image""" - type: Literal['image'] = 'image' - - image: ImageField = Field(default=None, description="The output image") - - -class MaskOutput(BaseInvocationOutput): - """Base class for invocations that output a mask""" - type: Literal['mask'] = 'mask' - - mask: ImageField = Field(default=None, description="The output mask") - - -# TODO: this isn't really necessary anymore -class LoadImageInvocation(BaseInvocation): - """Load an image from a filename and provide it as output.""" - type: Literal['load_image'] = 'load_image' - - # Inputs - image_type: ImageType = Field(description="The type of the image") - image_name: str = Field(description="The name of the image") - - def invoke(self, context: InvocationContext) -> ImageOutput: - return ImageOutput( - image = ImageField(image_type = self.image_type, image_name = self.image_name) - ) - - -class ShowImageInvocation(BaseInvocation): - """Displays a provided image, and passes it forward in the pipeline.""" - type: Literal['show_image'] = 'show_image' - - # Inputs - image: ImageField = Field(default=None, description="The image to show") - - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get(self.image.image_type, self.image.image_name) - if image: - image.show() - - # TODO: how to handle failure? - - return ImageOutput( - image = ImageField(image_type = self.image.image_type, image_name = self.image.image_name) - ) - - -class CropImageInvocation(BaseInvocation): - """Crops an image to a specified box. The box can be outside of the image.""" - type: Literal['crop'] = 'crop' - - # Inputs - image: ImageField = Field(default=None, description="The image to crop") - x: int = Field(default=0, description="The left x coordinate of the crop rectangle") - y: int = Field(default=0, description="The top y coordinate of the crop rectangle") - width: int = Field(default=512, gt=0, description="The width of the crop rectangle") - height: int = Field(default=512, gt=0, description="The height of the crop rectangle") - - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get(self.image.image_type, self.image.image_name) - - image_crop = Image.new(mode = 'RGBA', size = (self.width, self.height), color = (0, 0, 0, 0)) - image_crop.paste(image, (-self.x, -self.y)) - - image_type = ImageType.INTERMEDIATE - image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) - context.services.images.save(image_type, image_name, image_crop) - return ImageOutput( - image = ImageField(image_type = image_type, image_name = image_name) - ) - - -class PasteImageInvocation(BaseInvocation): - """Pastes an image into another image.""" - type: Literal['paste'] = 'paste' - - # Inputs - base_image: ImageField = Field(default=None, description="The base image") - image: ImageField = Field(default=None, description="The image to paste") - mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting") - x: int = Field(default=0, description="The left x coordinate at which to paste the image") - y: int = Field(default=0, description="The top y coordinate at which to paste the image") - - def invoke(self, context: InvocationContext) -> ImageOutput: - base_image = context.services.images.get(self.base_image.image_type, self.base_image.image_name) - image = context.services.images.get(self.image.image_type, self.image.image_name) - mask = None if self.mask is None else ImageOps.invert(services.images.get(self.mask.image_type, self.mask.image_name)) - # TODO: probably shouldn't invert mask here... should user be required to do it? - - min_x = min(0, self.x) - min_y = min(0, self.y) - max_x = max(base_image.width, image.width + self.x) - max_y = max(base_image.height, image.height + self.y) - - new_image = Image.new(mode = 'RGBA', size = (max_x - min_x, max_y - min_y), color = (0, 0, 0, 0)) - new_image.paste(base_image, (abs(min_x), abs(min_y))) - new_image.paste(image, (max(0, self.x), max(0, self.y)), mask = mask) - - image_type = ImageType.RESULT - image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) - context.services.images.save(image_type, image_name, new_image) - return ImageOutput( - image = ImageField(image_type = image_type, image_name = image_name) - ) - - -class MaskFromAlphaInvocation(BaseInvocation): - """Extracts the alpha channel of an image as a mask.""" - type: Literal['tomask'] = 'tomask' - - # Inputs - image: ImageField = Field(default=None, description="The image to create the mask from") - invert: bool = Field(default=False, description="Whether or not to invert the mask") - - def invoke(self, context: InvocationContext) -> MaskOutput: - image = context.services.images.get(self.image.image_type, self.image.image_name) - - image_mask = image.split()[-1] - if self.invert: - image_mask = ImageOps.invert(image_mask) - - image_type = ImageType.INTERMEDIATE - image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) - context.services.images.save(image_type, image_name, image_mask) - return MaskOutput( - mask = ImageField(image_type = image_type, image_name = image_name) - ) - - -class BlurInvocation(BaseInvocation): - """Blurs an image""" - type: Literal['blur'] = 'blur' - - # Inputs - image: ImageField = Field(default=None, description="The image to blur") - radius: float = Field(default=8.0, ge=0, description="The blur radius") - blur_type: Literal['gaussian', 'box'] = Field(default='gaussian', description="The type of blur") - - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get(self.image.image_type, self.image.image_name) - - blur = ImageFilter.GaussianBlur(self.radius) if self.blur_type == 'gaussian' else ImageFilter.BoxBlur(self.radius) - blur_image = image.filter(blur) - - image_type = ImageType.INTERMEDIATE - image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) - context.services.images.save(image_type, image_name, blur_image) - return ImageOutput( - image = ImageField(image_type = image_type, image_name = image_name) - ) - - -class LerpInvocation(BaseInvocation): - """Linear interpolation of all pixels of an image""" - type: Literal['lerp'] = 'lerp' - - # Inputs - image: ImageField = Field(default=None, description="The image to lerp") - min: int = Field(default=0, ge=0, le=255, description="The minimum output value") - max: int = Field(default=255, ge=0, le=255, description="The maximum output value") - - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get(self.image.image_type, self.image.image_name) - - image_arr = numpy.asarray(image, dtype=numpy.float32) / 255 - image_arr = image_arr * (self.max - self.min) + self.max - - lerp_image = Image.fromarray(numpy.uint8(image_arr)) - - image_type = ImageType.INTERMEDIATE - image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) - context.services.images.save(image_type, image_name, lerp_image) - return ImageOutput( - image = ImageField(image_type = image_type, image_name = image_name) - ) - - -class InverseLerpInvocation(BaseInvocation): - """Inverse linear interpolation of all pixels of an image""" - type: Literal['ilerp'] = 'ilerp' - - # Inputs - image: ImageField = Field(default=None, description="The image to lerp") - min: int = Field(default=0, ge=0, le=255, description="The minimum input value") - max: int = Field(default=255, ge=0, le=255, description="The maximum input value") - - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get(self.image.image_type, self.image.image_name) - - image_arr = numpy.asarray(image, dtype=numpy.float32) - image_arr = numpy.minimum(numpy.maximum(image_arr - self.min, 0) / float(self.max - self.min), 1) * 255 - - ilerp_image = Image.fromarray(numpy.uint8(image_arr)) - - image_type = ImageType.INTERMEDIATE - image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) - context.services.images.save(image_type, image_name, ilerp_image) - return ImageOutput( - image = ImageField(image_type = image_type, image_name = image_name) - ) diff --git a/ldm/invoke/app/invocations/prompt.py b/ldm/invoke/app/invocations/prompt.py deleted file mode 100644 index 029cad9660..0000000000 --- a/ldm/invoke/app/invocations/prompt.py +++ /dev/null @@ -1,9 +0,0 @@ -from typing import Literal -from pydantic.fields import Field -from .baseinvocation import BaseInvocationOutput - -class PromptOutput(BaseInvocationOutput): - """Base class for invocations that output a prompt""" - type: Literal['prompt'] = 'prompt' - - prompt: str = Field(default=None, description="The output prompt") diff --git a/ldm/invoke/app/invocations/reconstruct.py b/ldm/invoke/app/invocations/reconstruct.py deleted file mode 100644 index 98201ce837..0000000000 --- a/ldm/invoke/app/invocations/reconstruct.py +++ /dev/null @@ -1,36 +0,0 @@ -from datetime import datetime, timezone -from typing import Literal, Union -from pydantic import Field -from .image import ImageField, ImageOutput -from .baseinvocation import BaseInvocation, InvocationContext -from ..services.image_storage import ImageType -from ..services.invocation_services import InvocationServices - - -class RestoreFaceInvocation(BaseInvocation): - """Restores faces in an image.""" - type: Literal['restore_face'] = 'restore_face' - - # Inputs - image: Union[ImageField,None] = Field(description="The input image") - strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the restoration") - - - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get(self.image.image_type, self.image.image_name) - results = context.services.generate.upscale_and_reconstruct( - image_list = [[image, 0]], - upscale = None, - strength = self.strength, # GFPGAN strength - save_original = False, - image_callback = None, - ) - - # Results are image and seed, unwrap for now - # TODO: can this return multiple results? - image_type = ImageType.RESULT - image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) - context.services.images.save(image_type, image_name, results[0][0]) - return ImageOutput( - image = ImageField(image_type = image_type, image_name = image_name) - ) diff --git a/ldm/invoke/app/invocations/upscale.py b/ldm/invoke/app/invocations/upscale.py deleted file mode 100644 index 1df8c44ea8..0000000000 --- a/ldm/invoke/app/invocations/upscale.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) - -from datetime import datetime, timezone -from typing import Literal, Union -from pydantic import Field -from .image import ImageField, ImageOutput -from .baseinvocation import BaseInvocation, InvocationContext -from ..services.image_storage import ImageType -from ..services.invocation_services import InvocationServices - - -class UpscaleInvocation(BaseInvocation): - """Upscales an image.""" - type: Literal['upscale'] = 'upscale' - - # Inputs - image: Union[ImageField,None] = Field(description="The input image", default=None) - strength: float = Field(default=0.75, gt=0, le=1, description="The strength") - level: Literal[2,4] = Field(default=2, description = "The upscale level") - - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get(self.image.image_type, self.image.image_name) - results = context.services.generate.upscale_and_reconstruct( - image_list = [[image, 0]], - upscale = (self.level, self.strength), - strength = 0.0, # GFPGAN strength - save_original = False, - image_callback = None, - ) - - # Results are image and seed, unwrap for now - # TODO: can this return multiple results? - image_type = ImageType.RESULT - image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) - context.services.images.save(image_type, image_name, results[0][0]) - return ImageOutput( - image = ImageField(image_type = image_type, image_name = image_name) - ) diff --git a/ldm/invoke/app/services/__init__.py b/ldm/invoke/app/services/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/ldm/invoke/app/services/events.py b/ldm/invoke/app/services/events.py deleted file mode 100644 index 7b850b61ac..0000000000 --- a/ldm/invoke/app/services/events.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) - -from typing import Any, Dict - - -class EventServiceBase: - session_event: str = 'session_event' - - """Basic event bus, to have an empty stand-in when not needed""" - def dispatch(self, event_name: str, payload: Any) -> None: - pass - - def __emit_session_event(self, - event_name: str, - payload: Dict) -> None: - self.dispatch( - event_name = EventServiceBase.session_event, - payload = dict( - event = event_name, - data = payload - ) - ) - - # Define events here for every event in the system. - # This will make them easier to integrate until we find a schema generator. - def emit_generator_progress(self, - graph_execution_state_id: str, - invocation_id: str, - step: int, - percent: float - ) -> None: - """Emitted when there is generation progress""" - self.__emit_session_event( - event_name = 'generator_progress', - payload = dict( - graph_execution_state_id = graph_execution_state_id, - invocation_id = invocation_id, - step = step, - percent = percent - ) - ) - - def emit_invocation_complete(self, - graph_execution_state_id: str, - invocation_id: str, - result: Dict - ) -> None: - """Emitted when an invocation has completed""" - self.__emit_session_event( - event_name = 'invocation_complete', - payload = dict( - graph_execution_state_id = graph_execution_state_id, - invocation_id = invocation_id, - result = result - ) - ) - - def emit_invocation_started(self, - graph_execution_state_id: str, - invocation_id: str - ) -> None: - """Emitted when an invocation has started""" - self.__emit_session_event( - event_name = 'invocation_started', - payload = dict( - graph_execution_state_id = graph_execution_state_id, - invocation_id = invocation_id - ) - ) - - def emit_graph_execution_complete(self, graph_execution_state_id: str) -> None: - """Emitted when a session has completed all invocations""" - self.__emit_session_event( - event_name = 'graph_execution_state_complete', - payload = dict( - graph_execution_state_id = graph_execution_state_id - ) - ) diff --git a/ldm/invoke/app/services/generate_initializer.py b/ldm/invoke/app/services/generate_initializer.py deleted file mode 100644 index 39c0fe491e..0000000000 --- a/ldm/invoke/app/services/generate_initializer.py +++ /dev/null @@ -1,233 +0,0 @@ -from argparse import Namespace -import os -import sys -import traceback - -from ...model_manager import ModelManager - -from ...globals import Globals -from ....generate import Generate -import ldm.invoke - - -# TODO: most of this code should be split into individual services as the Generate.py code is deprecated -def get_generate(args, config) -> Generate: - if not args.conf: - config_file = os.path.join(Globals.root,'configs','models.yaml') - if not os.path.exists(config_file): - report_model_error(args, FileNotFoundError(f"The file {config_file} could not be found.")) - - print(f'>> {ldm.invoke.__app_name__}, version {ldm.invoke.__version__}') - print(f'>> InvokeAI runtime directory is "{Globals.root}"') - - # these two lines prevent a horrible warning message from appearing - # when the frozen CLIP tokenizer is imported - import transformers # type: ignore - transformers.logging.set_verbosity_error() - import diffusers - diffusers.logging.set_verbosity_error() - - # Loading Face Restoration and ESRGAN Modules - gfpgan,codeformer,esrgan = load_face_restoration(args) - - # normalize the config directory relative to root - if not os.path.isabs(args.conf): - args.conf = os.path.normpath(os.path.join(Globals.root,args.conf)) - - if args.embeddings: - if not os.path.isabs(args.embedding_path): - embedding_path = os.path.normpath(os.path.join(Globals.root,args.embedding_path)) - else: - embedding_path = args.embedding_path - else: - embedding_path = None - - # migrate legacy models - ModelManager.migrate_models() - - # load the infile as a list of lines - if args.infile: - try: - if os.path.isfile(args.infile): - infile = open(args.infile, 'r', encoding='utf-8') - elif args.infile == '-': # stdin - infile = sys.stdin - else: - raise FileNotFoundError(f'{args.infile} not found.') - except (FileNotFoundError, IOError) as e: - print(f'{e}. Aborting.') - sys.exit(-1) - - # creating a Generate object: - try: - gen = Generate( - conf = args.conf, - model = args.model, - sampler_name = args.sampler_name, - embedding_path = embedding_path, - full_precision = args.full_precision, - precision = args.precision, - gfpgan = gfpgan, - codeformer = codeformer, - esrgan = esrgan, - free_gpu_mem = args.free_gpu_mem, - safety_checker = args.safety_checker, - max_loaded_models = args.max_loaded_models, - ) - except (FileNotFoundError, TypeError, AssertionError) as e: - report_model_error(opt,e) - except (IOError, KeyError) as e: - print(f'{e}. Aborting.') - sys.exit(-1) - - if args.seamless: - print(">> changed to seamless tiling mode") - - # preload the model - try: - gen.load_model() - except KeyError: - pass - except Exception as e: - report_model_error(args, e) - - # try to autoconvert new models - # autoimport new .ckpt files - if path := args.autoconvert: - gen.model_manager.autoconvert_weights( - conf_path=args.conf, - weights_directory=path, - ) - - return gen - - -def load_face_restoration(opt): - try: - gfpgan, codeformer, esrgan = None, None, None - if opt.restore or opt.esrgan: - from ldm.invoke.restoration import Restoration - restoration = Restoration() - if opt.restore: - gfpgan, codeformer = restoration.load_face_restore_models(opt.gfpgan_model_path) - else: - print('>> Face restoration disabled') - if opt.esrgan: - esrgan = restoration.load_esrgan(opt.esrgan_bg_tile) - else: - print('>> Upscaling disabled') - else: - print('>> Face restoration and upscaling disabled') - except (ModuleNotFoundError, ImportError): - print(traceback.format_exc(), file=sys.stderr) - print('>> You may need to install the ESRGAN and/or GFPGAN modules') - return gfpgan,codeformer,esrgan - - -def report_model_error(opt:Namespace, e:Exception): - print(f'** An error occurred while attempting to initialize the model: "{str(e)}"') - print('** This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models.') - yes_to_all = os.environ.get('INVOKE_MODEL_RECONFIGURE') - if yes_to_all: - print('** Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE') - else: - response = input('Do you want to run invokeai-configure script to select and/or reinstall models? [y] ') - if response.startswith(('n', 'N')): - return - - print('invokeai-configure is launching....\n') - - # Match arguments that were set on the CLI - # only the arguments accepted by the configuration script are parsed - root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else [] - config = ["--config", opt.conf] if opt.conf is not None else [] - previous_args = sys.argv - sys.argv = [ 'invokeai-configure' ] - sys.argv.extend(root_dir) - sys.argv.extend(config) - if yes_to_all is not None: - for arg in yes_to_all.split(): - sys.argv.append(arg) - - from ldm.invoke.config import invokeai_configure - invokeai_configure.main() - # TODO: Figure out how to restart - # print('** InvokeAI will now restart') - # sys.argv = previous_args - # main() # would rather do a os.exec(), but doesn't exist? - # sys.exit(0) - - -# Temporary initializer for Generate until we migrate off of it -def old_get_generate(args, config) -> Generate: - # TODO: Remove the need for globals - from ldm.invoke.globals import Globals - - # alert - setting globals here - Globals.root = os.path.expanduser(args.root_dir or os.environ.get('INVOKEAI_ROOT') or os.path.abspath('.')) - Globals.try_patchmatch = args.patchmatch - - print(f'>> InvokeAI runtime directory is "{Globals.root}"') - - # these two lines prevent a horrible warning message from appearing - # when the frozen CLIP tokenizer is imported - import transformers - transformers.logging.set_verbosity_error() - - # Loading Face Restoration and ESRGAN Modules - gfpgan, codeformer, esrgan = None, None, None - try: - if config.restore or config.esrgan: - from ldm.invoke.restoration import Restoration - restoration = Restoration() - if config.restore: - gfpgan, codeformer = restoration.load_face_restore_models(config.gfpgan_model_path) - else: - print('>> Face restoration disabled') - if config.esrgan: - esrgan = restoration.load_esrgan(config.esrgan_bg_tile) - else: - print('>> Upscaling disabled') - else: - print('>> Face restoration and upscaling disabled') - except (ModuleNotFoundError, ImportError): - print(traceback.format_exc(), file=sys.stderr) - print('>> You may need to install the ESRGAN and/or GFPGAN modules') - - # normalize the config directory relative to root - if not os.path.isabs(config.conf): - config.conf = os.path.normpath(os.path.join(Globals.root,config.conf)) - - if config.embeddings: - if not os.path.isabs(config.embedding_path): - embedding_path = os.path.normpath(os.path.join(Globals.root,config.embedding_path)) - else: - embedding_path = None - - - # TODO: lazy-initialize this by wrapping it - try: - generate = Generate( - conf = config.conf, - model = config.model, - sampler_name = config.sampler_name, - embedding_path = embedding_path, - full_precision = config.full_precision, - precision = config.precision, - gfpgan = gfpgan, - codeformer = codeformer, - esrgan = esrgan, - free_gpu_mem = config.free_gpu_mem, - safety_checker = config.safety_checker, - max_loaded_models = config.max_loaded_models, - ) - except (FileNotFoundError, TypeError, AssertionError): - #emergency_model_reconfigure() # TODO? - sys.exit(-1) - except (IOError, KeyError) as e: - print(f'{e}. Aborting.') - sys.exit(-1) - - generate.free_gpu_mem = config.free_gpu_mem - - return generate diff --git a/ldm/invoke/app/services/graph.py b/ldm/invoke/app/services/graph.py deleted file mode 100644 index 8d1583fc8b..0000000000 --- a/ldm/invoke/app/services/graph.py +++ /dev/null @@ -1,797 +0,0 @@ -# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) - -import copy -import itertools -from types import NoneType -import uuid -import networkx as nx -from pydantic import BaseModel, validator -from pydantic.fields import Field -from typing import Any, Literal, Optional, Union, get_args, get_origin, get_type_hints, Annotated - -from .invocation_services import InvocationServices -from ..invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext -from ..invocations import * - - -class EdgeConnection(BaseModel): - node_id: str = Field(description="The id of the node for this edge connection") - field: str = Field(description="The field for this connection") - - def __eq__(self, other): - return (isinstance(other, self.__class__) and - getattr(other, 'node_id', None) == self.node_id and - getattr(other, 'field', None) == self.field) - - def __hash__(self): - return hash(f'{self.node_id}.{self.field}') - - -def get_output_field(node: BaseInvocation, field: str) -> Any: - node_type = type(node) - node_outputs = get_type_hints(node_type.get_output_type()) - node_output_field = node_outputs.get(field) or None - return node_output_field - - -def get_input_field(node: BaseInvocation, field: str) -> Any: - node_type = type(node) - node_inputs = get_type_hints(node_type) - node_input_field = node_inputs.get(field) or None - return node_input_field - - -def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool: - if not from_type: - return False - if not to_type: - return False - - # TODO: this is pretty forgiving on generic types. Clean that up (need to handle optionals and such) - if from_type and to_type: - # Ports are compatible - if (from_type == to_type or - from_type == Any or - to_type == Any or - Any in get_args(from_type) or - Any in get_args(to_type)): - return True - - if from_type in get_args(to_type): - return True - - if to_type in get_args(from_type): - return True - - if not issubclass(from_type, to_type): - return False - else: - return False - - return True - - -def are_connections_compatible( - from_node: BaseInvocation, - from_field: str, - to_node: BaseInvocation, - to_field: str) -> bool: - """Determines if a connection between fields of two nodes is compatible.""" - - # TODO: handle iterators and collectors - from_node_field = get_output_field(from_node, from_field) - to_node_field = get_input_field(to_node, to_field) - - return are_connection_types_compatible(from_node_field, to_node_field) - - -class NodeAlreadyInGraphError(Exception): - pass - - -class InvalidEdgeError(Exception): - pass - -class NodeNotFoundError(Exception): - pass - -class NodeAlreadyExecutedError(Exception): - pass - - -# TODO: Create and use an Empty output? -class GraphInvocationOutput(BaseInvocationOutput): - type: Literal['graph_output'] = 'graph_output' - - -# TODO: Fill this out and move to invocations -class GraphInvocation(BaseInvocation): - type: Literal['graph'] = 'graph' - - # TODO: figure out how to create a default here - graph: 'Graph' = Field(description="The graph to run", default=None) - - def invoke(self, context: InvocationContext) -> GraphInvocationOutput: - """Invoke with provided services and return outputs.""" - return GraphInvocationOutput() - - -class IterateInvocationOutput(BaseInvocationOutput): - """Used to connect iteration outputs. Will be expanded to a specific output.""" - type: Literal['iterate_output'] = 'iterate_output' - - item: Any = Field(description="The item being iterated over") - - -# TODO: Fill this out and move to invocations -class IterateInvocation(BaseInvocation): - type: Literal['iterate'] = 'iterate' - - collection: list[Any] = Field(description="The list of items to iterate over", default_factory=list) - index: int = Field(description="The index, will be provided on executed iterators", default=0) - - def invoke(self, context: InvocationContext) -> IterateInvocationOutput: - """Produces the outputs as values""" - return IterateInvocationOutput(item = self.collection[self.index]) - - -class CollectInvocationOutput(BaseInvocationOutput): - type: Literal['collect_output'] = 'collect_output' - - collection: list[Any] = Field(description="The collection of input items") - - -class CollectInvocation(BaseInvocation): - """Collects values into a collection""" - type: Literal['collect'] = 'collect' - - item: Any = Field(description="The item to collect (all inputs must be of the same type)", default=None) - collection: list[Any] = Field(description="The collection, will be provided on execution", default_factory=list) - - def invoke(self, context: InvocationContext) -> CollectInvocationOutput: - """Invoke with provided services and return outputs.""" - return CollectInvocationOutput(collection = copy.copy(self.collection)) - - -InvocationsUnion = Union[BaseInvocation.get_invocations()] -InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()] - - -class Graph(BaseModel): - id: str = Field(description="The id of this graph", default_factory=uuid.uuid4) - # TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me - nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(description="The nodes in this graph", default_factory=dict) - edges: list[tuple[EdgeConnection,EdgeConnection]] = Field(description="The connections between nodes and their fields in this graph", default_factory=list) - - def add_node(self, node: BaseInvocation) -> None: - """Adds a node to a graph - - :raises NodeAlreadyInGraphError: the node is already present in the graph. - """ - - if node.id in self.nodes: - raise NodeAlreadyInGraphError() - - self.nodes[node.id] = node - - - def _get_graph_and_node(self, node_path: str) -> tuple['Graph', str]: - """Returns the graph and node id for a node path.""" - # Materialized graphs may have nodes at the top level - if node_path in self.nodes: - return (self, node_path) - - node_id = node_path if '.' not in node_path else node_path[:node_path.index('.')] - if node_id not in self.nodes: - raise NodeNotFoundError(f'Node {node_path} not found in graph') - - node = self.nodes[node_id] - - if not isinstance(node, GraphInvocation): - # There's more node path left but this isn't a graph - failure - raise NodeNotFoundError('Node path terminated early at a non-graph node') - - return node.graph._get_graph_and_node(node_path[node_path.index('.')+1:]) - - - def delete_node(self, node_path: str) -> None: - """Deletes a node from a graph""" - - try: - graph, node_id = self._get_graph_and_node(node_path) - - # Delete edges for this node - input_edges = self._get_input_edges_and_graphs(node_path) - output_edges = self._get_output_edges_and_graphs(node_path) - - for edge_graph,_,edge in input_edges: - edge_graph.delete_edge(edge) - - for edge_graph,_,edge in output_edges: - edge_graph.delete_edge(edge) - - del graph.nodes[node_id] - - except NodeNotFoundError: - pass # Ignore, not doesn't exist (should this throw?) - - - def add_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None: - """Adds an edge to a graph - - :raises InvalidEdgeError: the provided edge is invalid. - """ - - if self._is_edge_valid(edge) and edge not in self.edges: - self.edges.append(edge) - else: - raise InvalidEdgeError() - - - def delete_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None: - """Deletes an edge from a graph""" - - try: - self.edges.remove(edge) - except KeyError: - pass - - - def is_valid(self) -> bool: - """Validates the graph.""" - - # Validate all subgraphs - for gn in (n for n in self.nodes.values() if isinstance(n, GraphInvocation)): - if not gn.graph.is_valid(): - return False - - # Validate all edges reference nodes in the graph - node_ids = set([e[0].node_id for e in self.edges]+[e[1].node_id for e in self.edges]) - if not all((self.has_node(node_id) for node_id in node_ids)): - return False - - # Validate there are no cycles - g = self.nx_graph_flat() - if not nx.is_directed_acyclic_graph(g): - return False - - # Validate all edge connections are valid - if not all((are_connections_compatible( - self.get_node(e[0].node_id), e[0].field, - self.get_node(e[1].node_id), e[1].field - ) for e in self.edges)): - return False - - # Validate all iterators - # TODO: may need to validate all iterators in subgraphs so edge connections in parent graphs will be available - if not all((self._is_iterator_connection_valid(n.id) for n in self.nodes.values() if isinstance(n, IterateInvocation))): - return False - - # Validate all collectors - # TODO: may need to validate all collectors in subgraphs so edge connections in parent graphs will be available - if not all((self._is_collector_connection_valid(n.id) for n in self.nodes.values() if isinstance(n, CollectInvocation))): - return False - - return True - - def _is_edge_valid(self, edge: tuple[EdgeConnection, EdgeConnection]) -> bool: - """Validates that a new edge doesn't create a cycle in the graph""" - - # Validate that the nodes exist (edges may contain node paths, so we can't just check for nodes directly) - try: - from_node = self.get_node(edge[0].node_id) - to_node = self.get_node(edge[1].node_id) - except NodeNotFoundError: - return False - - # Validate that an edge to this node+field doesn't already exist - input_edges = self._get_input_edges(edge[1].node_id, edge[1].field) - if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation): - return False - - # Validate that no cycles would be created - g = self.nx_graph_flat() - g.add_edge(edge[0].node_id, edge[1].node_id) - if not nx.is_directed_acyclic_graph(g): - return False - - # Validate that the field types are compatible - if not are_connections_compatible(from_node, edge[0].field, to_node, edge[1].field): - return False - - # Validate if iterator output type matches iterator input type (if this edge results in both being set) - if isinstance(to_node, IterateInvocation) and edge[1].field == 'collection': - if not self._is_iterator_connection_valid(edge[1].node_id, new_input = edge[0]): - return False - - # Validate if iterator input type matches output type (if this edge results in both being set) - if isinstance(from_node, IterateInvocation) and edge[0].field == 'item': - if not self._is_iterator_connection_valid(edge[0].node_id, new_output = edge[1]): - return False - - # Validate if collector input type matches output type (if this edge results in both being set) - if isinstance(to_node, CollectInvocation) and edge[1].field == 'item': - if not self._is_collector_connection_valid(edge[1].node_id, new_input = edge[0]): - return False - - # Validate if collector output type matches input type (if this edge results in both being set) - if isinstance(from_node, CollectInvocation) and edge[0].field == 'collection': - if not self._is_collector_connection_valid(edge[0].node_id, new_output = edge[1]): - return False - - return True - - def has_node(self, node_path: str) -> bool: - """Determines whether or not a node exists in the graph.""" - try: - n = self.get_node(node_path) - if n is not None: - return True - else: - return False - except NodeNotFoundError: - return False - - def get_node(self, node_path: str) -> InvocationsUnion: - """Gets a node from the graph using a node path.""" - # Materialized graphs may have nodes at the top level - graph, node_id = self._get_graph_and_node(node_path) - return graph.nodes[node_id] - - - def _get_node_path(self, node_id: str, prefix: Optional[str] = None) -> str: - return node_id if prefix is None or prefix == '' else f'{prefix}.{node_id}' - - - def update_node(self, node_path: str, new_node: BaseInvocation) -> None: - """Updates a node in the graph.""" - graph, node_id = self._get_graph_and_node(node_path) - node = graph.nodes[node_id] - - # Ensure the node type matches the new node - if type(node) != type(new_node): - raise TypeError(f'Node {node_path} is type {type(node)} but new node is type {type(new_node)}') - - # Ensure the new id is either the same or is not in the graph - prefix = None if '.' not in node_path else node_path[:node_path.rindex('.')] - new_path = self._get_node_path(new_node.id, prefix = prefix) - if new_node.id != node.id and self.has_node(new_path): - raise NodeAlreadyInGraphError('Node with id {new_node.id} already exists in graph') - - # Set the new node in the graph - graph.nodes[new_node.id] = new_node - if new_node.id != node.id: - input_edges = self._get_input_edges_and_graphs(node_path) - output_edges = self._get_output_edges_and_graphs(node_path) - - # Delete node and all edges - graph.delete_node(node_path) - - # Create new edges for each input and output - for graph,_,edge in input_edges: - # Remove the graph prefix from the node path - new_graph_node_path = new_node.id if '.' not in edge[1].node_id else f'{edge[1].node_id[edge[1].node_id.rindex("."):]}.{new_node.id}' - graph.add_edge((edge[0], EdgeConnection(node_id = new_graph_node_path, field = edge[1].field))) - - for graph,_,edge in output_edges: - # Remove the graph prefix from the node path - new_graph_node_path = new_node.id if '.' not in edge[0].node_id else f'{edge[0].node_id[edge[0].node_id.rindex("."):]}.{new_node.id}' - graph.add_edge((EdgeConnection(node_id = new_graph_node_path, field = edge[0].field), edge[1])) - - - def _get_input_edges(self, node_path: str, field: Optional[str] = None) -> list[tuple[EdgeConnection,EdgeConnection]]: - """Gets all input edges for a node""" - edges = self._get_input_edges_and_graphs(node_path) - - # Filter to edges that match the field - filtered_edges = (e for e in edges if field is None or e[2][1].field == field) - - # Create full node paths for each edge - return [(EdgeConnection(node_id = self._get_node_path(e[0].node_id, prefix = prefix), field=e[0].field), EdgeConnection(node_id = self._get_node_path(e[1].node_id, prefix = prefix), field=e[1].field)) for _,prefix,e in filtered_edges] - - - def _get_input_edges_and_graphs(self, node_path: str, prefix: Optional[str] = None) -> list[tuple['Graph', str, tuple[EdgeConnection,EdgeConnection]]]: - """Gets all input edges for a node along with the graph they are in and the graph's path""" - edges = list() - - # Return any input edges that appear in this graph - edges.extend([(self, prefix, e) for e in self.edges if e[1].node_id == node_path]) - - node_id = node_path if '.' not in node_path else node_path[:node_path.index('.')] - node = self.nodes[node_id] - - if isinstance(node, GraphInvocation): - graph = node.graph - graph_path = node.id if prefix is None or prefix == '' else self._get_node_path(node.id, prefix = prefix) - graph_edges = graph._get_input_edges_and_graphs(node_path[(len(node_id)+1):], prefix=graph_path) - edges.extend(graph_edges) - - return edges - - - def _get_output_edges(self, node_path: str, field: str) -> list[tuple[EdgeConnection,EdgeConnection]]: - """Gets all output edges for a node""" - edges = self._get_output_edges_and_graphs(node_path) - - # Filter to edges that match the field - filtered_edges = (e for e in edges if e[2][0].field == field) - - # Create full node paths for each edge - return [(EdgeConnection(node_id = self._get_node_path(e[0].node_id, prefix = prefix), field=e[0].field), EdgeConnection(node_id = self._get_node_path(e[1].node_id, prefix = prefix), field=e[1].field)) for _,prefix,e in filtered_edges] - - - def _get_output_edges_and_graphs(self, node_path: str, prefix: Optional[str] = None) -> list[tuple['Graph', str, tuple[EdgeConnection,EdgeConnection]]]: - """Gets all output edges for a node along with the graph they are in and the graph's path""" - edges = list() - - # Return any input edges that appear in this graph - edges.extend([(self, prefix, e) for e in self.edges if e[0].node_id == node_path]) - - node_id = node_path if '.' not in node_path else node_path[:node_path.index('.')] - node = self.nodes[node_id] - - if isinstance(node, GraphInvocation): - graph = node.graph - graph_path = node.id if prefix is None or prefix == '' else self._get_node_path(node.id, prefix = prefix) - graph_edges = graph._get_output_edges_and_graphs(node_path[(len(node_id)+1):], prefix=graph_path) - edges.extend(graph_edges) - - return edges - - - def _is_iterator_connection_valid(self, node_path: str, new_input: Optional[EdgeConnection] = None, new_output: Optional[EdgeConnection] = None) -> bool: - inputs = list([e[0] for e in self._get_input_edges(node_path, 'collection')]) - outputs = list([e[1] for e in self._get_output_edges(node_path, 'item')]) - - if new_input is not None: - inputs.append(new_input) - if new_output is not None: - outputs.append(new_output) - - # Only one input is allowed for iterators - if len(inputs) > 1: - return False - - # Get input and output fields (the fields linked to the iterator's input/output) - input_field = get_output_field(self.get_node(inputs[0].node_id), inputs[0].field) - output_fields = list([get_input_field(self.get_node(e.node_id), e.field) for e in outputs]) - - # Input type must be a list - if get_origin(input_field) != list: - return False - - # Validate that all outputs match the input type - input_field_item_type = get_args(input_field)[0] - if not all((are_connection_types_compatible(input_field_item_type, f) for f in output_fields)): - return False - - return True - - def _is_collector_connection_valid(self, node_path: str, new_input: Optional[EdgeConnection] = None, new_output: Optional[EdgeConnection] = None) -> bool: - inputs = list([e[0] for e in self._get_input_edges(node_path, 'item')]) - outputs = list([e[1] for e in self._get_output_edges(node_path, 'collection')]) - - if new_input is not None: - inputs.append(new_input) - if new_output is not None: - outputs.append(new_output) - - # Get input and output fields (the fields linked to the iterator's input/output) - input_fields = list([get_output_field(self.get_node(e.node_id), e.field) for e in inputs]) - output_fields = list([get_input_field(self.get_node(e.node_id), e.field) for e in outputs]) - - # Validate that all inputs are derived from or match a single type - input_field_types = set([t for input_field in input_fields for t in ([input_field] if get_origin(input_field) == None else get_args(input_field)) if t != NoneType]) # Get unique types - type_tree = nx.DiGraph() - type_tree.add_nodes_from(input_field_types) - type_tree.add_edges_from([e for e in itertools.permutations(input_field_types, 2) if issubclass(e[1], e[0])]) - type_degrees = type_tree.in_degree(type_tree.nodes) - if sum((t[1] == 0 for t in type_degrees)) != 1: - return False # There is more than one root type - - # Get the input root type - input_root_type = next(t[0] for t in type_degrees if t[1] == 0) - - # Verify that all outputs are lists - if not all((get_origin(f) == list for f in output_fields)): - return False - - # Verify that all outputs match the input type (are a base class or the same class) - if not all((issubclass(input_root_type, get_args(f)[0]) for f in output_fields)): - return False - - return True - - def nx_graph(self) -> nx.DiGraph: - """Returns a NetworkX DiGraph representing the layout of this graph""" - # TODO: Cache this? - g = nx.DiGraph() - g.add_nodes_from([n for n in self.nodes.keys()]) - g.add_edges_from(set([(e[0].node_id, e[1].node_id) for e in self.edges])) - return g - - def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None) -> nx.DiGraph: - """Returns a flattened NetworkX DiGraph, including all subgraphs (but not with iterations expanded)""" - g = nx_graph or nx.DiGraph() - - # Add all nodes from this graph except graph/iteration nodes - g.add_nodes_from([self._get_node_path(n.id, prefix) for n in self.nodes.values() if not isinstance(n, GraphInvocation) and not isinstance(n, IterateInvocation)]) - - # Expand graph nodes - for sgn in (gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation)): - sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix)) - - # TODO: figure out if iteration nodes need to be expanded - - unique_edges = set([(e[0].node_id, e[1].node_id) for e in self.edges]) - g.add_edges_from([(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix)) for e in unique_edges]) - return g - - -class GraphExecutionState(BaseModel): - """Tracks the state of a graph execution""" - id: str = Field(description="The id of the execution state", default_factory=uuid.uuid4) - - # TODO: Store a reference to the graph instead of the actual graph? - graph: Graph = Field(description="The graph being executed") - - # The graph of materialized nodes - execution_graph: Graph = Field(description="The expanded graph of activated and executed nodes", default_factory=Graph) - - # Nodes that have been executed - executed: set[str] = Field(description="The set of node ids that have been executed", default_factory=set) - executed_history: list[str] = Field(description="The list of node ids that have been executed, in order of execution", default_factory=list) - - # The results of executed nodes - results: dict[str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]] = Field(description="The results of node executions", default_factory=dict) - - # Map of prepared/executed nodes to their original nodes - prepared_source_mapping: dict[str, str] = Field(description="The map of prepared nodes to original graph nodes", default_factory=dict) - - # Map of original nodes to prepared nodes - source_prepared_mapping: dict[str, set[str]] = Field(description="The map of original graph nodes to prepared nodes", default_factory=dict) - - def next(self) -> BaseInvocation | None: - """Gets the next node ready to execute.""" - - # TODO: enable multiple nodes to execute simultaneously by tracking currently executing nodes - # possibly with a timeout? - - # If there are no prepared nodes, prepare some nodes - next_node = self._get_next_node() - if next_node is None: - prepared_id = self._prepare() - - # TODO: prepare multiple nodes at once? - # while prepared_id is not None and not isinstance(self.graph.nodes[prepared_id], IterateInvocation): - # prepared_id = self._prepare() - - if prepared_id is not None: - next_node = self._get_next_node() - - # Get values from edges - if next_node is not None: - self._prepare_inputs(next_node) - - # If next is still none, there's no next node, return None - return next_node - - def complete(self, node_id: str, output: InvocationOutputsUnion): - """Marks a node as complete""" - - if node_id not in self.execution_graph.nodes: - return # TODO: log error? - - # Mark node as executed - self.executed.add(node_id) - self.results[node_id] = output - - # Check if source node is complete (all prepared nodes are complete) - source_node = self.prepared_source_mapping[node_id] - prepared_nodes = self.source_prepared_mapping[source_node] - - if all([n in self.executed for n in prepared_nodes]): - self.executed.add(source_node) - self.executed_history.append(source_node) - - def is_complete(self) -> bool: - """Returns true if the graph is complete""" - return all((k in self.executed for k in self.graph.nodes)) - - def _create_execution_node(self, node_path: str, iteration_node_map: list[tuple[str, str]]) -> list[str]: - """Prepares an iteration node and connects all edges, returning the new node id""" - - node = self.graph.get_node(node_path) - - self_iteration_count = -1 - - # If this is an iterator node, we must create a copy for each iteration - if isinstance(node, IterateInvocation): - # Get input collection edge (should error if there are no inputs) - input_collection_edge = next(iter(self.graph._get_input_edges(node_path, 'collection'))) - input_collection_prepared_node_id = next(n[1] for n in iteration_node_map if n[0] == input_collection_edge[0].node_id) - input_collection_prepared_node_output = self.results[input_collection_prepared_node_id] - input_collection = getattr(input_collection_prepared_node_output, input_collection_edge[0].field) - self_iteration_count = len(input_collection) - - new_nodes = list() - if self_iteration_count == 0: - # TODO: should this raise a warning? It might just happen if an empty collection is input, and should be valid. - return new_nodes - - # Get all input edges - input_edges = self.graph._get_input_edges(node_path) - - # Create new edges for this iteration - # For collect nodes, this may contain multiple inputs to the same field - new_edges = list() - for edge in input_edges: - for input_node_id in (n[1] for n in iteration_node_map if n[0] == edge[0].node_id): - new_edge = (EdgeConnection(node_id = input_node_id, field = edge[0].field), EdgeConnection(node_id = '', field = edge[1].field)) - new_edges.append(new_edge) - - # Create a new node (or one for each iteration of this iterator) - for i in (range(self_iteration_count) if self_iteration_count > 0 else [-1]): - # Create a new node - new_node = copy.deepcopy(node) - - # Create the node id (use a random uuid) - new_node.id = str(uuid.uuid4()) - - # Set the iteration index for iteration invocations - if isinstance(new_node, IterateInvocation): - new_node.index = i - - # Add to execution graph - self.execution_graph.add_node(new_node) - self.prepared_source_mapping[new_node.id] = node_path - if node_path not in self.source_prepared_mapping: - self.source_prepared_mapping[node_path] = set() - self.source_prepared_mapping[node_path].add(new_node.id) - - # Add new edges to execution graph - for edge in new_edges: - new_edge = (edge[0], EdgeConnection(node_id = new_node.id, field = edge[1].field)) - self.execution_graph.add_edge(new_edge) - - new_nodes.append(new_node.id) - - return new_nodes - - def _iterator_graph(self) -> nx.DiGraph: - """Gets a DiGraph with edges to collectors removed so an ancestor search produces all active iterators for any node""" - g = self.graph.nx_graph() - collectors = (n for n in self.graph.nodes if isinstance(self.graph.nodes[n], CollectInvocation)) - for c in collectors: - g.remove_edges_from(list(g.in_edges(c))) - return g - - - def _get_node_iterators(self, node_id: str) -> list[str]: - """Gets iterators for a node""" - g = self._iterator_graph() - iterators = [n for n in nx.ancestors(g, node_id) if isinstance(self.graph.nodes[n], IterateInvocation)] - return iterators - - - def _prepare(self) -> Optional[str]: - # Get flattened source graph - g = self.graph.nx_graph_flat() - - # Find next unprepared node where all source nodes are executed - sorted_nodes = nx.topological_sort(g) - next_node_id = next((n for n in sorted_nodes if n not in self.source_prepared_mapping and all((e[0] in self.executed for e in g.in_edges(n)))), None) - - if next_node_id == None: - return None - - # Get all parents of the next node - next_node_parents = [e[0] for e in g.in_edges(next_node_id)] - - # Create execution nodes - next_node = self.graph.get_node(next_node_id) - new_node_ids = list() - if isinstance(next_node, CollectInvocation): - # Collapse all iterator input mappings and create a single execution node for the collect invocation - all_iteration_mappings = list(itertools.chain(*(((s,p) for p in self.source_prepared_mapping[s]) for s in next_node_parents))) - #all_iteration_mappings = list(set(itertools.chain(*prepared_parent_mappings))) - create_results = self._create_execution_node(next_node_id, all_iteration_mappings) - if create_results is not None: - new_node_ids.extend(create_results) - else: # Iterators or normal nodes - # Get all iterator combinations for this node - # Will produce a list of lists of prepared iterator nodes, from which results can be iterated - iterator_nodes = self._get_node_iterators(next_node_id) - iterator_nodes_prepared = [list(self.source_prepared_mapping[n]) for n in iterator_nodes] - iterator_node_prepared_combinations = list(itertools.product(*iterator_nodes_prepared)) - - # Select the correct prepared parents for each iteration - # For every iterator, the parent must either not be a child of that iterator, or must match the prepared iteration for that iterator - # TODO: Handle a node mapping to none - eg = self.execution_graph.nx_graph_flat() - prepared_parent_mappings = [[(n,self._get_iteration_node(n, g, eg, it)) for n in next_node_parents] for it in iterator_node_prepared_combinations] - - # Create execution node for each iteration - for iteration_mappings in prepared_parent_mappings: - create_results = self._create_execution_node(next_node_id, iteration_mappings) - if create_results is not None: - new_node_ids.extend(create_results) - - return next(iter(new_node_ids), None) - - def _get_iteration_node(self, source_node_path: str, graph: nx.DiGraph, execution_graph: nx.DiGraph, prepared_iterator_nodes: list[str]) -> Optional[str]: - """Gets the prepared version of the specified source node that matches every iteration specified""" - prepared_nodes = self.source_prepared_mapping[source_node_path] - if len(prepared_nodes) == 1: - return next(iter(prepared_nodes)) - - # Check if the requested node is an iterator - prepared_iterator = next((n for n in prepared_nodes if n in prepared_iterator_nodes), None) - if prepared_iterator is not None: - return prepared_iterator - - # Filter to only iterator nodes that are a parent of the specified node, in tuple format (prepared, source) - iterator_source_node_mapping = [(n, self.prepared_source_mapping[n]) for n in prepared_iterator_nodes] - parent_iterators = [itn for itn in iterator_source_node_mapping if nx.has_path(graph, itn[1], source_node_path)] - - return next((n for n in prepared_nodes if all(pit for pit in parent_iterators if nx.has_path(execution_graph, pit[0], n))), None) - - def _get_next_node(self) -> Optional[BaseInvocation]: - g = self.execution_graph.nx_graph() - sorted_nodes = nx.topological_sort(g) - next_node = next((n for n in sorted_nodes if n not in self.executed), None) - if next_node is None: - return None - - return self.execution_graph.nodes[next_node] - - def _prepare_inputs(self, node: BaseInvocation): - input_edges = [e for e in self.execution_graph.edges if e[1].node_id == node.id] - if isinstance(node, CollectInvocation): - output_collection = [getattr(self.results[edge[0].node_id], edge[0].field) for edge in input_edges if edge[1].field == 'item'] - setattr(node, 'collection', output_collection) - else: - for edge in input_edges: - output_value = getattr(self.results[edge[0].node_id], edge[0].field) - setattr(node, edge[1].field, output_value) - - # TODO: Add API for modifying underlying graph that checks if the change will be valid given the current execution state - def _is_edge_valid(self, edge: tuple[EdgeConnection, EdgeConnection]) -> bool: - if not self._is_edge_valid(edge): - return False - - # Invalid if destination has already been prepared or executed - if edge[1].node_id in self.source_prepared_mapping: - return False - - # Otherwise, the edge is valid - return True - - def _is_node_updatable(self, node_id: str) -> bool: - # The node is updatable as long as it hasn't been prepared or executed - return node_id not in self.source_prepared_mapping - - def add_node(self, node: BaseInvocation) -> None: - self.graph.add_node(node) - - def update_node(self, node_path: str, new_node: BaseInvocation) -> None: - if not self._is_node_updatable(node_path): - raise NodeAlreadyExecutedError(f'Node {node_path} has already been prepared or executed and cannot be updated') - self.graph.update_node(node_path, new_node) - - def delete_node(self, node_path: str) -> None: - if not self._is_node_updatable(node_path): - raise NodeAlreadyExecutedError(f'Node {node_path} has already been prepared or executed and cannot be deleted') - self.graph.delete_node(node_path) - - def add_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None: - if not self._is_node_updatable(edge[1].node_id): - raise NodeAlreadyExecutedError(f'Destination node {edge[1].node_id} has already been prepared or executed and cannot be linked to') - self.graph.add_edge(edge) - - def delete_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None: - if not self._is_node_updatable(edge[1].node_id): - raise NodeAlreadyExecutedError(f'Destination node {edge[1].node_id} has already been prepared or executed and cannot have a source edge deleted') - self.graph.delete_edge(edge) - -GraphInvocation.update_forward_refs() diff --git a/ldm/invoke/app/services/image_storage.py b/ldm/invoke/app/services/image_storage.py deleted file mode 100644 index 03227d870b..0000000000 --- a/ldm/invoke/app/services/image_storage.py +++ /dev/null @@ -1,104 +0,0 @@ -# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) - -from abc import ABC, abstractmethod -from enum import Enum -import datetime -import os -from pathlib import Path -from queue import Queue -from typing import Dict -from PIL.Image import Image -from ...pngwriter import PngWriter - - -class ImageType(str, Enum): - RESULT = 'results' - INTERMEDIATE = 'intermediates' - UPLOAD = 'uploads' - - -class ImageStorageBase(ABC): - """Responsible for storing and retrieving images.""" - - @abstractmethod - def get(self, image_type: ImageType, image_name: str) -> Image: - pass - - # TODO: make this a bit more flexible for e.g. cloud storage - @abstractmethod - def get_path(self, image_type: ImageType, image_name: str) -> str: - pass - - @abstractmethod - def save(self, image_type: ImageType, image_name: str, image: Image) -> None: - pass - - @abstractmethod - def delete(self, image_type: ImageType, image_name: str) -> None: - pass - - def create_name(self, context_id: str, node_id: str) -> str: - return f'{context_id}_{node_id}_{str(int(datetime.datetime.now(datetime.timezone.utc).timestamp()))}.png' - - -class DiskImageStorage(ImageStorageBase): - """Stores images on disk""" - __output_folder: str - __pngWriter: PngWriter - __cache_ids: Queue # TODO: this is an incredibly naive cache - __cache: Dict[str, Image] - __max_cache_size: int - - def __init__(self, output_folder: str): - self.__output_folder = output_folder - self.__pngWriter = PngWriter(output_folder) - self.__cache = dict() - self.__cache_ids = Queue() - self.__max_cache_size = 10 # TODO: get this from config - - Path(output_folder).mkdir(parents=True, exist_ok=True) - - # TODO: don't hard-code. get/save/delete should maybe take subpath? - for image_type in ImageType: - Path(os.path.join(output_folder, image_type)).mkdir(parents=True, exist_ok=True) - - def get(self, image_type: ImageType, image_name: str) -> Image: - image_path = self.get_path(image_type, image_name) - cache_item = self.__get_cache(image_path) - if cache_item: - return cache_item - - image = Image.open(image_path) - self.__set_cache(image_path, image) - return image - - # TODO: make this a bit more flexible for e.g. cloud storage - def get_path(self, image_type: ImageType, image_name: str) -> str: - path = os.path.join(self.__output_folder, image_type, image_name) - return path - - def save(self, image_type: ImageType, image_name: str, image: Image) -> None: - image_subpath = os.path.join(image_type, image_name) - self.__pngWriter.save_image_and_prompt_to_png(image, "", image_subpath, None) # TODO: just pass full path to png writer - - image_path = self.get_path(image_type, image_name) - self.__set_cache(image_path, image) - - def delete(self, image_type: ImageType, image_name: str) -> None: - image_path = self.get_path(image_type, image_name) - if os.path.exists(image_path): - os.remove(image_path) - - if image_path in self.__cache: - del self.__cache[image_path] - - def __get_cache(self, image_name: str) -> Image: - return None if image_name not in self.__cache else self.__cache[image_name] - - def __set_cache(self, image_name: str, image: Image): - if not image_name in self.__cache: - self.__cache[image_name] = image - self.__cache_ids.put(image_name) # TODO: this should refresh position for LRU cache - if len(self.__cache) > self.__max_cache_size: - cache_id = self.__cache_ids.get() - del self.__cache[cache_id] diff --git a/ldm/invoke/app/services/invocation_queue.py b/ldm/invoke/app/services/invocation_queue.py deleted file mode 100644 index 0a5b5ae3bb..0000000000 --- a/ldm/invoke/app/services/invocation_queue.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) - -from abc import ABC, abstractmethod -from queue import Queue - - -# TODO: make this serializable -class InvocationQueueItem: - #session_id: str - graph_execution_state_id: str - invocation_id: str - invoke_all: bool - - def __init__(self, - #session_id: str, - graph_execution_state_id: str, - invocation_id: str, - invoke_all: bool = False): - #self.session_id = session_id - self.graph_execution_state_id = graph_execution_state_id - self.invocation_id = invocation_id - self.invoke_all = invoke_all - - -class InvocationQueueABC(ABC): - """Abstract base class for all invocation queues""" - @abstractmethod - def get(self) -> InvocationQueueItem: - pass - - @abstractmethod - def put(self, item: InvocationQueueItem|None) -> None: - pass - - -class MemoryInvocationQueue(InvocationQueueABC): - __queue: Queue - - def __init__(self): - self.__queue = Queue() - - def get(self) -> InvocationQueueItem: - return self.__queue.get() - - def put(self, item: InvocationQueueItem|None) -> None: - self.__queue.put(item) diff --git a/ldm/invoke/app/services/invocation_services.py b/ldm/invoke/app/services/invocation_services.py deleted file mode 100644 index 9eb5309d3d..0000000000 --- a/ldm/invoke/app/services/invocation_services.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) -from .image_storage import ImageStorageBase -from .events import EventServiceBase -from ....generate import Generate - - -class InvocationServices(): - """Services that can be used by invocations""" - generate: Generate # TODO: wrap Generate, or split it up from model? - events: EventServiceBase - images: ImageStorageBase - - def __init__(self, - generate: Generate, - events: EventServiceBase, - images: ImageStorageBase - ): - self.generate = generate - self.events = events - self.images = images diff --git a/ldm/invoke/app/services/invoker.py b/ldm/invoke/app/services/invoker.py deleted file mode 100644 index 796f541781..0000000000 --- a/ldm/invoke/app/services/invoker.py +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) - -from abc import ABC -from threading import Event, Thread -from .graph import Graph, GraphExecutionState -from .item_storage import ItemStorageABC -from ..invocations.baseinvocation import InvocationContext -from .invocation_services import InvocationServices -from .invocation_queue import InvocationQueueABC, InvocationQueueItem - - -class InvokerServices: - """Services used by the Invoker for execution""" - - queue: InvocationQueueABC - graph_execution_manager: ItemStorageABC[GraphExecutionState] - processor: 'InvocationProcessorABC' - - def __init__(self, - queue: InvocationQueueABC, - graph_execution_manager: ItemStorageABC[GraphExecutionState], - processor: 'InvocationProcessorABC'): - self.queue = queue - self.graph_execution_manager = graph_execution_manager - self.processor = processor - - -class Invoker: - """The invoker, used to execute invocations""" - - services: InvocationServices - invoker_services: InvokerServices - - def __init__(self, - services: InvocationServices, # Services used by nodes to perform invocations - invoker_services: InvokerServices # Services used by the invoker for orchestration - ): - self.services = services - self.invoker_services = invoker_services - self._start() - - - def invoke(self, graph_execution_state: GraphExecutionState, invoke_all: bool = False) -> str|None: - """Determines the next node to invoke and returns the id of the invoked node, or None if there are no nodes to execute""" - - # Get the next invocation - invocation = graph_execution_state.next() - if not invocation: - return None - - # Save the execution state - self.invoker_services.graph_execution_manager.set(graph_execution_state) - - # Queue the invocation - print(f'queueing item {invocation.id}') - self.invoker_services.queue.put(InvocationQueueItem( - #session_id = session.id, - graph_execution_state_id = graph_execution_state.id, - invocation_id = invocation.id, - invoke_all = invoke_all - )) - - return invocation.id - - - def create_execution_state(self, graph: Graph|None = None) -> GraphExecutionState: - """Creates a new execution state for the given graph""" - new_state = GraphExecutionState(graph = Graph() if graph is None else graph) - self.invoker_services.graph_execution_manager.set(new_state) - return new_state - - - def __start_service(self, service) -> None: - # Call start() method on any services that have it - start_op = getattr(service, 'start', None) - if callable(start_op): - start_op(self) - - - def __stop_service(self, service) -> None: - # Call stop() method on any services that have it - stop_op = getattr(service, 'stop', None) - if callable(stop_op): - stop_op(self) - - - def _start(self) -> None: - """Starts the invoker. This is called automatically when the invoker is created.""" - for service in vars(self.invoker_services): - self.__start_service(getattr(self.invoker_services, service)) - - for service in vars(self.services): - self.__start_service(getattr(self.services, service)) - - - def stop(self) -> None: - """Stops the invoker. A new invoker will have to be created to execute further.""" - # First stop all services - for service in vars(self.services): - self.__stop_service(getattr(self.services, service)) - - for service in vars(self.invoker_services): - self.__stop_service(getattr(self.invoker_services, service)) - - self.invoker_services.queue.put(None) - - -class InvocationProcessorABC(ABC): - pass \ No newline at end of file diff --git a/ldm/invoke/app/services/item_storage.py b/ldm/invoke/app/services/item_storage.py deleted file mode 100644 index 738f06cb7e..0000000000 --- a/ldm/invoke/app/services/item_storage.py +++ /dev/null @@ -1,57 +0,0 @@ - -from typing import Callable, TypeVar, Generic -from pydantic import BaseModel, Field -from pydantic.generics import GenericModel -from abc import ABC, abstractmethod - -T = TypeVar('T', bound=BaseModel) - -class PaginatedResults(GenericModel, Generic[T]): - """Paginated results""" - items: list[T] = Field(description = "Items") - page: int = Field(description = "Current Page") - pages: int = Field(description = "Total number of pages") - per_page: int = Field(description = "Number of items per page") - total: int = Field(description = "Total number of items in result") - - -class ItemStorageABC(ABC, Generic[T]): - _on_changed_callbacks: list[Callable[[T], None]] - _on_deleted_callbacks: list[Callable[[str], None]] - - def __init__(self) -> None: - self._on_changed_callbacks = list() - self._on_deleted_callbacks = list() - - """Base item storage class""" - @abstractmethod - def get(self, item_id: str) -> T: - pass - - @abstractmethod - def set(self, item: T) -> None: - pass - - @abstractmethod - def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]: - pass - - @abstractmethod - def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]: - pass - - def on_changed(self, on_changed: Callable[[T], None]) -> None: - """Register a callback for when an item is changed""" - self._on_changed_callbacks.append(on_changed) - - def on_deleted(self, on_deleted: Callable[[str], None]) -> None: - """Register a callback for when an item is deleted""" - self._on_deleted_callbacks.append(on_deleted) - - def _on_changed(self, item: T) -> None: - for callback in self._on_changed_callbacks: - callback(item) - - def _on_deleted(self, item_id: str) -> None: - for callback in self._on_deleted_callbacks: - callback(item_id) diff --git a/ldm/invoke/app/services/processor.py b/ldm/invoke/app/services/processor.py deleted file mode 100644 index 9b51a6bcbc..0000000000 --- a/ldm/invoke/app/services/processor.py +++ /dev/null @@ -1,78 +0,0 @@ -from threading import Event, Thread -from ..invocations.baseinvocation import InvocationContext -from .invocation_queue import InvocationQueueItem -from .invoker import InvocationProcessorABC, Invoker - - -class DefaultInvocationProcessor(InvocationProcessorABC): - __invoker_thread: Thread - __stop_event: Event - __invoker: Invoker - - def start(self, invoker) -> None: - self.__invoker = invoker - self.__stop_event = Event() - self.__invoker_thread = Thread( - name = "invoker_processor", - target = self.__process, - kwargs = dict(stop_event = self.__stop_event) - ) - self.__invoker_thread.daemon = True # TODO: probably better to just not use threads? - self.__invoker_thread.start() - - - def stop(self, *args, **kwargs) -> None: - self.__stop_event.set() - - - def __process(self, stop_event: Event): - try: - while not stop_event.is_set(): - queue_item: InvocationQueueItem = self.__invoker.invoker_services.queue.get() - if not queue_item: # Probably stopping - continue - - graph_execution_state = self.__invoker.invoker_services.graph_execution_manager.get(queue_item.graph_execution_state_id) - invocation = graph_execution_state.execution_graph.get_node(queue_item.invocation_id) - - # Send starting event - self.__invoker.services.events.emit_invocation_started( - graph_execution_state_id = graph_execution_state.id, - invocation_id = invocation.id - ) - - # Invoke - try: - outputs = invocation.invoke(InvocationContext( - services = self.__invoker.services, - graph_execution_state_id = graph_execution_state.id - )) - - # Save outputs and history - graph_execution_state.complete(invocation.id, outputs) - - # Save the state changes - self.__invoker.invoker_services.graph_execution_manager.set(graph_execution_state) - - # Send complete event - self.__invoker.services.events.emit_invocation_complete( - graph_execution_state_id = graph_execution_state.id, - invocation_id = invocation.id, - result = outputs.dict() - ) - - # Queue any further commands if invoking all - is_complete = graph_execution_state.is_complete() - if queue_item.invoke_all and not is_complete: - self.__invoker.invoke(graph_execution_state, invoke_all = True) - elif is_complete: - self.__invoker.services.events.emit_graph_execution_complete(graph_execution_state.id) - except KeyboardInterrupt: - pass - except Exception as e: - # TODO: Log the error, mark the invocation as failed, and emit an event - print(f'Error invoking {invocation.id}: {e}') - pass - - except KeyboardInterrupt: - ... # Log something? diff --git a/ldm/invoke/app/services/sqlite.py b/ldm/invoke/app/services/sqlite.py deleted file mode 100644 index 8858bbd874..0000000000 --- a/ldm/invoke/app/services/sqlite.py +++ /dev/null @@ -1,119 +0,0 @@ -import sqlite3 -from threading import Lock -from typing import Generic, TypeVar, Union, get_args -from pydantic import BaseModel, parse_raw_as -from .item_storage import ItemStorageABC, PaginatedResults - -T = TypeVar('T', bound=BaseModel) - -sqlite_memory = ':memory:' - -class SqliteItemStorage(ItemStorageABC, Generic[T]): - _filename: str - _table_name: str - _conn: sqlite3.Connection - _cursor: sqlite3.Cursor - _id_field: str - _lock: Lock - - def __init__(self, filename: str, table_name: str, id_field: str = 'id'): - super().__init__() - - self._filename = filename - self._table_name = table_name - self._id_field = id_field # TODO: validate that T has this field - self._lock = Lock() - - self._conn = sqlite3.connect(self._filename, check_same_thread=False) # TODO: figure out a better threading solution - self._cursor = self._conn.cursor() - - self._create_table() - - def _create_table(self): - try: - self._lock.acquire() - self._cursor.execute(f'''CREATE TABLE IF NOT EXISTS {self._table_name} ( - item TEXT, - id TEXT GENERATED ALWAYS AS (json_extract(item, '$.{self._id_field}')) VIRTUAL NOT NULL);''') - self._cursor.execute(f'''CREATE UNIQUE INDEX IF NOT EXISTS {self._table_name}_id ON {self._table_name}(id);''') - finally: - self._lock.release() - - def _parse_item(self, item: str) -> T: - item_type = get_args(self.__orig_class__)[0] - return parse_raw_as(item_type, item) - - def set(self, item: T): - try: - self._lock.acquire() - self._cursor.execute(f'''INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);''', (item.json(),)) - finally: - self._lock.release() - self._on_changed(item) - - def get(self, id: str) -> Union[T, None]: - try: - self._lock.acquire() - self._cursor.execute(f'''SELECT item FROM {self._table_name} WHERE id = ?;''', (str(id),)) - result = self._cursor.fetchone() - finally: - self._lock.release() - - if not result: - return None - - return self._parse_item(result[0]) - - def delete(self, id: str): - try: - self._lock.acquire() - self._cursor.execute(f'''DELETE FROM {self._table_name} WHERE id = ?;''', (str(id),)) - finally: - self._lock.release() - self._on_deleted(id) - - def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]: - try: - self._lock.acquire() - self._cursor.execute(f'''SELECT item FROM {self._table_name} LIMIT ? OFFSET ?;''', (per_page, page * per_page)) - result = self._cursor.fetchall() - - items = list(map(lambda r: self._parse_item(r[0]), result)) - - self._cursor.execute(f'''SELECT count(*) FROM {self._table_name};''') - count = self._cursor.fetchone()[0] - finally: - self._lock.release() - - pageCount = int(count / per_page) + 1 - - return PaginatedResults[T]( - items = items, - page = page, - pages = pageCount, - per_page = per_page, - total = count - ) - - def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]: - try: - self._lock.acquire() - self._cursor.execute(f'''SELECT item FROM {self._table_name} WHERE item LIKE ? LIMIT ? OFFSET ?;''', (f'%{query}%', per_page, page * per_page)) - result = self._cursor.fetchall() - - items = list(map(lambda r: self._parse_item(r[0]), result)) - - self._cursor.execute(f'''SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;''', (f'%{query}%',)) - count = self._cursor.fetchone()[0] - finally: - self._lock.release() - - pageCount = int(count / per_page) + 1 - - return PaginatedResults[T]( - items = items, - page = page, - pages = pageCount, - per_page = per_page, - total = count - ) diff --git a/pyproject.toml b/pyproject.toml index bfa36ff7d6..b544b9eb9c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,9 +39,6 @@ dependencies = [ "einops", "eventlet", "facexlib", - "fastapi==0.85.0", - "fastapi-events==0.6.0", - "fastapi-socketio==0.0.9", "flask==2.1.3", "flask_cors==3.0.10", "flask_socketio==5.3.0", @@ -63,7 +60,6 @@ dependencies = [ "pudb", "pypatchmatch", "pyreadline3", - "python-multipart==0.0.5", "pytorch-lightning==1.7.7", "realesrgan", "requests==2.28.2", @@ -78,7 +74,6 @@ dependencies = [ "torchmetrics", "torchvision>=0.14.1", "transformers~=4.25", - "uvicorn[standard]==0.20.0", "windows-curses; sys_platform=='win32'", ] description = "An implementation of Stable Diffusion which provides various new features and options to aid the image generation process" diff --git a/scripts/invoke-new.py b/scripts/invoke-new.py deleted file mode 100644 index 2bc5330a5c..0000000000 --- a/scripts/invoke-new.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) - -import os -import sys - -def main(): - # Change working directory to the repo root - os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) - - if '--web' in sys.argv: - from ldm.invoke.app.api_app import invoke_api - invoke_api() - else: - # TODO: Parse some top-level args here. - from ldm.invoke.app.cli_app import invoke_cli - invoke_cli() - - -if __name__ == '__main__': - main() diff --git a/static/dream_web/test.html b/static/dream_web/test.html deleted file mode 100644 index e99abb3703..0000000000 --- a/static/dream_web/test.html +++ /dev/null @@ -1,206 +0,0 @@ - - - - InvokeAI Test - - - - - - - - - - - - - - - -
- -
- - - - - \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/nodes/__init__.py b/tests/nodes/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/nodes/test_graph_execution_state.py b/tests/nodes/test_graph_execution_state.py deleted file mode 100644 index 0a5dcc7734..0000000000 --- a/tests/nodes/test_graph_execution_state.py +++ /dev/null @@ -1,114 +0,0 @@ -from .test_invoker import create_edge -from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation -from ldm.invoke.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext -from ldm.invoke.app.services.invocation_services import InvocationServices -from ldm.invoke.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState -from ldm.invoke.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation -from ldm.invoke.app.invocations.upscale import UpscaleInvocation -import pytest - - -@pytest.fixture -def simple_graph(): - g = Graph() - g.add_node(PromptTestInvocation(id = "1", prompt = "Banana sushi")) - g.add_node(ImageTestInvocation(id = "2")) - g.add_edge(create_edge("1", "prompt", "2", "prompt")) - return g - -@pytest.fixture -def mock_services(): - # NOTE: none of these are actually called by the test invocations - return InvocationServices(generate = None, events = None, images = None) - -def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]: - n = g.next() - if n is None: - return (None, None) - - print(f'invoking {n.id}: {type(n)}') - o = n.invoke(InvocationContext(services, "1")) - g.complete(n.id, o) - - return (n, o) - -def test_graph_state_executes_in_order(simple_graph, mock_services): - g = GraphExecutionState(graph = simple_graph) - - n1 = invoke_next(g, mock_services) - n2 = invoke_next(g, mock_services) - n3 = g.next() - - assert g.prepared_source_mapping[n1[0].id] == "1" - assert g.prepared_source_mapping[n2[0].id] == "2" - assert n3 is None - assert g.results[n1[0].id].prompt == n1[0].prompt - assert n2[0].prompt == n1[0].prompt - -def test_graph_is_complete(simple_graph, mock_services): - g = GraphExecutionState(graph = simple_graph) - n1 = invoke_next(g, mock_services) - n2 = invoke_next(g, mock_services) - n3 = g.next() - - assert g.is_complete() - -def test_graph_is_not_complete(simple_graph, mock_services): - g = GraphExecutionState(graph = simple_graph) - n1 = invoke_next(g, mock_services) - n2 = g.next() - - assert not g.is_complete() - -# TODO: test completion with iterators/subgraphs - -def test_graph_state_expands_iterator(mock_services): - graph = Graph() - test_prompts = ["Banana sushi", "Cat sushi"] - graph.add_node(PromptCollectionTestInvocation(id = "1", collection = list(test_prompts))) - graph.add_node(IterateInvocation(id = "2")) - graph.add_node(ImageTestInvocation(id = "3")) - graph.add_edge(create_edge("1", "collection", "2", "collection")) - graph.add_edge(create_edge("2", "item", "3", "prompt")) - - g = GraphExecutionState(graph = graph) - n1 = invoke_next(g, mock_services) - n2 = invoke_next(g, mock_services) - n3 = invoke_next(g, mock_services) - n4 = invoke_next(g, mock_services) - n5 = invoke_next(g, mock_services) - - assert g.prepared_source_mapping[n1[0].id] == "1" - assert g.prepared_source_mapping[n2[0].id] == "2" - assert g.prepared_source_mapping[n3[0].id] == "2" - assert g.prepared_source_mapping[n4[0].id] == "3" - assert g.prepared_source_mapping[n5[0].id] == "3" - - assert isinstance(n4[0], ImageTestInvocation) - assert isinstance(n5[0], ImageTestInvocation) - - prompts = [n4[0].prompt, n5[0].prompt] - assert sorted(prompts) == sorted(test_prompts) - -def test_graph_state_collects(mock_services): - graph = Graph() - test_prompts = ["Banana sushi", "Cat sushi"] - graph.add_node(PromptCollectionTestInvocation(id = "1", collection = list(test_prompts))) - graph.add_node(IterateInvocation(id = "2")) - graph.add_node(PromptTestInvocation(id = "3")) - graph.add_node(CollectInvocation(id = "4")) - graph.add_edge(create_edge("1", "collection", "2", "collection")) - graph.add_edge(create_edge("2", "item", "3", "prompt")) - graph.add_edge(create_edge("3", "prompt", "4", "item")) - - g = GraphExecutionState(graph = graph) - n1 = invoke_next(g, mock_services) - n2 = invoke_next(g, mock_services) - n3 = invoke_next(g, mock_services) - n4 = invoke_next(g, mock_services) - n5 = invoke_next(g, mock_services) - n6 = invoke_next(g, mock_services) - - assert isinstance(n6[0], CollectInvocation) - - assert sorted(g.results[n6[0].id].collection) == sorted(test_prompts) diff --git a/tests/nodes/test_invoker.py b/tests/nodes/test_invoker.py deleted file mode 100644 index a6d96f61c0..0000000000 --- a/tests/nodes/test_invoker.py +++ /dev/null @@ -1,85 +0,0 @@ -from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation, TestEventService, create_edge, wait_until -from ldm.invoke.app.services.processor import DefaultInvocationProcessor -from ldm.invoke.app.services.sqlite import SqliteItemStorage, sqlite_memory -from ldm.invoke.app.services.invocation_queue import MemoryInvocationQueue -from ldm.invoke.app.services.invoker import Invoker, InvokerServices -from ldm.invoke.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext -from ldm.invoke.app.services.invocation_services import InvocationServices -from ldm.invoke.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState -from ldm.invoke.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation -from ldm.invoke.app.invocations.upscale import UpscaleInvocation -import pytest - - -@pytest.fixture -def simple_graph(): - g = Graph() - g.add_node(PromptTestInvocation(id = "1", prompt = "Banana sushi")) - g.add_node(ImageTestInvocation(id = "2")) - g.add_edge(create_edge("1", "prompt", "2", "prompt")) - return g - -@pytest.fixture -def mock_services() -> InvocationServices: - # NOTE: none of these are actually called by the test invocations - return InvocationServices(generate = None, events = TestEventService(), images = None) - -@pytest.fixture() -def mock_invoker_services() -> InvokerServices: - return InvokerServices( - queue = MemoryInvocationQueue(), - graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'), - processor = DefaultInvocationProcessor() - ) - -@pytest.fixture() -def mock_invoker(mock_services: InvocationServices, mock_invoker_services: InvokerServices) -> Invoker: - return Invoker( - services = mock_services, - invoker_services = mock_invoker_services - ) - -def test_can_create_graph_state(mock_invoker: Invoker): - g = mock_invoker.create_execution_state() - mock_invoker.stop() - - assert g is not None - assert isinstance(g, GraphExecutionState) - -def test_can_create_graph_state_from_graph(mock_invoker: Invoker, simple_graph): - g = mock_invoker.create_execution_state(graph = simple_graph) - mock_invoker.stop() - - assert g is not None - assert isinstance(g, GraphExecutionState) - assert g.graph == simple_graph - -def test_can_invoke(mock_invoker: Invoker, simple_graph): - g = mock_invoker.create_execution_state(graph = simple_graph) - invocation_id = mock_invoker.invoke(g) - assert invocation_id is not None - - def has_executed_any(g: GraphExecutionState): - g = mock_invoker.invoker_services.graph_execution_manager.get(g.id) - return len(g.executed) > 0 - - wait_until(lambda: has_executed_any(g), timeout = 5, interval = 1) - mock_invoker.stop() - - g = mock_invoker.invoker_services.graph_execution_manager.get(g.id) - assert len(g.executed) > 0 - -def test_can_invoke_all(mock_invoker: Invoker, simple_graph): - g = mock_invoker.create_execution_state(graph = simple_graph) - invocation_id = mock_invoker.invoke(g, invoke_all = True) - assert invocation_id is not None - - def has_executed_all(g: GraphExecutionState): - g = mock_invoker.invoker_services.graph_execution_manager.get(g.id) - return g.is_complete() - - wait_until(lambda: has_executed_all(g), timeout = 5, interval = 1) - mock_invoker.stop() - - g = mock_invoker.invoker_services.graph_execution_manager.get(g.id) - assert g.is_complete() diff --git a/tests/nodes/test_node_graph.py b/tests/nodes/test_node_graph.py deleted file mode 100644 index 1b5b341192..0000000000 --- a/tests/nodes/test_node_graph.py +++ /dev/null @@ -1,501 +0,0 @@ -from ldm.invoke.app.invocations.image import * - -from .test_nodes import ListPassThroughInvocation, PromptTestInvocation -from ldm.invoke.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation -from ldm.invoke.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation -from ldm.invoke.app.invocations.upscale import UpscaleInvocation -import pytest - - -# Helpers -def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> tuple[EdgeConnection, EdgeConnection]: - return (EdgeConnection(node_id = from_id, field = from_field), EdgeConnection(node_id = to_id, field = to_field)) - -# Tests -def test_connections_are_compatible(): - from_node = TextToImageInvocation(id = "1", prompt = "Banana sushi") - from_field = "image" - to_node = UpscaleInvocation(id = "2") - to_field = "image" - - result = are_connections_compatible(from_node, from_field, to_node, to_field) - - assert result == True - -def test_connections_are_incompatible(): - from_node = TextToImageInvocation(id = "1", prompt = "Banana sushi") - from_field = "image" - to_node = UpscaleInvocation(id = "2") - to_field = "strength" - - result = are_connections_compatible(from_node, from_field, to_node, to_field) - - assert result == False - -def test_connections_incompatible_with_invalid_fields(): - from_node = TextToImageInvocation(id = "1", prompt = "Banana sushi") - from_field = "invalid_field" - to_node = UpscaleInvocation(id = "2") - to_field = "image" - - # From field is invalid - result = are_connections_compatible(from_node, from_field, to_node, to_field) - assert result == False - - # To field is invalid - from_field = "image" - to_field = "invalid_field" - - result = are_connections_compatible(from_node, from_field, to_node, to_field) - assert result == False - -def test_graph_can_add_node(): - g = Graph() - n = TextToImageInvocation(id = "1", prompt = "Banana sushi") - g.add_node(n) - - assert n.id in g.nodes - -def test_graph_fails_to_add_node_with_duplicate_id(): - g = Graph() - n = TextToImageInvocation(id = "1", prompt = "Banana sushi") - g.add_node(n) - n2 = TextToImageInvocation(id = "1", prompt = "Banana sushi the second") - - with pytest.raises(NodeAlreadyInGraphError): - g.add_node(n2) - -def test_graph_updates_node(): - g = Graph() - n = TextToImageInvocation(id = "1", prompt = "Banana sushi") - g.add_node(n) - n2 = TextToImageInvocation(id = "2", prompt = "Banana sushi the second") - g.add_node(n2) - - nu = TextToImageInvocation(id = "1", prompt = "Banana sushi updated") - - g.update_node("1", nu) - - assert g.nodes["1"].prompt == "Banana sushi updated" - -def test_graph_fails_to_update_node_if_type_changes(): - g = Graph() - n = TextToImageInvocation(id = "1", prompt = "Banana sushi") - g.add_node(n) - n2 = UpscaleInvocation(id = "2") - g.add_node(n2) - - nu = UpscaleInvocation(id = "1") - - with pytest.raises(TypeError): - g.update_node("1", nu) - -def test_graph_allows_non_conflicting_id_change(): - g = Graph() - n = TextToImageInvocation(id = "1", prompt = "Banana sushi") - g.add_node(n) - n2 = UpscaleInvocation(id = "2") - g.add_node(n2) - e1 = create_edge(n.id,"image",n2.id,"image") - g.add_edge(e1) - - nu = TextToImageInvocation(id = "3", prompt = "Banana sushi") - g.update_node("1", nu) - - with pytest.raises(NodeNotFoundError): - g.get_node("1") - - assert g.get_node("3").prompt == "Banana sushi" - - assert len(g.edges) == 1 - assert (EdgeConnection(node_id = "3", field = "image"), EdgeConnection(node_id = "2", field = "image")) in g.edges - -def test_graph_fails_to_update_node_id_if_conflict(): - g = Graph() - n = TextToImageInvocation(id = "1", prompt = "Banana sushi") - g.add_node(n) - n2 = TextToImageInvocation(id = "2", prompt = "Banana sushi the second") - g.add_node(n2) - - nu = TextToImageInvocation(id = "2", prompt = "Banana sushi") - with pytest.raises(NodeAlreadyInGraphError): - g.update_node("1", nu) - -def test_graph_adds_edge(): - g = Graph() - n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") - n2 = UpscaleInvocation(id = "2") - g.add_node(n1) - g.add_node(n2) - e = create_edge(n1.id,"image",n2.id,"image") - - g.add_edge(e) - - assert e in g.edges - -def test_graph_fails_to_add_edge_with_cycle(): - g = Graph() - n1 = UpscaleInvocation(id = "1") - g.add_node(n1) - e = create_edge(n1.id,"image",n1.id,"image") - with pytest.raises(InvalidEdgeError): - g.add_edge(e) - -def test_graph_fails_to_add_edge_with_long_cycle(): - g = Graph() - n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") - n2 = UpscaleInvocation(id = "2") - n3 = UpscaleInvocation(id = "3") - g.add_node(n1) - g.add_node(n2) - g.add_node(n3) - e1 = create_edge(n1.id,"image",n2.id,"image") - e2 = create_edge(n2.id,"image",n3.id,"image") - e3 = create_edge(n3.id,"image",n2.id,"image") - g.add_edge(e1) - g.add_edge(e2) - with pytest.raises(InvalidEdgeError): - g.add_edge(e3) - -def test_graph_fails_to_add_edge_with_missing_node_id(): - g = Graph() - n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") - n2 = UpscaleInvocation(id = "2") - g.add_node(n1) - g.add_node(n2) - e1 = create_edge("1","image","3","image") - e2 = create_edge("3","image","1","image") - with pytest.raises(InvalidEdgeError): - g.add_edge(e1) - with pytest.raises(InvalidEdgeError): - g.add_edge(e2) - -def test_graph_fails_to_add_edge_when_destination_exists(): - g = Graph() - n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") - n2 = UpscaleInvocation(id = "2") - n3 = UpscaleInvocation(id = "3") - g.add_node(n1) - g.add_node(n2) - g.add_node(n3) - e1 = create_edge(n1.id,"image",n2.id,"image") - e2 = create_edge(n1.id,"image",n3.id,"image") - e3 = create_edge(n2.id,"image",n3.id,"image") - g.add_edge(e1) - g.add_edge(e2) - with pytest.raises(InvalidEdgeError): - g.add_edge(e3) - - -def test_graph_fails_to_add_edge_with_mismatched_types(): - g = Graph() - n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") - n2 = UpscaleInvocation(id = "2") - g.add_node(n1) - g.add_node(n2) - e1 = create_edge("1","image","2","strength") - with pytest.raises(InvalidEdgeError): - g.add_edge(e1) - -def test_graph_connects_collector(): - g = Graph() - n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") - n2 = TextToImageInvocation(id = "2", prompt = "Banana sushi 2") - n3 = CollectInvocation(id = "3") - n4 = ListPassThroughInvocation(id = "4") - g.add_node(n1) - g.add_node(n2) - g.add_node(n3) - g.add_node(n4) - - e1 = create_edge("1","image","3","item") - e2 = create_edge("2","image","3","item") - e3 = create_edge("3","collection","4","collection") - g.add_edge(e1) - g.add_edge(e2) - g.add_edge(e3) - -# TODO: test that derived types mixed with base types are compatible - -def test_graph_collector_invalid_with_varying_input_types(): - g = Graph() - n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") - n2 = PromptTestInvocation(id = "2", prompt = "banana sushi 2") - n3 = CollectInvocation(id = "3") - g.add_node(n1) - g.add_node(n2) - g.add_node(n3) - - e1 = create_edge("1","image","3","item") - e2 = create_edge("2","prompt","3","item") - g.add_edge(e1) - - with pytest.raises(InvalidEdgeError): - g.add_edge(e2) - -def test_graph_collector_invalid_with_varying_input_output(): - g = Graph() - n1 = PromptTestInvocation(id = "1", prompt = "Banana sushi") - n2 = PromptTestInvocation(id = "2", prompt = "Banana sushi 2") - n3 = CollectInvocation(id = "3") - n4 = ListPassThroughInvocation(id = "4") - g.add_node(n1) - g.add_node(n2) - g.add_node(n3) - g.add_node(n4) - - e1 = create_edge("1","prompt","3","item") - e2 = create_edge("2","prompt","3","item") - e3 = create_edge("3","collection","4","collection") - g.add_edge(e1) - g.add_edge(e2) - - with pytest.raises(InvalidEdgeError): - g.add_edge(e3) - -def test_graph_collector_invalid_with_non_list_output(): - g = Graph() - n1 = PromptTestInvocation(id = "1", prompt = "Banana sushi") - n2 = PromptTestInvocation(id = "2", prompt = "Banana sushi 2") - n3 = CollectInvocation(id = "3") - n4 = PromptTestInvocation(id = "4") - g.add_node(n1) - g.add_node(n2) - g.add_node(n3) - g.add_node(n4) - - e1 = create_edge("1","prompt","3","item") - e2 = create_edge("2","prompt","3","item") - e3 = create_edge("3","collection","4","prompt") - g.add_edge(e1) - g.add_edge(e2) - - with pytest.raises(InvalidEdgeError): - g.add_edge(e3) - -def test_graph_connects_iterator(): - g = Graph() - n1 = ListPassThroughInvocation(id = "1") - n2 = IterateInvocation(id = "2") - n3 = ImageToImageInvocation(id = "3", prompt = "Banana sushi") - g.add_node(n1) - g.add_node(n2) - g.add_node(n3) - - e1 = create_edge("1","collection","2","collection") - e2 = create_edge("2","item","3","image") - g.add_edge(e1) - g.add_edge(e2) - -# TODO: TEST INVALID ITERATOR SCENARIOS - -def test_graph_iterator_invalid_if_multiple_inputs(): - g = Graph() - n1 = ListPassThroughInvocation(id = "1") - n2 = IterateInvocation(id = "2") - n3 = ImageToImageInvocation(id = "3", prompt = "Banana sushi") - n4 = ListPassThroughInvocation(id = "4") - g.add_node(n1) - g.add_node(n2) - g.add_node(n3) - g.add_node(n4) - - e1 = create_edge("1","collection","2","collection") - e2 = create_edge("2","item","3","image") - e3 = create_edge("4","collection","2","collection") - g.add_edge(e1) - g.add_edge(e2) - - with pytest.raises(InvalidEdgeError): - g.add_edge(e3) - -def test_graph_iterator_invalid_if_input_not_list(): - g = Graph() - n1 = TextToImageInvocation(id = "1", promopt = "Banana sushi") - n2 = IterateInvocation(id = "2") - g.add_node(n1) - g.add_node(n2) - - e1 = create_edge("1","collection","2","collection") - - with pytest.raises(InvalidEdgeError): - g.add_edge(e1) - -def test_graph_iterator_invalid_if_output_and_input_types_different(): - g = Graph() - n1 = ListPassThroughInvocation(id = "1") - n2 = IterateInvocation(id = "2") - n3 = PromptTestInvocation(id = "3", prompt = "Banana sushi") - g.add_node(n1) - g.add_node(n2) - g.add_node(n3) - - e1 = create_edge("1","collection","2","collection") - e2 = create_edge("2","item","3","prompt") - g.add_edge(e1) - - with pytest.raises(InvalidEdgeError): - g.add_edge(e2) - -def test_graph_validates(): - g = Graph() - n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") - n2 = UpscaleInvocation(id = "2") - g.add_node(n1) - g.add_node(n2) - e1 = create_edge("1","image","2","image") - g.add_edge(e1) - - assert g.is_valid() == True - -def test_graph_invalid_if_edges_reference_missing_nodes(): - g = Graph() - n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") - g.nodes[n1.id] = n1 - e1 = create_edge("1","image","2","image") - g.edges.append(e1) - - assert g.is_valid() == False - -def test_graph_invalid_if_subgraph_invalid(): - g = Graph() - n1 = GraphInvocation(id = "1") - n1.graph = Graph() - - n1_1 = TextToImageInvocation(id = "2", prompt = "Banana sushi") - n1.graph.nodes[n1_1.id] = n1_1 - e1 = create_edge("1","image","2","image") - n1.graph.edges.append(e1) - - g.nodes[n1.id] = n1 - - assert g.is_valid() == False - -def test_graph_invalid_if_has_cycle(): - g = Graph() - n1 = UpscaleInvocation(id = "1") - n2 = UpscaleInvocation(id = "2") - g.nodes[n1.id] = n1 - g.nodes[n2.id] = n2 - e1 = create_edge("1","image","2","image") - e2 = create_edge("2","image","1","image") - g.edges.append(e1) - g.edges.append(e2) - - assert g.is_valid() == False - -def test_graph_invalid_with_invalid_connection(): - g = Graph() - n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") - n2 = UpscaleInvocation(id = "2") - g.nodes[n1.id] = n1 - g.nodes[n2.id] = n2 - e1 = create_edge("1","image","2","strength") - g.edges.append(e1) - - assert g.is_valid() == False - - -# TODO: Subgraph operations -def test_graph_gets_subgraph_node(): - g = Graph() - n1 = GraphInvocation(id = "1") - n1.graph = Graph() - n1.graph.add_node - - n1_1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") - n1.graph.add_node(n1_1) - - g.add_node(n1) - - result = g.get_node('1.1') - - assert result is not None - assert result.id == '1' - assert result == n1_1 - -def test_graph_fails_to_get_missing_subgraph_node(): - g = Graph() - n1 = GraphInvocation(id = "1") - n1.graph = Graph() - n1.graph.add_node - - n1_1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") - n1.graph.add_node(n1_1) - - g.add_node(n1) - - with pytest.raises(NodeNotFoundError): - result = g.get_node('1.2') - -def test_graph_fails_to_enumerate_non_subgraph_node(): - g = Graph() - n1 = GraphInvocation(id = "1") - n1.graph = Graph() - n1.graph.add_node - - n1_1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") - n1.graph.add_node(n1_1) - - g.add_node(n1) - - n2 = UpscaleInvocation(id = "2") - g.add_node(n2) - - with pytest.raises(NodeNotFoundError): - result = g.get_node('2.1') - -def test_graph_gets_networkx_graph(): - g = Graph() - n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") - n2 = UpscaleInvocation(id = "2") - g.add_node(n1) - g.add_node(n2) - e = create_edge(n1.id,"image",n2.id,"image") - g.add_edge(e) - - nxg = g.nx_graph() - - assert '1' in nxg.nodes - assert '2' in nxg.nodes - assert ('1','2') in nxg.edges - - -# TODO: Graph serializes and deserializes -def test_graph_can_serialize(): - g = Graph() - n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") - n2 = UpscaleInvocation(id = "2") - g.add_node(n1) - g.add_node(n2) - e = create_edge(n1.id,"image",n2.id,"image") - g.add_edge(e) - - # Not throwing on this line is sufficient - json = g.json() - -def test_graph_can_deserialize(): - g = Graph() - n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") - n2 = UpscaleInvocation(id = "2") - g.add_node(n1) - g.add_node(n2) - e = create_edge(n1.id,"image",n2.id,"image") - g.add_edge(e) - - json = g.json() - g2 = Graph.parse_raw(json) - - assert g2 is not None - assert g2.nodes['1'] is not None - assert g2.nodes['2'] is not None - assert len(g2.edges) == 1 - assert g2.edges[0][0].node_id == '1' - assert g2.edges[0][0].field == 'image' - assert g2.edges[0][1].node_id == '2' - assert g2.edges[0][1].field == 'image' - -def test_graph_can_generate_schema(): - # Not throwing on this line is sufficient - # NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation - schema = Graph.schema_json(indent = 2) diff --git a/tests/nodes/test_nodes.py b/tests/nodes/test_nodes.py deleted file mode 100644 index fea2e75e95..0000000000 --- a/tests/nodes/test_nodes.py +++ /dev/null @@ -1,92 +0,0 @@ -from typing import Any, Callable, Literal -from ldm.invoke.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext -from ldm.invoke.app.invocations.image import ImageField -from ldm.invoke.app.services.invocation_services import InvocationServices -from pydantic import Field -import pytest - -# Define test invocations before importing anything that uses invocations -class ListPassThroughInvocationOutput(BaseInvocationOutput): - type: Literal['test_list_output'] = 'test_list_output' - - collection: list[ImageField] = Field(default_factory=list) - -class ListPassThroughInvocation(BaseInvocation): - type: Literal['test_list'] = 'test_list' - - collection: list[ImageField] = Field(default_factory=list) - - def invoke(self, context: InvocationContext) -> ListPassThroughInvocationOutput: - return ListPassThroughInvocationOutput(collection = self.collection) - -class PromptTestInvocationOutput(BaseInvocationOutput): - type: Literal['test_prompt_output'] = 'test_prompt_output' - - prompt: str = Field(default = "") - -class PromptTestInvocation(BaseInvocation): - type: Literal['test_prompt'] = 'test_prompt' - - prompt: str = Field(default = "") - - def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput: - return PromptTestInvocationOutput(prompt = self.prompt) - -class ImageTestInvocationOutput(BaseInvocationOutput): - type: Literal['test_image_output'] = 'test_image_output' - - image: ImageField = Field() - -class ImageTestInvocation(BaseInvocation): - type: Literal['test_image'] = 'test_image' - - prompt: str = Field(default = "") - - def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput: - return ImageTestInvocationOutput(image=ImageField(image_name=self.id)) - -class PromptCollectionTestInvocationOutput(BaseInvocationOutput): - type: Literal['test_prompt_collection_output'] = 'test_prompt_collection_output' - collection: list[str] = Field(default_factory=list) - -class PromptCollectionTestInvocation(BaseInvocation): - type: Literal['test_prompt_collection'] = 'test_prompt_collection' - collection: list[str] = Field() - - def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput: - return PromptCollectionTestInvocationOutput(collection=self.collection.copy()) - - -from ldm.invoke.app.services.events import EventServiceBase -from ldm.invoke.app.services.graph import EdgeConnection - -def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> tuple[EdgeConnection, EdgeConnection]: - return (EdgeConnection(node_id = from_id, field = from_field), EdgeConnection(node_id = to_id, field = to_field)) - - -class TestEvent: - event_name: str - payload: Any - - def __init__(self, event_name: str, payload: Any): - self.event_name = event_name - self.payload = payload - -class TestEventService(EventServiceBase): - events: list - - def __init__(self): - super().__init__() - self.events = list() - - def dispatch(self, event_name: str, payload: Any) -> None: - pass - -def wait_until(condition: Callable[[], bool], timeout: int = 10, interval: float = 0.1) -> None: - import time - start_time = time.time() - while time.time() - start_time < timeout: - if condition(): - return - time.sleep(interval) - raise TimeoutError("Condition not met") \ No newline at end of file diff --git a/tests/nodes/test_sqlite.py b/tests/nodes/test_sqlite.py deleted file mode 100644 index e499bbce12..0000000000 --- a/tests/nodes/test_sqlite.py +++ /dev/null @@ -1,112 +0,0 @@ -from ldm.invoke.app.services.sqlite import SqliteItemStorage, sqlite_memory -from pydantic import BaseModel, Field - - -class TestModel(BaseModel): - id: str = Field(description = "ID") - name: str = Field(description = "Name") - - -def test_sqlite_service_can_create_and_get(): - db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id') - db.set(TestModel(id = '1', name = 'Test')) - assert db.get('1') == TestModel(id = '1', name = 'Test') - -def test_sqlite_service_can_list(): - db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id') - db.set(TestModel(id = '1', name = 'Test')) - db.set(TestModel(id = '2', name = 'Test')) - db.set(TestModel(id = '3', name = 'Test')) - results = db.list() - assert results.page == 0 - assert results.pages == 1 - assert results.per_page == 10 - assert results.total == 3 - assert results.items == [TestModel(id = '1', name = 'Test'), TestModel(id = '2', name = 'Test'), TestModel(id = '3', name = 'Test')] - -def test_sqlite_service_can_delete(): - db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id') - db.set(TestModel(id = '1', name = 'Test')) - db.delete('1') - assert db.get('1') is None - -def test_sqlite_service_calls_set_callback(): - db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id') - called = False - def on_changed(item: TestModel): - nonlocal called - called = True - db.on_changed(on_changed) - db.set(TestModel(id = '1', name = 'Test')) - assert called - -def test_sqlite_service_calls_delete_callback(): - db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id') - called = False - def on_deleted(item_id: str): - nonlocal called - called = True - db.on_deleted(on_deleted) - db.set(TestModel(id = '1', name = 'Test')) - db.delete('1') - assert called - -def test_sqlite_service_can_list_with_pagination(): - db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id') - db.set(TestModel(id = '1', name = 'Test')) - db.set(TestModel(id = '2', name = 'Test')) - db.set(TestModel(id = '3', name = 'Test')) - results = db.list(page = 0, per_page = 2) - assert results.page == 0 - assert results.pages == 2 - assert results.per_page == 2 - assert results.total == 3 - assert results.items == [TestModel(id = '1', name = 'Test'), TestModel(id = '2', name = 'Test')] - -def test_sqlite_service_can_list_with_pagination_and_offset(): - db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id') - db.set(TestModel(id = '1', name = 'Test')) - db.set(TestModel(id = '2', name = 'Test')) - db.set(TestModel(id = '3', name = 'Test')) - results = db.list(page = 1, per_page = 2) - assert results.page == 1 - assert results.pages == 2 - assert results.per_page == 2 - assert results.total == 3 - assert results.items == [TestModel(id = '3', name = 'Test')] - -def test_sqlite_service_can_search(): - db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id') - db.set(TestModel(id = '1', name = 'Test')) - db.set(TestModel(id = '2', name = 'Test')) - db.set(TestModel(id = '3', name = 'Test')) - results = db.search(query = 'Test') - assert results.page == 0 - assert results.pages == 1 - assert results.per_page == 10 - assert results.total == 3 - assert results.items == [TestModel(id = '1', name = 'Test'), TestModel(id = '2', name = 'Test'), TestModel(id = '3', name = 'Test')] - -def test_sqlite_service_can_search_with_pagination(): - db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id') - db.set(TestModel(id = '1', name = 'Test')) - db.set(TestModel(id = '2', name = 'Test')) - db.set(TestModel(id = '3', name = 'Test')) - results = db.search(query = 'Test', page = 0, per_page = 2) - assert results.page == 0 - assert results.pages == 2 - assert results.per_page == 2 - assert results.total == 3 - assert results.items == [TestModel(id = '1', name = 'Test'), TestModel(id = '2', name = 'Test')] - -def test_sqlite_service_can_search_with_pagination_and_offset(): - db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id') - db.set(TestModel(id = '1', name = 'Test')) - db.set(TestModel(id = '2', name = 'Test')) - db.set(TestModel(id = '3', name = 'Test')) - results = db.search(query = 'Test', page = 1, per_page = 2) - assert results.page == 1 - assert results.pages == 2 - assert results.per_page == 2 - assert results.total == 3 - assert results.items == [TestModel(id = '3', name = 'Test')] From 8dc56471efbd4c481335b8e671b07aad9d863100 Mon Sep 17 00:00:00 2001 From: mauwii Date: Sun, 26 Feb 2023 22:08:07 +0100 Subject: [PATCH 56/57] fix compel version in pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b544b9eb9c..a8aba32b90 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ dependencies = [ "albumentations", "click", "clip_anytorch", - "compel>=0.1.6", + "compel==0.1.7", "datasets", "diffusers[torch]~=0.13", "dnspython==2.2.1", From 70283f7d8d6c6d707fe489abb60e03cda48edc7a Mon Sep 17 00:00:00 2001 From: mauwii Date: Sun, 26 Feb 2023 22:11:11 +0100 Subject: [PATCH 57/57] increase line_length to 120 --- .editorconfig | 2 +- .flake8 | 2 +- pyproject.toml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.editorconfig b/.editorconfig index 28e2100bab..0ded504342 100644 --- a/.editorconfig +++ b/.editorconfig @@ -13,7 +13,7 @@ trim_trailing_whitespace = true # Python [*.py] indent_size = 4 -max_line_length = 88 +max_line_length = 120 # css [*.css] diff --git a/.flake8 b/.flake8 index 81d8d82bfb..2159b9dcc6 100644 --- a/.flake8 +++ b/.flake8 @@ -1,5 +1,5 @@ [flake8] -max-line-length = 88 +max-line-length = 120 extend-ignore = # See https://github.com/PyCQA/pycodestyle/issues/373 E203, diff --git a/pyproject.toml b/pyproject.toml index a8aba32b90..6b866f80ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -161,7 +161,7 @@ target-version = ['py39'] atomic = true extend_skip_glob = ["scripts/orig_scripts/*"] filter_files = true -line_length = 88 +line_length = 120 profile = "black" py_version = 39 remove_redundant_aliases = true