mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
2 Commits
maryhipp/u
...
bugfix/cli
Author | SHA1 | Date | |
---|---|---|---|
ec1e66dcd3 | |||
69543c23d0 |
20
.github/workflows/pyflakes.yml
vendored
Normal file
20
.github/workflows/pyflakes.yml
vendored
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
- development
|
||||||
|
- 'release-candidate-*'
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
pyflakes:
|
||||||
|
name: runner / pyflakes
|
||||||
|
if: github.event.pull_request.draft == false
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
- name: pyflakes
|
||||||
|
uses: reviewdog/action-pyflakes@v1
|
||||||
|
with:
|
||||||
|
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
reporter: github-pr-review
|
9
.github/workflows/style-checks.yml
vendored
9
.github/workflows/style-checks.yml
vendored
@ -6,7 +6,7 @@ on:
|
|||||||
branches: main
|
branches: main
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
ruff:
|
black:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v3
|
||||||
@ -18,7 +18,8 @@ jobs:
|
|||||||
|
|
||||||
- name: Install dependencies with pip
|
- name: Install dependencies with pip
|
||||||
run: |
|
run: |
|
||||||
pip install ruff
|
pip install black flake8 Flake8-pyproject isort
|
||||||
|
|
||||||
- run: ruff check --output-format=github .
|
- run: isort --check-only .
|
||||||
- run: ruff format --check .
|
- run: black --check .
|
||||||
|
- run: flake8
|
||||||
|
@ -161,7 +161,7 @@ the command `npm install -g yarn` if needed)
|
|||||||
_For Windows/Linux with an NVIDIA GPU:_
|
_For Windows/Linux with an NVIDIA GPU:_
|
||||||
|
|
||||||
```terminal
|
```terminal
|
||||||
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121
|
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
|
||||||
```
|
```
|
||||||
|
|
||||||
_For Linux with an AMD GPU:_
|
_For Linux with an AMD GPU:_
|
||||||
@ -175,7 +175,7 @@ the command `npm install -g yarn` if needed)
|
|||||||
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/cpu
|
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/cpu
|
||||||
```
|
```
|
||||||
|
|
||||||
_For Macintoshes, either Intel or M1/M2/M3:_
|
_For Macintoshes, either Intel or M1/M2:_
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
pip install InvokeAI --use-pep517
|
pip install InvokeAI --use-pep517
|
||||||
|
@ -11,5 +11,5 @@ INVOKEAI_ROOT=
|
|||||||
# HUGGING_FACE_HUB_TOKEN=
|
# HUGGING_FACE_HUB_TOKEN=
|
||||||
|
|
||||||
## optional variables specific to the docker setup.
|
## optional variables specific to the docker setup.
|
||||||
# GPU_DRIVER=cuda # or rocm
|
# GPU_DRIVER=cuda
|
||||||
# CONTAINER_UID=1000
|
# CONTAINER_UID=1000
|
@ -18,8 +18,8 @@ ENV INVOKEAI_SRC=/opt/invokeai
|
|||||||
ENV VIRTUAL_ENV=/opt/venv/invokeai
|
ENV VIRTUAL_ENV=/opt/venv/invokeai
|
||||||
|
|
||||||
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
|
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
|
||||||
ARG TORCH_VERSION=2.1.0
|
ARG TORCH_VERSION=2.0.1
|
||||||
ARG TORCHVISION_VERSION=0.16
|
ARG TORCHVISION_VERSION=0.15.2
|
||||||
ARG GPU_DRIVER=cuda
|
ARG GPU_DRIVER=cuda
|
||||||
ARG TARGETPLATFORM="linux/amd64"
|
ARG TARGETPLATFORM="linux/amd64"
|
||||||
# unused but available
|
# unused but available
|
||||||
@ -35,7 +35,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
|||||||
if [ "$TARGETPLATFORM" = "linux/arm64" ] || [ "$GPU_DRIVER" = "cpu" ]; then \
|
if [ "$TARGETPLATFORM" = "linux/arm64" ] || [ "$GPU_DRIVER" = "cpu" ]; then \
|
||||||
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/cpu"; \
|
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/cpu"; \
|
||||||
elif [ "$GPU_DRIVER" = "rocm" ]; then \
|
elif [ "$GPU_DRIVER" = "rocm" ]; then \
|
||||||
extra_index_url_arg="--index-url https://download.pytorch.org/whl/rocm5.6"; \
|
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/rocm5.4.2"; \
|
||||||
else \
|
else \
|
||||||
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/cu121"; \
|
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/cu121"; \
|
||||||
fi &&\
|
fi &&\
|
||||||
|
@ -15,10 +15,6 @@ services:
|
|||||||
- driver: nvidia
|
- driver: nvidia
|
||||||
count: 1
|
count: 1
|
||||||
capabilities: [gpu]
|
capabilities: [gpu]
|
||||||
# For AMD support, comment out the deploy section above and uncomment the devices section below:
|
|
||||||
#devices:
|
|
||||||
# - /dev/kfd:/dev/kfd
|
|
||||||
# - /dev/dri:/dev/dri
|
|
||||||
build:
|
build:
|
||||||
context: ..
|
context: ..
|
||||||
dockerfile: docker/Dockerfile
|
dockerfile: docker/Dockerfile
|
||||||
|
@ -7,5 +7,5 @@ set -e
|
|||||||
SCRIPTDIR=$(dirname "${BASH_SOURCE[0]}")
|
SCRIPTDIR=$(dirname "${BASH_SOURCE[0]}")
|
||||||
cd "$SCRIPTDIR" || exit 1
|
cd "$SCRIPTDIR" || exit 1
|
||||||
|
|
||||||
docker compose up -d
|
docker compose up --build -d
|
||||||
docker compose logs -f
|
docker compose logs -f
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -198,7 +198,6 @@ The list of schedulers has been completely revamped and brought up to date:
|
|||||||
| **dpmpp_2m** | DPMSolverMultistepScheduler | original noise scnedule |
|
| **dpmpp_2m** | DPMSolverMultistepScheduler | original noise scnedule |
|
||||||
| **dpmpp_2m_k** | DPMSolverMultistepScheduler | using karras noise schedule |
|
| **dpmpp_2m_k** | DPMSolverMultistepScheduler | using karras noise schedule |
|
||||||
| **unipc** | UniPCMultistepScheduler | CPU only |
|
| **unipc** | UniPCMultistepScheduler | CPU only |
|
||||||
| **lcm** | LCMScheduler | |
|
|
||||||
|
|
||||||
Please see [3.0.0 Release Notes](https://github.com/invoke-ai/InvokeAI/releases/tag/v3.0.0) for further details.
|
Please see [3.0.0 Release Notes](https://github.com/invoke-ai/InvokeAI/releases/tag/v3.0.0) for further details.
|
||||||
|
|
||||||
|
@ -179,7 +179,7 @@ experimental versions later.
|
|||||||
you will have the choice of CUDA (NVidia cards), ROCm (AMD cards),
|
you will have the choice of CUDA (NVidia cards), ROCm (AMD cards),
|
||||||
or CPU (no graphics acceleration). On Windows, you'll have the
|
or CPU (no graphics acceleration). On Windows, you'll have the
|
||||||
choice of CUDA vs CPU, and on Macs you'll be offered CPU only. When
|
choice of CUDA vs CPU, and on Macs you'll be offered CPU only. When
|
||||||
you select CPU on M1/M2/M3 Macintoshes, you will get MPS-based
|
you select CPU on M1 or M2 Macintoshes, you will get MPS-based
|
||||||
graphics acceleration without installing additional drivers. If you
|
graphics acceleration without installing additional drivers. If you
|
||||||
are unsure what GPU you are using, you can ask the installer to
|
are unsure what GPU you are using, you can ask the installer to
|
||||||
guess.
|
guess.
|
||||||
@ -471,7 +471,7 @@ Then type the following commands:
|
|||||||
|
|
||||||
=== "NVIDIA System"
|
=== "NVIDIA System"
|
||||||
```bash
|
```bash
|
||||||
pip install torch torchvision --force-reinstall --extra-index-url https://download.pytorch.org/whl/cu121
|
pip install torch torchvision --force-reinstall --extra-index-url https://download.pytorch.org/whl/cu118
|
||||||
pip install xformers
|
pip install xformers
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -148,7 +148,7 @@ manager, please follow these steps:
|
|||||||
=== "CUDA (NVidia)"
|
=== "CUDA (NVidia)"
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121
|
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "ROCm (AMD)"
|
=== "ROCm (AMD)"
|
||||||
@ -327,7 +327,7 @@ installation protocol (important!)
|
|||||||
|
|
||||||
=== "CUDA (NVidia)"
|
=== "CUDA (NVidia)"
|
||||||
```bash
|
```bash
|
||||||
pip install -e .[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121
|
pip install -e .[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "ROCm (AMD)"
|
=== "ROCm (AMD)"
|
||||||
@ -375,7 +375,7 @@ you can do so using this unsupported recipe:
|
|||||||
mkdir ~/invokeai
|
mkdir ~/invokeai
|
||||||
conda create -n invokeai python=3.10
|
conda create -n invokeai python=3.10
|
||||||
conda activate invokeai
|
conda activate invokeai
|
||||||
pip install InvokeAI[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121
|
pip install InvokeAI[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
|
||||||
invokeai-configure --root ~/invokeai
|
invokeai-configure --root ~/invokeai
|
||||||
invokeai --root ~/invokeai --web
|
invokeai --root ~/invokeai --web
|
||||||
```
|
```
|
||||||
|
@ -85,7 +85,7 @@ You can find which version you should download from [this link](https://docs.nvi
|
|||||||
|
|
||||||
When installing torch and torchvision manually with `pip`, remember to provide
|
When installing torch and torchvision manually with `pip`, remember to provide
|
||||||
the argument `--extra-index-url
|
the argument `--extra-index-url
|
||||||
https://download.pytorch.org/whl/cu121` as described in the [Manual
|
https://download.pytorch.org/whl/cu118` as described in the [Manual
|
||||||
Installation Guide](020_INSTALL_MANUAL.md).
|
Installation Guide](020_INSTALL_MANUAL.md).
|
||||||
|
|
||||||
## :simple-amd: ROCm
|
## :simple-amd: ROCm
|
||||||
|
@ -30,7 +30,7 @@ methodology for details on why running applications in such a stateless fashion
|
|||||||
The container is configured for CUDA by default, but can be built to support AMD GPUs
|
The container is configured for CUDA by default, but can be built to support AMD GPUs
|
||||||
by setting the `GPU_DRIVER=rocm` environment variable at Docker image build time.
|
by setting the `GPU_DRIVER=rocm` environment variable at Docker image build time.
|
||||||
|
|
||||||
Developers on Apple silicon (M1/M2/M3): You
|
Developers on Apple silicon (M1/M2): You
|
||||||
[can't access your GPU cores from Docker containers](https://github.com/pytorch/pytorch/issues/81224)
|
[can't access your GPU cores from Docker containers](https://github.com/pytorch/pytorch/issues/81224)
|
||||||
and performance is reduced compared with running it directly on macOS but for
|
and performance is reduced compared with running it directly on macOS but for
|
||||||
development purposes it's fine. Once you're done with development tasks on your
|
development purposes it's fine. Once you're done with development tasks on your
|
||||||
|
@ -28,7 +28,7 @@ command line, then just be sure to activate it's virtual environment.
|
|||||||
Then run the following three commands:
|
Then run the following three commands:
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
pip install xformers~=0.0.22
|
pip install xformers~=0.0.19
|
||||||
pip install triton # WON'T WORK ON WINDOWS
|
pip install triton # WON'T WORK ON WINDOWS
|
||||||
python -m xformers.info output
|
python -m xformers.info output
|
||||||
```
|
```
|
||||||
@ -42,7 +42,7 @@ If all goes well, you'll see a report like the
|
|||||||
following:
|
following:
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
xFormers 0.0.22
|
xFormers 0.0.20
|
||||||
memory_efficient_attention.cutlassF: available
|
memory_efficient_attention.cutlassF: available
|
||||||
memory_efficient_attention.cutlassB: available
|
memory_efficient_attention.cutlassB: available
|
||||||
memory_efficient_attention.flshattF: available
|
memory_efficient_attention.flshattF: available
|
||||||
@ -59,14 +59,14 @@ swiglu.gemm_fused_operand_sum: available
|
|||||||
swiglu.fused.p.cpp: available
|
swiglu.fused.p.cpp: available
|
||||||
is_triton_available: True
|
is_triton_available: True
|
||||||
is_functorch_available: False
|
is_functorch_available: False
|
||||||
pytorch.version: 2.1.0+cu121
|
pytorch.version: 2.0.1+cu118
|
||||||
pytorch.cuda: available
|
pytorch.cuda: available
|
||||||
gpu.compute_capability: 8.9
|
gpu.compute_capability: 8.9
|
||||||
gpu.name: NVIDIA GeForce RTX 4070
|
gpu.name: NVIDIA GeForce RTX 4070
|
||||||
build.info: available
|
build.info: available
|
||||||
build.cuda_version: 1108
|
build.cuda_version: 1108
|
||||||
build.python_version: 3.10.11
|
build.python_version: 3.10.11
|
||||||
build.torch_version: 2.1.0+cu121
|
build.torch_version: 2.0.1+cu118
|
||||||
build.env.TORCH_CUDA_ARCH_LIST: 5.0+PTX 6.0 6.1 7.0 7.5 8.0 8.6
|
build.env.TORCH_CUDA_ARCH_LIST: 5.0+PTX 6.0 6.1 7.0 7.5 8.0 8.6
|
||||||
build.env.XFORMERS_BUILD_TYPE: Release
|
build.env.XFORMERS_BUILD_TYPE: Release
|
||||||
build.env.XFORMERS_ENABLE_DEBUG_ASSERTIONS: None
|
build.env.XFORMERS_ENABLE_DEBUG_ASSERTIONS: None
|
||||||
@ -92,22 +92,33 @@ installed from source. These instructions were written for a system
|
|||||||
running Ubuntu 22.04, but other Linux distributions should be able to
|
running Ubuntu 22.04, but other Linux distributions should be able to
|
||||||
adapt this recipe.
|
adapt this recipe.
|
||||||
|
|
||||||
#### 1. Install CUDA Toolkit 12.1
|
#### 1. Install CUDA Toolkit 11.8
|
||||||
|
|
||||||
You will need the CUDA developer's toolkit in order to compile and
|
You will need the CUDA developer's toolkit in order to compile and
|
||||||
install xFormers. **Do not try to install Ubuntu's nvidia-cuda-toolkit
|
install xFormers. **Do not try to install Ubuntu's nvidia-cuda-toolkit
|
||||||
package.** It is out of date and will cause conflicts among the NVIDIA
|
package.** It is out of date and will cause conflicts among the NVIDIA
|
||||||
driver and binaries. Instead install the CUDA Toolkit package provided
|
driver and binaries. Instead install the CUDA Toolkit package provided
|
||||||
by NVIDIA itself. Go to [CUDA Toolkit 12.1
|
by NVIDIA itself. Go to [CUDA Toolkit 11.8
|
||||||
Downloads](https://developer.nvidia.com/cuda-12-1-0-download-archive)
|
Downloads](https://developer.nvidia.com/cuda-11-8-0-download-archive)
|
||||||
and use the target selection wizard to choose your platform and Linux
|
and use the target selection wizard to choose your platform and Linux
|
||||||
distribution. Select an installer type of "runfile (local)" at the
|
distribution. Select an installer type of "runfile (local)" at the
|
||||||
last step.
|
last step.
|
||||||
|
|
||||||
This will provide you with a recipe for downloading and running a
|
This will provide you with a recipe for downloading and running a
|
||||||
install shell script that will install the toolkit and drivers.
|
install shell script that will install the toolkit and drivers. For
|
||||||
|
example, the install script recipe for Ubuntu 22.04 running on a
|
||||||
|
x86_64 system is:
|
||||||
|
|
||||||
#### 2. Confirm/Install pyTorch 2.1.0 with CUDA 12.1 support
|
```
|
||||||
|
wget https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run
|
||||||
|
sudo sh cuda_11.8.0_520.61.05_linux.run
|
||||||
|
```
|
||||||
|
|
||||||
|
Rather than cut-and-paste this example, We recommend that you walk
|
||||||
|
through the toolkit wizard in order to get the most up to date
|
||||||
|
installer for your system.
|
||||||
|
|
||||||
|
#### 2. Confirm/Install pyTorch 2.01 with CUDA 11.8 support
|
||||||
|
|
||||||
If you are using InvokeAI 3.0.2 or higher, these will already be
|
If you are using InvokeAI 3.0.2 or higher, these will already be
|
||||||
installed. If not, you can check whether you have the needed libraries
|
installed. If not, you can check whether you have the needed libraries
|
||||||
@ -122,7 +133,7 @@ Then run the command:
|
|||||||
python -c 'exec("import torch\nprint(torch.__version__)")'
|
python -c 'exec("import torch\nprint(torch.__version__)")'
|
||||||
```
|
```
|
||||||
|
|
||||||
If it prints __2.1.0+cu121__ you're good. If not, you can install the
|
If it prints __1.13.1+cu118__ you're good. If not, you can install the
|
||||||
most up to date libraries with this command:
|
most up to date libraries with this command:
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
|
@ -244,7 +244,7 @@ class InvokeAiInstance:
|
|||||||
"numpy~=1.24.0", # choose versions that won't be uninstalled during phase 2
|
"numpy~=1.24.0", # choose versions that won't be uninstalled during phase 2
|
||||||
"urllib3~=1.26.0",
|
"urllib3~=1.26.0",
|
||||||
"requests~=2.28.0",
|
"requests~=2.28.0",
|
||||||
"torch~=2.1.0",
|
"torch~=2.0.0",
|
||||||
"torchmetrics==0.11.4",
|
"torchmetrics==0.11.4",
|
||||||
"torchvision>=0.14.1",
|
"torchvision>=0.14.1",
|
||||||
"--force-reinstall",
|
"--force-reinstall",
|
||||||
@ -460,10 +460,10 @@ def get_torch_source() -> (Union[str, None], str):
|
|||||||
url = "https://download.pytorch.org/whl/cpu"
|
url = "https://download.pytorch.org/whl/cpu"
|
||||||
|
|
||||||
if device == "cuda":
|
if device == "cuda":
|
||||||
url = "https://download.pytorch.org/whl/cu121"
|
url = "https://download.pytorch.org/whl/cu118"
|
||||||
optional_modules = "[xformers,onnx-cuda]"
|
optional_modules = "[xformers,onnx-cuda]"
|
||||||
if device == "cuda_and_dml":
|
if device == "cuda_and_dml":
|
||||||
url = "https://download.pytorch.org/whl/cu121"
|
url = "https://download.pytorch.org/whl/cu118"
|
||||||
optional_modules = "[xformers,onnx-directml]"
|
optional_modules = "[xformers,onnx-directml]"
|
||||||
|
|
||||||
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13
|
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13
|
||||||
|
@ -137,7 +137,7 @@ def dest_path(dest=None) -> Path:
|
|||||||
path_completer = PathCompleter(
|
path_completer = PathCompleter(
|
||||||
only_directories=True,
|
only_directories=True,
|
||||||
expanduser=True,
|
expanduser=True,
|
||||||
get_paths=lambda: [browse_start], # noqa: B023
|
get_paths=lambda: [browse_start],
|
||||||
# get_paths=lambda: [".."].extend(list(browse_start.iterdir()))
|
# get_paths=lambda: [".."].extend(list(browse_start.iterdir()))
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -149,7 +149,7 @@ def dest_path(dest=None) -> Path:
|
|||||||
completer=path_completer,
|
completer=path_completer,
|
||||||
default=str(browse_start) + os.sep,
|
default=str(browse_start) + os.sep,
|
||||||
vi_mode=True,
|
vi_mode=True,
|
||||||
complete_while_typing=True,
|
complete_while_typing=True
|
||||||
# Test that this is not needed on Windows
|
# Test that this is not needed on Windows
|
||||||
# complete_style=CompleteStyle.READLINE_LIKE,
|
# complete_style=CompleteStyle.READLINE_LIKE,
|
||||||
)
|
)
|
||||||
|
@ -24,7 +24,6 @@ from ..services.item_storage.item_storage_sqlite import SqliteItemStorage
|
|||||||
from ..services.latents_storage.latents_storage_disk import DiskLatentsStorage
|
from ..services.latents_storage.latents_storage_disk import DiskLatentsStorage
|
||||||
from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage
|
from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage
|
||||||
from ..services.model_manager.model_manager_default import ModelManagerService
|
from ..services.model_manager.model_manager_default import ModelManagerService
|
||||||
from ..services.model_records import ModelRecordServiceSQL
|
|
||||||
from ..services.names.names_default import SimpleNameService
|
from ..services.names.names_default import SimpleNameService
|
||||||
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
|
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
|
||||||
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
||||||
@ -86,7 +85,6 @@ class ApiDependencies:
|
|||||||
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
|
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
|
||||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
|
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
|
||||||
model_manager = ModelManagerService(config, logger)
|
model_manager = ModelManagerService(config, logger)
|
||||||
model_record_service = ModelRecordServiceSQL(db=db)
|
|
||||||
names = SimpleNameService()
|
names = SimpleNameService()
|
||||||
performance_statistics = InvocationStatsService()
|
performance_statistics = InvocationStatsService()
|
||||||
processor = DefaultInvocationProcessor()
|
processor = DefaultInvocationProcessor()
|
||||||
@ -113,7 +111,6 @@ class ApiDependencies:
|
|||||||
latents=latents,
|
latents=latents,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
model_manager=model_manager,
|
model_manager=model_manager,
|
||||||
model_records=model_record_service,
|
|
||||||
names=names,
|
names=names,
|
||||||
performance_statistics=performance_statistics,
|
performance_statistics=performance_statistics,
|
||||||
processor=processor,
|
processor=processor,
|
||||||
|
@ -28,7 +28,7 @@ class FastAPIEventService(EventServiceBase):
|
|||||||
self.__queue.put(None)
|
self.__queue.put(None)
|
||||||
|
|
||||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||||
self.__queue.put({"event_name": event_name, "payload": payload})
|
self.__queue.put(dict(event_name=event_name, payload=payload))
|
||||||
|
|
||||||
async def __dispatch_from_queue(self, stop_event: threading.Event):
|
async def __dispatch_from_queue(self, stop_event: threading.Event):
|
||||||
"""Get events on from the queue and dispatch them, from the correct thread"""
|
"""Get events on from the queue and dispatch them, from the correct thread"""
|
||||||
|
@ -366,7 +366,7 @@ class ImagesDownloaded(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@images_router.post("/export", operation_id="download_images_from_list", response_model=ImagesDownloaded)
|
@images_router.post("/download", operation_id="download_images_from_list", response_model=ImagesDownloaded)
|
||||||
async def download_images_from_list(
|
async def download_images_from_list(
|
||||||
image_names: list[str] = Body(description="The list of names of images to download", embed=True),
|
image_names: list[str] = Body(description="The list of names of images to download", embed=True),
|
||||||
board_id: Optional[str] = Body(
|
board_id: Optional[str] = Body(
|
||||||
|
@ -1,164 +0,0 @@
|
|||||||
# Copyright (c) 2023 Lincoln D. Stein
|
|
||||||
"""FastAPI route for model configuration records."""
|
|
||||||
|
|
||||||
|
|
||||||
from hashlib import sha1
|
|
||||||
from random import randbytes
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from fastapi import Body, Path, Query, Response
|
|
||||||
from fastapi.routing import APIRouter
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
|
||||||
from starlette.exceptions import HTTPException
|
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
from invokeai.app.services.model_records import (
|
|
||||||
DuplicateModelException,
|
|
||||||
InvalidModelException,
|
|
||||||
UnknownModelException,
|
|
||||||
)
|
|
||||||
from invokeai.backend.model_manager.config import (
|
|
||||||
AnyModelConfig,
|
|
||||||
BaseModelType,
|
|
||||||
ModelType,
|
|
||||||
)
|
|
||||||
|
|
||||||
from ..dependencies import ApiDependencies
|
|
||||||
|
|
||||||
model_records_router = APIRouter(prefix="/v1/model/record", tags=["models"])
|
|
||||||
|
|
||||||
|
|
||||||
class ModelsList(BaseModel):
|
|
||||||
"""Return list of configs."""
|
|
||||||
|
|
||||||
models: list[AnyModelConfig]
|
|
||||||
|
|
||||||
model_config = ConfigDict(use_enum_values=True)
|
|
||||||
|
|
||||||
|
|
||||||
@model_records_router.get(
|
|
||||||
"/",
|
|
||||||
operation_id="list_model_records",
|
|
||||||
)
|
|
||||||
async def list_model_records(
|
|
||||||
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
|
|
||||||
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
|
|
||||||
) -> ModelsList:
|
|
||||||
"""Get a list of models."""
|
|
||||||
record_store = ApiDependencies.invoker.services.model_records
|
|
||||||
found_models: list[AnyModelConfig] = []
|
|
||||||
if base_models:
|
|
||||||
for base_model in base_models:
|
|
||||||
found_models.extend(record_store.search_by_attr(base_model=base_model, model_type=model_type))
|
|
||||||
else:
|
|
||||||
found_models.extend(record_store.search_by_attr(model_type=model_type))
|
|
||||||
return ModelsList(models=found_models)
|
|
||||||
|
|
||||||
|
|
||||||
@model_records_router.get(
|
|
||||||
"/i/{key}",
|
|
||||||
operation_id="get_model_record",
|
|
||||||
responses={
|
|
||||||
200: {"description": "Success"},
|
|
||||||
400: {"description": "Bad request"},
|
|
||||||
404: {"description": "The model could not be found"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
async def get_model_record(
|
|
||||||
key: str = Path(description="Key of the model record to fetch."),
|
|
||||||
) -> AnyModelConfig:
|
|
||||||
"""Get a model record"""
|
|
||||||
record_store = ApiDependencies.invoker.services.model_records
|
|
||||||
try:
|
|
||||||
return record_store.get_model(key)
|
|
||||||
except UnknownModelException as e:
|
|
||||||
raise HTTPException(status_code=404, detail=str(e))
|
|
||||||
|
|
||||||
|
|
||||||
@model_records_router.patch(
|
|
||||||
"/i/{key}",
|
|
||||||
operation_id="update_model_record",
|
|
||||||
responses={
|
|
||||||
200: {"description": "The model was updated successfully"},
|
|
||||||
400: {"description": "Bad request"},
|
|
||||||
404: {"description": "The model could not be found"},
|
|
||||||
409: {"description": "There is already a model corresponding to the new name"},
|
|
||||||
},
|
|
||||||
status_code=200,
|
|
||||||
response_model=AnyModelConfig,
|
|
||||||
)
|
|
||||||
async def update_model_record(
|
|
||||||
key: Annotated[str, Path(description="Unique key of model")],
|
|
||||||
info: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")],
|
|
||||||
) -> AnyModelConfig:
|
|
||||||
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
|
||||||
logger = ApiDependencies.invoker.services.logger
|
|
||||||
record_store = ApiDependencies.invoker.services.model_records
|
|
||||||
try:
|
|
||||||
model_response = record_store.update_model(key, config=info)
|
|
||||||
logger.info(f"Updated model: {key}")
|
|
||||||
except UnknownModelException as e:
|
|
||||||
raise HTTPException(status_code=404, detail=str(e))
|
|
||||||
except ValueError as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
raise HTTPException(status_code=409, detail=str(e))
|
|
||||||
return model_response
|
|
||||||
|
|
||||||
|
|
||||||
@model_records_router.delete(
|
|
||||||
"/i/{key}",
|
|
||||||
operation_id="del_model_record",
|
|
||||||
responses={
|
|
||||||
204: {"description": "Model deleted successfully"},
|
|
||||||
404: {"description": "Model not found"},
|
|
||||||
},
|
|
||||||
status_code=204,
|
|
||||||
)
|
|
||||||
async def del_model_record(
|
|
||||||
key: str = Path(description="Unique key of model to remove from model registry."),
|
|
||||||
) -> Response:
|
|
||||||
"""Delete Model"""
|
|
||||||
logger = ApiDependencies.invoker.services.logger
|
|
||||||
|
|
||||||
try:
|
|
||||||
record_store = ApiDependencies.invoker.services.model_records
|
|
||||||
record_store.del_model(key)
|
|
||||||
logger.info(f"Deleted model: {key}")
|
|
||||||
return Response(status_code=204)
|
|
||||||
except UnknownModelException as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
raise HTTPException(status_code=404, detail=str(e))
|
|
||||||
|
|
||||||
|
|
||||||
@model_records_router.post(
|
|
||||||
"/i/",
|
|
||||||
operation_id="add_model_record",
|
|
||||||
responses={
|
|
||||||
201: {"description": "The model added successfully"},
|
|
||||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
|
||||||
415: {"description": "Unrecognized file/folder format"},
|
|
||||||
},
|
|
||||||
status_code=201,
|
|
||||||
)
|
|
||||||
async def add_model_record(
|
|
||||||
config: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")]
|
|
||||||
) -> AnyModelConfig:
|
|
||||||
"""
|
|
||||||
Add a model using the configuration information appropriate for its type.
|
|
||||||
"""
|
|
||||||
logger = ApiDependencies.invoker.services.logger
|
|
||||||
record_store = ApiDependencies.invoker.services.model_records
|
|
||||||
if config.key == "<NOKEY>":
|
|
||||||
config.key = sha1(randbytes(100)).hexdigest()
|
|
||||||
logger.info(f"Created model {config.key} for {config.name}")
|
|
||||||
try:
|
|
||||||
record_store.add_model(config.key, config)
|
|
||||||
except DuplicateModelException as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
raise HTTPException(status_code=409, detail=str(e))
|
|
||||||
except InvalidModelException as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
raise HTTPException(status_code=415)
|
|
||||||
|
|
||||||
# now fetch it out
|
|
||||||
return record_store.get_model(config.key)
|
|
@ -1,5 +1,6 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2023 Lincoln D. Stein
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2023 Lincoln D. Stein
|
||||||
|
|
||||||
|
|
||||||
import pathlib
|
import pathlib
|
||||||
from typing import Annotated, List, Literal, Optional, Union
|
from typing import Annotated, List, Literal, Optional, Union
|
||||||
|
|
||||||
@ -54,7 +55,7 @@ async def list_models(
|
|||||||
) -> ModelsList:
|
) -> ModelsList:
|
||||||
"""Gets a list of models"""
|
"""Gets a list of models"""
|
||||||
if base_models and len(base_models) > 0:
|
if base_models and len(base_models) > 0:
|
||||||
models_raw = []
|
models_raw = list()
|
||||||
for base_model in base_models:
|
for base_model in base_models:
|
||||||
models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type))
|
models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type))
|
||||||
else:
|
else:
|
||||||
|
@ -34,4 +34,4 @@ class SocketIO:
|
|||||||
|
|
||||||
async def _handle_unsub_queue(self, sid, data, *args, **kwargs):
|
async def _handle_unsub_queue(self, sid, data, *args, **kwargs):
|
||||||
if "queue_id" in data:
|
if "queue_id" in data:
|
||||||
await self.__sio.leave_room(sid, data["queue_id"])
|
await self.__sio.enter_room(sid, data["queue_id"])
|
||||||
|
@ -43,7 +43,6 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
|||||||
board_images,
|
board_images,
|
||||||
boards,
|
boards,
|
||||||
images,
|
images,
|
||||||
model_records,
|
|
||||||
models,
|
models,
|
||||||
session_queue,
|
session_queue,
|
||||||
sessions,
|
sessions,
|
||||||
@ -107,7 +106,6 @@ app.include_router(sessions.session_router, prefix="/api")
|
|||||||
|
|
||||||
app.include_router(utilities.utilities_router, prefix="/api")
|
app.include_router(utilities.utilities_router, prefix="/api")
|
||||||
app.include_router(models.models_router, prefix="/api")
|
app.include_router(models.models_router, prefix="/api")
|
||||||
app.include_router(model_records.model_records_router, prefix="/api")
|
|
||||||
app.include_router(images.images_router, prefix="/api")
|
app.include_router(images.images_router, prefix="/api")
|
||||||
app.include_router(boards.boards_router, prefix="/api")
|
app.include_router(boards.boards_router, prefix="/api")
|
||||||
app.include_router(board_images.board_images_router, prefix="/api")
|
app.include_router(board_images.board_images_router, prefix="/api")
|
||||||
@ -132,7 +130,7 @@ def custom_openapi() -> dict[str, Any]:
|
|||||||
# Add all outputs
|
# Add all outputs
|
||||||
all_invocations = BaseInvocation.get_invocations()
|
all_invocations = BaseInvocation.get_invocations()
|
||||||
output_types = set()
|
output_types = set()
|
||||||
output_type_titles = {}
|
output_type_titles = dict()
|
||||||
for invoker in all_invocations:
|
for invoker in all_invocations:
|
||||||
output_type = signature(invoker.invoke).return_annotation
|
output_type = signature(invoker.invoke).return_annotation
|
||||||
output_types.add(output_type)
|
output_types.add(output_type)
|
||||||
@ -173,12 +171,12 @@ def custom_openapi() -> dict[str, Any]:
|
|||||||
# print(f"Config with name {name} already defined")
|
# print(f"Config with name {name} already defined")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
openapi_schema["components"]["schemas"][name] = {
|
openapi_schema["components"]["schemas"][name] = dict(
|
||||||
"title": name,
|
title=name,
|
||||||
"description": "An enumeration.",
|
description="An enumeration.",
|
||||||
"type": "string",
|
type="string",
|
||||||
"enum": [v.value for v in model_config_format_enum],
|
enum=list(v.value for v in model_config_format_enum),
|
||||||
}
|
)
|
||||||
|
|
||||||
app.openapi_schema = openapi_schema
|
app.openapi_schema = openapi_schema
|
||||||
return app.openapi_schema
|
return app.openapi_schema
|
||||||
|
@ -25,4 +25,4 @@ spec.loader.exec_module(module)
|
|||||||
|
|
||||||
# add core nodes to __all__
|
# add core nodes to __all__
|
||||||
python_files = filter(lambda f: not f.name.startswith("_"), Path(__file__).parent.glob("*.py"))
|
python_files = filter(lambda f: not f.name.startswith("_"), Path(__file__).parent.glob("*.py"))
|
||||||
__all__ = [f.stem for f in python_files] # type: ignore
|
__all__ = list(f.stem for f in python_files) # type: ignore
|
||||||
|
@ -16,7 +16,6 @@ from pydantic.fields import FieldInfo, _Unset
|
|||||||
from pydantic_core import PydanticUndefined
|
from pydantic_core import PydanticUndefined
|
||||||
|
|
||||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||||
from invokeai.app.shared.fields import FieldDescriptions
|
|
||||||
from invokeai.app.util.misc import uuid_string
|
from invokeai.app.util.misc import uuid_string
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -31,6 +30,70 @@ class InvalidFieldError(TypeError):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class FieldDescriptions:
|
||||||
|
denoising_start = "When to start denoising, expressed a percentage of total steps"
|
||||||
|
denoising_end = "When to stop denoising, expressed a percentage of total steps"
|
||||||
|
cfg_scale = "Classifier-Free Guidance scale"
|
||||||
|
scheduler = "Scheduler to use during inference"
|
||||||
|
positive_cond = "Positive conditioning tensor"
|
||||||
|
negative_cond = "Negative conditioning tensor"
|
||||||
|
noise = "Noise tensor"
|
||||||
|
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
|
||||||
|
unet = "UNet (scheduler, LoRAs)"
|
||||||
|
vae = "VAE"
|
||||||
|
cond = "Conditioning tensor"
|
||||||
|
controlnet_model = "ControlNet model to load"
|
||||||
|
vae_model = "VAE model to load"
|
||||||
|
lora_model = "LoRA model to load"
|
||||||
|
main_model = "Main model (UNet, VAE, CLIP) to load"
|
||||||
|
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
|
||||||
|
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
|
||||||
|
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
|
||||||
|
lora_weight = "The weight at which the LoRA is applied to each model"
|
||||||
|
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"
|
||||||
|
raw_prompt = "Raw prompt text (no parsing)"
|
||||||
|
sdxl_aesthetic = "The aesthetic score to apply to the conditioning tensor"
|
||||||
|
skipped_layers = "Number of layers to skip in text encoder"
|
||||||
|
seed = "Seed for random number generation"
|
||||||
|
steps = "Number of steps to run"
|
||||||
|
width = "Width of output (px)"
|
||||||
|
height = "Height of output (px)"
|
||||||
|
control = "ControlNet(s) to apply"
|
||||||
|
ip_adapter = "IP-Adapter to apply"
|
||||||
|
t2i_adapter = "T2I-Adapter(s) to apply"
|
||||||
|
denoised_latents = "Denoised latents tensor"
|
||||||
|
latents = "Latents tensor"
|
||||||
|
strength = "Strength of denoising (proportional to steps)"
|
||||||
|
metadata = "Optional metadata to be saved with the image"
|
||||||
|
metadata_collection = "Collection of Metadata"
|
||||||
|
metadata_item_polymorphic = "A single metadata item or collection of metadata items"
|
||||||
|
metadata_item_label = "Label for this metadata item"
|
||||||
|
metadata_item_value = "The value for this metadata item (may be any type)"
|
||||||
|
workflow = "Optional workflow to be saved with the image"
|
||||||
|
interp_mode = "Interpolation mode"
|
||||||
|
torch_antialias = "Whether or not to apply antialiasing (bilinear or bicubic only)"
|
||||||
|
fp32 = "Whether or not to use full float32 precision"
|
||||||
|
precision = "Precision to use"
|
||||||
|
tiled = "Processing using overlapping tiles (reduce memory consumption)"
|
||||||
|
detect_res = "Pixel resolution for detection"
|
||||||
|
image_res = "Pixel resolution for output image"
|
||||||
|
safe_mode = "Whether or not to use safe mode"
|
||||||
|
scribble_mode = "Whether or not to use scribble mode"
|
||||||
|
scale_factor = "The factor by which to scale"
|
||||||
|
blend_alpha = (
|
||||||
|
"Blending factor. 0.0 = use input A only, 1.0 = use input B only, 0.5 = 50% mix of input A and input B."
|
||||||
|
)
|
||||||
|
num_1 = "The first number"
|
||||||
|
num_2 = "The second number"
|
||||||
|
mask = "The mask to use for the operation"
|
||||||
|
board = "The board to save the image to"
|
||||||
|
image = "The image to process"
|
||||||
|
tile_size = "Tile size"
|
||||||
|
inclusive_low = "The inclusive low value"
|
||||||
|
exclusive_high = "The exclusive high value"
|
||||||
|
decimal_places = "The number of decimal places to round to"
|
||||||
|
|
||||||
|
|
||||||
class Input(str, Enum):
|
class Input(str, Enum):
|
||||||
"""
|
"""
|
||||||
The type of input a field accepts.
|
The type of input a field accepts.
|
||||||
@ -236,35 +299,35 @@ def InputField(
|
|||||||
Ignored for non-collection fields.
|
Ignored for non-collection fields.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
json_schema_extra_: dict[str, Any] = {
|
json_schema_extra_: dict[str, Any] = dict(
|
||||||
"input": input,
|
input=input,
|
||||||
"ui_type": ui_type,
|
ui_type=ui_type,
|
||||||
"ui_component": ui_component,
|
ui_component=ui_component,
|
||||||
"ui_hidden": ui_hidden,
|
ui_hidden=ui_hidden,
|
||||||
"ui_order": ui_order,
|
ui_order=ui_order,
|
||||||
"item_default": item_default,
|
item_default=item_default,
|
||||||
"ui_choice_labels": ui_choice_labels,
|
ui_choice_labels=ui_choice_labels,
|
||||||
"_field_kind": "input",
|
_field_kind="input",
|
||||||
}
|
)
|
||||||
|
|
||||||
field_args = {
|
field_args = dict(
|
||||||
"default": default,
|
default=default,
|
||||||
"default_factory": default_factory,
|
default_factory=default_factory,
|
||||||
"title": title,
|
title=title,
|
||||||
"description": description,
|
description=description,
|
||||||
"pattern": pattern,
|
pattern=pattern,
|
||||||
"strict": strict,
|
strict=strict,
|
||||||
"gt": gt,
|
gt=gt,
|
||||||
"ge": ge,
|
ge=ge,
|
||||||
"lt": lt,
|
lt=lt,
|
||||||
"le": le,
|
le=le,
|
||||||
"multiple_of": multiple_of,
|
multiple_of=multiple_of,
|
||||||
"allow_inf_nan": allow_inf_nan,
|
allow_inf_nan=allow_inf_nan,
|
||||||
"max_digits": max_digits,
|
max_digits=max_digits,
|
||||||
"decimal_places": decimal_places,
|
decimal_places=decimal_places,
|
||||||
"min_length": min_length,
|
min_length=min_length,
|
||||||
"max_length": max_length,
|
max_length=max_length,
|
||||||
}
|
)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Invocation definitions have their fields typed correctly for their `invoke()` functions.
|
Invocation definitions have their fields typed correctly for their `invoke()` functions.
|
||||||
@ -299,24 +362,24 @@ def InputField(
|
|||||||
|
|
||||||
# because we are manually making fields optional, we need to store the original required bool for reference later
|
# because we are manually making fields optional, we need to store the original required bool for reference later
|
||||||
if default is PydanticUndefined and default_factory is PydanticUndefined:
|
if default is PydanticUndefined and default_factory is PydanticUndefined:
|
||||||
json_schema_extra_.update({"orig_required": True})
|
json_schema_extra_.update(dict(orig_required=True))
|
||||||
else:
|
else:
|
||||||
json_schema_extra_.update({"orig_required": False})
|
json_schema_extra_.update(dict(orig_required=False))
|
||||||
|
|
||||||
# make Input.Any and Input.Connection fields optional, providing None as a default if the field doesn't already have one
|
# make Input.Any and Input.Connection fields optional, providing None as a default if the field doesn't already have one
|
||||||
if (input is Input.Any or input is Input.Connection) and default_factory is PydanticUndefined:
|
if (input is Input.Any or input is Input.Connection) and default_factory is PydanticUndefined:
|
||||||
default_ = None if default is PydanticUndefined else default
|
default_ = None if default is PydanticUndefined else default
|
||||||
provided_args.update({"default": default_})
|
provided_args.update(dict(default=default_))
|
||||||
if default is not PydanticUndefined:
|
if default is not PydanticUndefined:
|
||||||
# before invoking, we'll grab the original default value and set it on the field if the field wasn't provided a value
|
# before invoking, we'll grab the original default value and set it on the field if the field wasn't provided a value
|
||||||
json_schema_extra_.update({"default": default})
|
json_schema_extra_.update(dict(default=default))
|
||||||
json_schema_extra_.update({"orig_default": default})
|
json_schema_extra_.update(dict(orig_default=default))
|
||||||
elif default is not PydanticUndefined and default_factory is PydanticUndefined:
|
elif default is not PydanticUndefined and default_factory is PydanticUndefined:
|
||||||
default_ = default
|
default_ = default
|
||||||
provided_args.update({"default": default_})
|
provided_args.update(dict(default=default_))
|
||||||
json_schema_extra_.update({"orig_default": default_})
|
json_schema_extra_.update(dict(orig_default=default_))
|
||||||
elif default_factory is not PydanticUndefined:
|
elif default_factory is not PydanticUndefined:
|
||||||
provided_args.update({"default_factory": default_factory})
|
provided_args.update(dict(default_factory=default_factory))
|
||||||
# TODO: cannot serialize default_factory...
|
# TODO: cannot serialize default_factory...
|
||||||
# json_schema_extra_.update(dict(orig_default_factory=default_factory))
|
# json_schema_extra_.update(dict(orig_default_factory=default_factory))
|
||||||
|
|
||||||
@ -383,12 +446,12 @@ def OutputField(
|
|||||||
decimal_places=decimal_places,
|
decimal_places=decimal_places,
|
||||||
min_length=min_length,
|
min_length=min_length,
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
json_schema_extra={
|
json_schema_extra=dict(
|
||||||
"ui_type": ui_type,
|
ui_type=ui_type,
|
||||||
"ui_hidden": ui_hidden,
|
ui_hidden=ui_hidden,
|
||||||
"ui_order": ui_order,
|
ui_order=ui_order,
|
||||||
"_field_kind": "output",
|
_field_kind="output",
|
||||||
},
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -460,14 +523,14 @@ class BaseInvocationOutput(BaseModel):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_output_types(cls) -> Iterable[str]:
|
def get_output_types(cls) -> Iterable[str]:
|
||||||
return (get_type(i) for i in BaseInvocationOutput.get_outputs())
|
return map(lambda i: get_type(i), BaseInvocationOutput.get_outputs())
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||||
# Because we use a pydantic Literal field with default value for the invocation type,
|
# Because we use a pydantic Literal field with default value for the invocation type,
|
||||||
# it will be typed as optional in the OpenAPI schema. Make it required manually.
|
# it will be typed as optional in the OpenAPI schema. Make it required manually.
|
||||||
if "required" not in schema or not isinstance(schema["required"], list):
|
if "required" not in schema or not isinstance(schema["required"], list):
|
||||||
schema["required"] = []
|
schema["required"] = list()
|
||||||
schema["required"].extend(["type"])
|
schema["required"].extend(["type"])
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
@ -527,11 +590,16 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def get_invocations_map(cls) -> dict[str, BaseInvocation]:
|
def get_invocations_map(cls) -> dict[str, BaseInvocation]:
|
||||||
# Get the type strings out of the literals and into a dictionary
|
# Get the type strings out of the literals and into a dictionary
|
||||||
return {get_type(i): i for i in BaseInvocation.get_invocations()}
|
return dict(
|
||||||
|
map(
|
||||||
|
lambda i: (get_type(i), i),
|
||||||
|
BaseInvocation.get_invocations(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_invocation_types(cls) -> Iterable[str]:
|
def get_invocation_types(cls) -> Iterable[str]:
|
||||||
return (get_type(i) for i in BaseInvocation.get_invocations())
|
return map(lambda i: get_type(i), BaseInvocation.get_invocations())
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_output_type(cls) -> BaseInvocationOutput:
|
def get_output_type(cls) -> BaseInvocationOutput:
|
||||||
@ -550,7 +618,7 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
if uiconfig and hasattr(uiconfig, "version"):
|
if uiconfig and hasattr(uiconfig, "version"):
|
||||||
schema["version"] = uiconfig.version
|
schema["version"] = uiconfig.version
|
||||||
if "required" not in schema or not isinstance(schema["required"], list):
|
if "required" not in schema or not isinstance(schema["required"], list):
|
||||||
schema["required"] = []
|
schema["required"] = list()
|
||||||
schema["required"].extend(["type", "id"])
|
schema["required"].extend(["type", "id"])
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -604,15 +672,15 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
id: str = Field(
|
id: str = Field(
|
||||||
default_factory=uuid_string,
|
default_factory=uuid_string,
|
||||||
description="The id of this instance of an invocation. Must be unique among all instances of invocations.",
|
description="The id of this instance of an invocation. Must be unique among all instances of invocations.",
|
||||||
json_schema_extra={"_field_kind": "internal"},
|
json_schema_extra=dict(_field_kind="internal"),
|
||||||
)
|
)
|
||||||
is_intermediate: bool = Field(
|
is_intermediate: bool = Field(
|
||||||
default=False,
|
default=False,
|
||||||
description="Whether or not this is an intermediate invocation.",
|
description="Whether or not this is an intermediate invocation.",
|
||||||
json_schema_extra={"ui_type": UIType.IsIntermediate, "_field_kind": "internal"},
|
json_schema_extra=dict(ui_type=UIType.IsIntermediate, _field_kind="internal"),
|
||||||
)
|
)
|
||||||
use_cache: bool = Field(
|
use_cache: bool = Field(
|
||||||
default=True, description="Whether or not to use the cache", json_schema_extra={"_field_kind": "internal"}
|
default=True, description="Whether or not to use the cache", json_schema_extra=dict(_field_kind="internal")
|
||||||
)
|
)
|
||||||
|
|
||||||
UIConfig: ClassVar[Type[UIConfigBase]]
|
UIConfig: ClassVar[Type[UIConfigBase]]
|
||||||
@ -646,7 +714,7 @@ class _Model(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
# Get all pydantic model attrs, methods, etc
|
# Get all pydantic model attrs, methods, etc
|
||||||
RESERVED_PYDANTIC_FIELD_NAMES = {m[0] for m in inspect.getmembers(_Model())}
|
RESERVED_PYDANTIC_FIELD_NAMES = set(map(lambda m: m[0], inspect.getmembers(_Model())))
|
||||||
|
|
||||||
|
|
||||||
def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None:
|
def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None:
|
||||||
@ -661,7 +729,9 @@ def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None
|
|||||||
|
|
||||||
field_kind = (
|
field_kind = (
|
||||||
# _field_kind is defined via InputField(), OutputField() or by one of the internal fields defined in this file
|
# _field_kind is defined via InputField(), OutputField() or by one of the internal fields defined in this file
|
||||||
field.json_schema_extra.get("_field_kind", None) if field.json_schema_extra else None
|
field.json_schema_extra.get("_field_kind", None)
|
||||||
|
if field.json_schema_extra
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
# must have a field_kind
|
# must have a field_kind
|
||||||
@ -722,7 +792,7 @@ def invocation(
|
|||||||
# Add OpenAPI schema extras
|
# Add OpenAPI schema extras
|
||||||
uiconf_name = cls.__qualname__ + ".UIConfig"
|
uiconf_name = cls.__qualname__ + ".UIConfig"
|
||||||
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
|
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
|
||||||
cls.UIConfig = type(uiconf_name, (UIConfigBase,), {})
|
cls.UIConfig = type(uiconf_name, (UIConfigBase,), dict())
|
||||||
if title is not None:
|
if title is not None:
|
||||||
cls.UIConfig.title = title
|
cls.UIConfig.title = title
|
||||||
if tags is not None:
|
if tags is not None:
|
||||||
@ -749,7 +819,7 @@ def invocation(
|
|||||||
|
|
||||||
invocation_type_annotation = Literal[invocation_type] # type: ignore
|
invocation_type_annotation = Literal[invocation_type] # type: ignore
|
||||||
invocation_type_field = Field(
|
invocation_type_field = Field(
|
||||||
title="type", default=invocation_type, json_schema_extra={"_field_kind": "internal"}
|
title="type", default=invocation_type, json_schema_extra=dict(_field_kind="internal")
|
||||||
)
|
)
|
||||||
|
|
||||||
docstring = cls.__doc__
|
docstring = cls.__doc__
|
||||||
@ -795,7 +865,7 @@ def invocation_output(
|
|||||||
# Add the output type to the model.
|
# Add the output type to the model.
|
||||||
|
|
||||||
output_type_annotation = Literal[output_type] # type: ignore
|
output_type_annotation = Literal[output_type] # type: ignore
|
||||||
output_type_field = Field(title="type", default=output_type, json_schema_extra={"_field_kind": "internal"})
|
output_type_field = Field(title="type", default=output_type, json_schema_extra=dict(_field_kind="internal"))
|
||||||
|
|
||||||
docstring = cls.__doc__
|
docstring = cls.__doc__
|
||||||
cls = create_model(
|
cls = create_model(
|
||||||
@ -827,7 +897,7 @@ WorkflowFieldValidator = TypeAdapter(WorkflowField)
|
|||||||
|
|
||||||
class WithWorkflow(BaseModel):
|
class WithWorkflow(BaseModel):
|
||||||
workflow: Optional[WorkflowField] = Field(
|
workflow: Optional[WorkflowField] = Field(
|
||||||
default=None, description=FieldDescriptions.workflow, json_schema_extra={"_field_kind": "internal"}
|
default=None, description=FieldDescriptions.workflow, json_schema_extra=dict(_field_kind="internal")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -845,5 +915,5 @@ MetadataFieldValidator = TypeAdapter(MetadataField)
|
|||||||
|
|
||||||
class WithMetadata(BaseModel):
|
class WithMetadata(BaseModel):
|
||||||
metadata: Optional[MetadataField] = Field(
|
metadata: Optional[MetadataField] = Field(
|
||||||
default=None, description=FieldDescriptions.metadata, json_schema_extra={"_field_kind": "internal"}
|
default=None, description=FieldDescriptions.metadata, json_schema_extra=dict(_field_kind="internal")
|
||||||
)
|
)
|
||||||
|
@ -7,7 +7,6 @@ from compel import Compel, ReturnedEmbeddingsType
|
|||||||
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
||||||
|
|
||||||
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput
|
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput
|
||||||
from invokeai.app.shared.fields import FieldDescriptions
|
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
BasicConditioningInfo,
|
BasicConditioningInfo,
|
||||||
ExtraConditioningInfo,
|
ExtraConditioningInfo,
|
||||||
@ -20,6 +19,7 @@ from ...backend.util.devices import torch_dtype
|
|||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
|
FieldDescriptions,
|
||||||
Input,
|
Input,
|
||||||
InputField,
|
InputField,
|
||||||
InvocationContext,
|
InvocationContext,
|
||||||
@ -112,11 +112,10 @@ class CompelInvocation(BaseInvocation):
|
|||||||
tokenizer,
|
tokenizer,
|
||||||
ti_manager,
|
ti_manager,
|
||||||
),
|
),
|
||||||
|
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers),
|
||||||
text_encoder_info as text_encoder,
|
text_encoder_info as text_encoder,
|
||||||
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
||||||
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
|
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
|
||||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
|
||||||
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers),
|
|
||||||
):
|
):
|
||||||
compel = Compel(
|
compel = Compel(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
@ -235,11 +234,10 @@ class SDXLPromptInvocationBase:
|
|||||||
tokenizer,
|
tokenizer,
|
||||||
ti_manager,
|
ti_manager,
|
||||||
),
|
),
|
||||||
|
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),
|
||||||
text_encoder_info as text_encoder,
|
text_encoder_info as text_encoder,
|
||||||
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
||||||
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
|
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
|
||||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
|
||||||
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),
|
|
||||||
):
|
):
|
||||||
compel = Compel(
|
compel = Compel(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
@ -28,12 +28,12 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator
|
|||||||
|
|
||||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||||
from invokeai.app.shared.fields import FieldDescriptions
|
|
||||||
|
|
||||||
from ...backend.model_management import BaseModelType
|
from ...backend.model_management import BaseModelType
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
|
FieldDescriptions,
|
||||||
Input,
|
Input,
|
||||||
InputField,
|
InputField,
|
||||||
InvocationContext,
|
InvocationContext,
|
||||||
|
@ -131,7 +131,7 @@ def prepare_faces_list(
|
|||||||
deduped_faces: list[FaceResultData] = []
|
deduped_faces: list[FaceResultData] = []
|
||||||
|
|
||||||
if len(face_result_list) == 0:
|
if len(face_result_list) == 0:
|
||||||
return []
|
return list()
|
||||||
|
|
||||||
for candidate in face_result_list:
|
for candidate in face_result_list:
|
||||||
should_add = True
|
should_add = True
|
||||||
@ -210,7 +210,7 @@ def generate_face_box_mask(
|
|||||||
# Check if any face is detected.
|
# Check if any face is detected.
|
||||||
if results.multi_face_landmarks: # type: ignore # this are via protobuf and not typed
|
if results.multi_face_landmarks: # type: ignore # this are via protobuf and not typed
|
||||||
# Search for the face_id in the detected faces.
|
# Search for the face_id in the detected faces.
|
||||||
for _face_id, face_landmarks in enumerate(results.multi_face_landmarks): # type: ignore #this are via protobuf and not typed
|
for face_id, face_landmarks in enumerate(results.multi_face_landmarks): # type: ignore #this are via protobuf and not typed
|
||||||
# Get the bounding box of the face mesh.
|
# Get the bounding box of the face mesh.
|
||||||
x_coordinates = [landmark.x for landmark in face_landmarks.landmark]
|
x_coordinates = [landmark.x for landmark in face_landmarks.landmark]
|
||||||
y_coordinates = [landmark.y for landmark in face_landmarks.landmark]
|
y_coordinates = [landmark.y for landmark in face_landmarks.landmark]
|
||||||
|
@ -9,11 +9,19 @@ from PIL import Image, ImageChops, ImageFilter, ImageOps
|
|||||||
|
|
||||||
from invokeai.app.invocations.primitives import BoardField, ColorField, ImageField, ImageOutput
|
from invokeai.app.invocations.primitives import BoardField, ColorField, ImageField, ImageOutput
|
||||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||||
from invokeai.app.shared.fields import FieldDescriptions
|
|
||||||
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
||||||
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, Input, InputField, InvocationContext, WithMetadata, WithWorkflow, invocation
|
from .baseinvocation import (
|
||||||
|
BaseInvocation,
|
||||||
|
FieldDescriptions,
|
||||||
|
Input,
|
||||||
|
InputField,
|
||||||
|
InvocationContext,
|
||||||
|
WithMetadata,
|
||||||
|
WithWorkflow,
|
||||||
|
invocation,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("show_image", title="Show Image", tags=["image"], category="image", version="1.0.0")
|
@invocation("show_image", title="Show Image", tags=["image"], category="image", version="1.0.0")
|
||||||
|
@ -7,6 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field
|
|||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
|
FieldDescriptions,
|
||||||
Input,
|
Input,
|
||||||
InputField,
|
InputField,
|
||||||
InvocationContext,
|
InvocationContext,
|
||||||
@ -16,7 +17,6 @@ from invokeai.app.invocations.baseinvocation import (
|
|||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
from invokeai.app.invocations.primitives import ImageField
|
from invokeai.app.invocations.primitives import ImageField
|
||||||
from invokeai.app.shared.fields import FieldDescriptions
|
|
||||||
from invokeai.backend.model_management.models.base import BaseModelType, ModelType
|
from invokeai.backend.model_management.models.base import BaseModelType, ModelType
|
||||||
from invokeai.backend.model_management.models.ip_adapter import get_ip_adapter_image_encoder_model_id
|
from invokeai.backend.model_management.models.ip_adapter import get_ip_adapter_image_encoder_model_id
|
||||||
|
|
||||||
|
@ -10,7 +10,7 @@ import torch
|
|||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
from diffusers import AutoencoderKL, AutoencoderTiny
|
from diffusers import AutoencoderKL, AutoencoderTiny
|
||||||
from diffusers.image_processor import VaeImageProcessor
|
from diffusers.image_processor import VaeImageProcessor
|
||||||
from diffusers.models.adapter import T2IAdapter
|
from diffusers.models.adapter import FullAdapterXL, T2IAdapter
|
||||||
from diffusers.models.attention_processor import (
|
from diffusers.models.attention_processor import (
|
||||||
AttnProcessor2_0,
|
AttnProcessor2_0,
|
||||||
LoRAAttnProcessor2_0,
|
LoRAAttnProcessor2_0,
|
||||||
@ -34,7 +34,6 @@ from invokeai.app.invocations.primitives import (
|
|||||||
)
|
)
|
||||||
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||||
from invokeai.app.shared.fields import FieldDescriptions
|
|
||||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
||||||
@ -58,6 +57,7 @@ from ...backend.util.devices import choose_precision, choose_torch_device
|
|||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
|
FieldDescriptions,
|
||||||
Input,
|
Input,
|
||||||
InputField,
|
InputField,
|
||||||
InvocationContext,
|
InvocationContext,
|
||||||
@ -77,7 +77,7 @@ if choose_torch_device() == torch.device("mps"):
|
|||||||
|
|
||||||
DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
||||||
|
|
||||||
SAMPLER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())]
|
SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))]
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("scheduler_output")
|
@invocation_output("scheduler_output")
|
||||||
@ -562,6 +562,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
t2i_adapter_model: T2IAdapter
|
t2i_adapter_model: T2IAdapter
|
||||||
with t2i_adapter_model_info as t2i_adapter_model:
|
with t2i_adapter_model_info as t2i_adapter_model:
|
||||||
total_downscale_factor = t2i_adapter_model.total_downscale_factor
|
total_downscale_factor = t2i_adapter_model.total_downscale_factor
|
||||||
|
if isinstance(t2i_adapter_model.adapter, FullAdapterXL):
|
||||||
|
# HACK(ryand): Work around a bug in FullAdapterXL. This is being addressed upstream in diffusers by
|
||||||
|
# this PR: https://github.com/huggingface/diffusers/pull/5134.
|
||||||
|
total_downscale_factor = total_downscale_factor // 2
|
||||||
|
|
||||||
# Resize the T2I-Adapter input image.
|
# Resize the T2I-Adapter input image.
|
||||||
# We select the resize dimensions so that after the T2I-Adapter's total_downscale_factor is applied, the
|
# We select the resize dimensions so that after the T2I-Adapter's total_downscale_factor is applied, the
|
||||||
@ -706,8 +710,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
with (
|
with (
|
||||||
ExitStack() as exit_stack,
|
ExitStack() as exit_stack,
|
||||||
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),
|
|
||||||
ModelPatcher.apply_freeu(unet_info.context.model, self.unet.freeu_config),
|
|
||||||
set_seamless(unet_info.context.model, self.unet.seamless_axes),
|
set_seamless(unet_info.context.model, self.unet.seamless_axes),
|
||||||
unet_info as unet,
|
unet_info as unet,
|
||||||
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||||
@ -1105,7 +1107,7 @@ class BlendLatentsInvocation(BaseInvocation):
|
|||||||
latents_b = context.services.latents.get(self.latents_b.latents_name)
|
latents_b = context.services.latents.get(self.latents_b.latents_name)
|
||||||
|
|
||||||
if latents_a.shape != latents_b.shape:
|
if latents_a.shape != latents_b.shape:
|
||||||
raise Exception("Latents to blend must be the same size.")
|
raise "Latents to blend must be the same size."
|
||||||
|
|
||||||
# TODO:
|
# TODO:
|
||||||
device = choose_torch_device()
|
device = choose_torch_device()
|
||||||
|
@ -6,9 +6,8 @@ import numpy as np
|
|||||||
from pydantic import ValidationInfo, field_validator
|
from pydantic import ValidationInfo, field_validator
|
||||||
|
|
||||||
from invokeai.app.invocations.primitives import FloatOutput, IntegerOutput
|
from invokeai.app.invocations.primitives import FloatOutput, IntegerOutput
|
||||||
from invokeai.app.shared.fields import FieldDescriptions
|
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, invocation
|
||||||
|
|
||||||
|
|
||||||
@invocation("add", title="Add Integers", tags=["math", "add"], category="math", version="1.0.0")
|
@invocation("add", title="Add Integers", tags=["math", "add"], category="math", version="1.0.0")
|
||||||
@ -145,17 +144,17 @@ INTEGER_OPERATIONS = Literal[
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
INTEGER_OPERATIONS_LABELS = {
|
INTEGER_OPERATIONS_LABELS = dict(
|
||||||
"ADD": "Add A+B",
|
ADD="Add A+B",
|
||||||
"SUB": "Subtract A-B",
|
SUB="Subtract A-B",
|
||||||
"MUL": "Multiply A*B",
|
MUL="Multiply A*B",
|
||||||
"DIV": "Divide A/B",
|
DIV="Divide A/B",
|
||||||
"EXP": "Exponentiate A^B",
|
EXP="Exponentiate A^B",
|
||||||
"MOD": "Modulus A%B",
|
MOD="Modulus A%B",
|
||||||
"ABS": "Absolute Value of A",
|
ABS="Absolute Value of A",
|
||||||
"MIN": "Minimum(A,B)",
|
MIN="Minimum(A,B)",
|
||||||
"MAX": "Maximum(A,B)",
|
MAX="Maximum(A,B)",
|
||||||
}
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@ -183,8 +182,8 @@ class IntegerMathInvocation(BaseInvocation):
|
|||||||
operation: INTEGER_OPERATIONS = InputField(
|
operation: INTEGER_OPERATIONS = InputField(
|
||||||
default="ADD", description="The operation to perform", ui_choice_labels=INTEGER_OPERATIONS_LABELS
|
default="ADD", description="The operation to perform", ui_choice_labels=INTEGER_OPERATIONS_LABELS
|
||||||
)
|
)
|
||||||
a: int = InputField(default=1, description=FieldDescriptions.num_1)
|
a: int = InputField(default=0, description=FieldDescriptions.num_1)
|
||||||
b: int = InputField(default=1, description=FieldDescriptions.num_2)
|
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||||
|
|
||||||
@field_validator("b")
|
@field_validator("b")
|
||||||
def no_unrepresentable_results(cls, v: int, info: ValidationInfo):
|
def no_unrepresentable_results(cls, v: int, info: ValidationInfo):
|
||||||
@ -231,17 +230,17 @@ FLOAT_OPERATIONS = Literal[
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
FLOAT_OPERATIONS_LABELS = {
|
FLOAT_OPERATIONS_LABELS = dict(
|
||||||
"ADD": "Add A+B",
|
ADD="Add A+B",
|
||||||
"SUB": "Subtract A-B",
|
SUB="Subtract A-B",
|
||||||
"MUL": "Multiply A*B",
|
MUL="Multiply A*B",
|
||||||
"DIV": "Divide A/B",
|
DIV="Divide A/B",
|
||||||
"EXP": "Exponentiate A^B",
|
EXP="Exponentiate A^B",
|
||||||
"ABS": "Absolute Value of A",
|
ABS="Absolute Value of A",
|
||||||
"SQRT": "Square Root of A",
|
SQRT="Square Root of A",
|
||||||
"MIN": "Minimum(A,B)",
|
MIN="Minimum(A,B)",
|
||||||
"MAX": "Maximum(A,B)",
|
MAX="Maximum(A,B)",
|
||||||
}
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@ -257,8 +256,8 @@ class FloatMathInvocation(BaseInvocation):
|
|||||||
operation: FLOAT_OPERATIONS = InputField(
|
operation: FLOAT_OPERATIONS = InputField(
|
||||||
default="ADD", description="The operation to perform", ui_choice_labels=FLOAT_OPERATIONS_LABELS
|
default="ADD", description="The operation to perform", ui_choice_labels=FLOAT_OPERATIONS_LABELS
|
||||||
)
|
)
|
||||||
a: float = InputField(default=1, description=FieldDescriptions.num_1)
|
a: float = InputField(default=0, description=FieldDescriptions.num_1)
|
||||||
b: float = InputField(default=1, description=FieldDescriptions.num_2)
|
b: float = InputField(default=0, description=FieldDescriptions.num_2)
|
||||||
|
|
||||||
@field_validator("b")
|
@field_validator("b")
|
||||||
def no_unrepresentable_results(cls, v: float, info: ValidationInfo):
|
def no_unrepresentable_results(cls, v: float, info: ValidationInfo):
|
||||||
@ -266,7 +265,7 @@ class FloatMathInvocation(BaseInvocation):
|
|||||||
raise ValueError("Cannot divide by zero")
|
raise ValueError("Cannot divide by zero")
|
||||||
elif info.data["operation"] == "EXP" and info.data["a"] == 0 and v < 0:
|
elif info.data["operation"] == "EXP" and info.data["a"] == 0 and v < 0:
|
||||||
raise ValueError("Cannot raise zero to a negative power")
|
raise ValueError("Cannot raise zero to a negative power")
|
||||||
elif info.data["operation"] == "EXP" and isinstance(info.data["a"] ** v, complex):
|
elif info.data["operation"] == "EXP" and type(info.data["a"] ** v) is complex:
|
||||||
raise ValueError("Root operation resulted in a complex number")
|
raise ValueError("Root operation resulted in a complex number")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
@ -5,6 +5,7 @@ from pydantic import BaseModel, ConfigDict, Field
|
|||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
|
FieldDescriptions,
|
||||||
InputField,
|
InputField,
|
||||||
InvocationContext,
|
InvocationContext,
|
||||||
MetadataField,
|
MetadataField,
|
||||||
@ -18,7 +19,6 @@ from invokeai.app.invocations.ip_adapter import IPAdapterModelField
|
|||||||
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
|
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
|
||||||
from invokeai.app.invocations.primitives import ImageField
|
from invokeai.app.invocations.primitives import ImageField
|
||||||
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||||
from invokeai.app.shared.fields import FieldDescriptions
|
|
||||||
|
|
||||||
from ...version import __version__
|
from ...version import __version__
|
||||||
|
|
||||||
@ -160,14 +160,13 @@ class CoreMetadataInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# High resolution fix metadata.
|
# High resolution fix metadata.
|
||||||
hrf_enabled: Optional[float] = InputField(
|
hrf_width: Optional[int] = InputField(
|
||||||
default=None,
|
default=None,
|
||||||
description="Whether or not high resolution fix was enabled.",
|
description="The high resolution fix height and width multipler.",
|
||||||
)
|
)
|
||||||
# TODO: should this be stricter or do we just let the UI handle it?
|
hrf_height: Optional[int] = InputField(
|
||||||
hrf_method: Optional[str] = InputField(
|
|
||||||
default=None,
|
default=None,
|
||||||
description="The high resolution fix upscale method.",
|
description="The high resolution fix height and width multipler.",
|
||||||
)
|
)
|
||||||
hrf_strength: Optional[float] = InputField(
|
hrf_strength: Optional[float] = InputField(
|
||||||
default=None,
|
default=None,
|
||||||
|
@ -3,13 +3,11 @@ from typing import List, Optional
|
|||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from invokeai.app.shared.fields import FieldDescriptions
|
|
||||||
from invokeai.app.shared.models import FreeUConfig
|
|
||||||
|
|
||||||
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
|
FieldDescriptions,
|
||||||
Input,
|
Input,
|
||||||
InputField,
|
InputField,
|
||||||
InvocationContext,
|
InvocationContext,
|
||||||
@ -38,7 +36,6 @@ class UNetField(BaseModel):
|
|||||||
scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
|
scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
|
||||||
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
|
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
|
||||||
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
||||||
freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration")
|
|
||||||
|
|
||||||
|
|
||||||
class ClipField(BaseModel):
|
class ClipField(BaseModel):
|
||||||
@ -54,32 +51,13 @@ class VaeField(BaseModel):
|
|||||||
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("unet_output")
|
|
||||||
class UNetOutput(BaseInvocationOutput):
|
|
||||||
"""Base class for invocations that output a UNet field"""
|
|
||||||
|
|
||||||
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("vae_output")
|
|
||||||
class VAEOutput(BaseInvocationOutput):
|
|
||||||
"""Base class for invocations that output a VAE field"""
|
|
||||||
|
|
||||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("clip_output")
|
|
||||||
class CLIPOutput(BaseInvocationOutput):
|
|
||||||
"""Base class for invocations that output a CLIP field"""
|
|
||||||
|
|
||||||
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP")
|
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("model_loader_output")
|
@invocation_output("model_loader_output")
|
||||||
class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput):
|
class ModelLoaderOutput(BaseInvocationOutput):
|
||||||
"""Model loader output"""
|
"""Model loader output"""
|
||||||
|
|
||||||
pass
|
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||||
|
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP")
|
||||||
|
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||||
|
|
||||||
|
|
||||||
class MainModelField(BaseModel):
|
class MainModelField(BaseModel):
|
||||||
@ -388,6 +366,13 @@ class VAEModelField(BaseModel):
|
|||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("vae_loader_output")
|
||||||
|
class VaeLoaderOutput(BaseInvocationOutput):
|
||||||
|
"""VAE output"""
|
||||||
|
|
||||||
|
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||||
|
|
||||||
|
|
||||||
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.0")
|
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.0")
|
||||||
class VaeLoaderInvocation(BaseInvocation):
|
class VaeLoaderInvocation(BaseInvocation):
|
||||||
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
||||||
@ -399,7 +384,7 @@ class VaeLoaderInvocation(BaseInvocation):
|
|||||||
title="VAE",
|
title="VAE",
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> VAEOutput:
|
def invoke(self, context: InvocationContext) -> VaeLoaderOutput:
|
||||||
base_model = self.vae_model.base_model
|
base_model = self.vae_model.base_model
|
||||||
model_name = self.vae_model.model_name
|
model_name = self.vae_model.model_name
|
||||||
model_type = ModelType.Vae
|
model_type = ModelType.Vae
|
||||||
@ -410,7 +395,7 @@ class VaeLoaderInvocation(BaseInvocation):
|
|||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
):
|
):
|
||||||
raise Exception(f"Unkown vae name: {model_name}!")
|
raise Exception(f"Unkown vae name: {model_name}!")
|
||||||
return VAEOutput(
|
return VaeLoaderOutput(
|
||||||
vae=VaeField(
|
vae=VaeField(
|
||||||
vae=ModelInfo(
|
vae=ModelInfo(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
@ -472,24 +457,3 @@ class SeamlessModeInvocation(BaseInvocation):
|
|||||||
vae.seamless_axes = seamless_axes_list
|
vae.seamless_axes = seamless_axes_list
|
||||||
|
|
||||||
return SeamlessModeOutput(unet=unet, vae=vae)
|
return SeamlessModeOutput(unet=unet, vae=vae)
|
||||||
|
|
||||||
|
|
||||||
@invocation("freeu", title="FreeU", tags=["freeu"], category="unet", version="1.0.0")
|
|
||||||
class FreeUInvocation(BaseInvocation):
|
|
||||||
"""
|
|
||||||
Applies FreeU to the UNet. Suggested values (b1/b2/s1/s2):
|
|
||||||
|
|
||||||
SD1.5: 1.2/1.4/0.9/0.2,
|
|
||||||
SD2: 1.1/1.2/0.9/0.2,
|
|
||||||
SDXL: 1.1/1.2/0.6/0.4,
|
|
||||||
"""
|
|
||||||
|
|
||||||
unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection, title="UNet")
|
|
||||||
b1: float = InputField(default=1.2, ge=-1, le=3, description=FieldDescriptions.freeu_b1)
|
|
||||||
b2: float = InputField(default=1.4, ge=-1, le=3, description=FieldDescriptions.freeu_b2)
|
|
||||||
s1: float = InputField(default=0.9, ge=-1, le=3, description=FieldDescriptions.freeu_s1)
|
|
||||||
s2: float = InputField(default=0.2, ge=-1, le=3, description=FieldDescriptions.freeu_s2)
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> UNetOutput:
|
|
||||||
self.unet.freeu_config = FreeUConfig(s1=self.s1, s2=self.s2, b1=self.b1, b2=self.b2)
|
|
||||||
return UNetOutput(unet=self.unet)
|
|
||||||
|
@ -5,13 +5,13 @@ import torch
|
|||||||
from pydantic import field_validator
|
from pydantic import field_validator
|
||||||
|
|
||||||
from invokeai.app.invocations.latent import LatentsField
|
from invokeai.app.invocations.latent import LatentsField
|
||||||
from invokeai.app.shared.fields import FieldDescriptions
|
|
||||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
|
|
||||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
|
FieldDescriptions,
|
||||||
InputField,
|
InputField,
|
||||||
InvocationContext,
|
InvocationContext,
|
||||||
OutputField,
|
OutputField,
|
||||||
|
@ -14,7 +14,6 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput
|
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput
|
||||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||||
from invokeai.app.shared.fields import FieldDescriptions
|
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
from invokeai.backend import BaseModelType, ModelType, SubModelType
|
from invokeai.backend import BaseModelType, ModelType, SubModelType
|
||||||
|
|
||||||
@ -24,6 +23,7 @@ from ...backend.util import choose_torch_device
|
|||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
|
FieldDescriptions,
|
||||||
Input,
|
Input,
|
||||||
InputField,
|
InputField,
|
||||||
InvocationContext,
|
InvocationContext,
|
||||||
@ -54,7 +54,7 @@ ORT_TO_NP_TYPE = {
|
|||||||
"tensor(double)": np.float64,
|
"tensor(double)": np.float64,
|
||||||
}
|
}
|
||||||
|
|
||||||
PRECISION_VALUES = Literal[tuple(ORT_TO_NP_TYPE.keys())]
|
PRECISION_VALUES = Literal[tuple(list(ORT_TO_NP_TYPE.keys()))]
|
||||||
|
|
||||||
|
|
||||||
@invocation("prompt_onnx", title="ONNX Prompt (Raw)", tags=["prompt", "onnx"], category="conditioning", version="1.0.0")
|
@invocation("prompt_onnx", title="ONNX Prompt (Raw)", tags=["prompt", "onnx"], category="conditioning", version="1.0.0")
|
||||||
@ -252,7 +252,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
|||||||
scheduler.set_timesteps(self.steps)
|
scheduler.set_timesteps(self.steps)
|
||||||
latents = latents * np.float64(scheduler.init_noise_sigma)
|
latents = latents * np.float64(scheduler.init_noise_sigma)
|
||||||
|
|
||||||
extra_step_kwargs = {}
|
extra_step_kwargs = dict()
|
||||||
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
|
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||||
extra_step_kwargs.update(
|
extra_step_kwargs.update(
|
||||||
eta=0.0,
|
eta=0.0,
|
||||||
|
@ -100,7 +100,7 @@ EASING_FUNCTIONS_MAP = {
|
|||||||
"BounceInOut": BounceEaseInOut,
|
"BounceInOut": BounceEaseInOut,
|
||||||
}
|
}
|
||||||
|
|
||||||
EASING_FUNCTION_KEYS = Literal[tuple(EASING_FUNCTIONS_MAP.keys())]
|
EASING_FUNCTION_KEYS = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))]
|
||||||
|
|
||||||
|
|
||||||
# actually I think for now could just use CollectionOutput (which is list[Any]
|
# actually I think for now could just use CollectionOutput (which is list[Any]
|
||||||
@ -161,7 +161,7 @@ class StepParamEasingInvocation(BaseInvocation):
|
|||||||
easing_class = EASING_FUNCTIONS_MAP[self.easing]
|
easing_class = EASING_FUNCTIONS_MAP[self.easing]
|
||||||
if log_diagnostics:
|
if log_diagnostics:
|
||||||
context.services.logger.debug("easing class: " + str(easing_class))
|
context.services.logger.debug("easing class: " + str(easing_class))
|
||||||
easing_list = []
|
easing_list = list()
|
||||||
if self.mirror: # "expected" mirroring
|
if self.mirror: # "expected" mirroring
|
||||||
# if number of steps is even, squeeze duration down to (number_of_steps)/2
|
# if number of steps is even, squeeze duration down to (number_of_steps)/2
|
||||||
# and create reverse copy of list to append
|
# and create reverse copy of list to append
|
||||||
@ -178,7 +178,7 @@ class StepParamEasingInvocation(BaseInvocation):
|
|||||||
end=self.end_value,
|
end=self.end_value,
|
||||||
duration=base_easing_duration - 1,
|
duration=base_easing_duration - 1,
|
||||||
)
|
)
|
||||||
base_easing_vals = []
|
base_easing_vals = list()
|
||||||
for step_index in range(base_easing_duration):
|
for step_index in range(base_easing_duration):
|
||||||
easing_val = easing_function.ease(step_index)
|
easing_val = easing_function.ease(step_index)
|
||||||
base_easing_vals.append(easing_val)
|
base_easing_vals.append(easing_val)
|
||||||
|
@ -5,11 +5,10 @@ from typing import Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from invokeai.app.shared.fields import FieldDescriptions
|
|
||||||
|
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
|
FieldDescriptions,
|
||||||
Input,
|
Input,
|
||||||
InputField,
|
InputField,
|
||||||
InvocationContext,
|
InvocationContext,
|
||||||
|
@ -1,9 +1,8 @@
|
|||||||
from invokeai.app.shared.fields import FieldDescriptions
|
|
||||||
|
|
||||||
from ...backend.model_management import ModelType, SubModelType
|
from ...backend.model_management import ModelType, SubModelType
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
|
FieldDescriptions,
|
||||||
Input,
|
Input,
|
||||||
InputField,
|
InputField,
|
||||||
InvocationContext,
|
InvocationContext,
|
||||||
|
@ -5,6 +5,7 @@ from pydantic import BaseModel, ConfigDict, Field
|
|||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
|
FieldDescriptions,
|
||||||
Input,
|
Input,
|
||||||
InputField,
|
InputField,
|
||||||
InvocationContext,
|
InvocationContext,
|
||||||
@ -15,7 +16,6 @@ from invokeai.app.invocations.baseinvocation import (
|
|||||||
)
|
)
|
||||||
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES
|
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES
|
||||||
from invokeai.app.invocations.primitives import ImageField
|
from invokeai.app.invocations.primitives import ImageField
|
||||||
from invokeai.app.shared.fields import FieldDescriptions
|
|
||||||
from invokeai.backend.model_management.models.base import BaseModelType
|
from invokeai.backend.model_management.models.base import BaseModelType
|
||||||
|
|
||||||
|
|
||||||
|
@ -139,7 +139,7 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
|||||||
(board_id,),
|
(board_id,),
|
||||||
)
|
)
|
||||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||||
images = [deserialize_image_record(dict(r)) for r in result]
|
images = list(map(lambda r: deserialize_image_record(dict(r)), result))
|
||||||
|
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
@ -167,7 +167,7 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
|||||||
(board_id,),
|
(board_id,),
|
||||||
)
|
)
|
||||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||||
image_names = [r[0] for r in result]
|
image_names = list(map(lambda r: r[0], result))
|
||||||
return image_names
|
return image_names
|
||||||
except sqlite3.Error as e:
|
except sqlite3.Error as e:
|
||||||
self._conn.rollback()
|
self._conn.rollback()
|
||||||
|
@ -199,7 +199,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||||
boards = [deserialize_board_record(dict(r)) for r in result]
|
boards = list(map(lambda r: deserialize_board_record(dict(r)), result))
|
||||||
|
|
||||||
# Get the total number of boards
|
# Get the total number of boards
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
@ -236,7 +236,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||||
boards = [deserialize_board_record(dict(r)) for r in result]
|
boards = list(map(lambda r: deserialize_board_record(dict(r)), result))
|
||||||
|
|
||||||
return boards
|
return boards
|
||||||
|
|
||||||
|
@ -55,7 +55,7 @@ class InvokeAISettings(BaseSettings):
|
|||||||
"""
|
"""
|
||||||
cls = self.__class__
|
cls = self.__class__
|
||||||
type = get_args(get_type_hints(cls)["type"])[0]
|
type = get_args(get_type_hints(cls)["type"])[0]
|
||||||
field_dict = {type: {}}
|
field_dict = dict({type: dict()})
|
||||||
for name, field in self.model_fields.items():
|
for name, field in self.model_fields.items():
|
||||||
if name in cls._excluded_from_yaml():
|
if name in cls._excluded_from_yaml():
|
||||||
continue
|
continue
|
||||||
@ -64,7 +64,7 @@ class InvokeAISettings(BaseSettings):
|
|||||||
)
|
)
|
||||||
value = getattr(self, name)
|
value = getattr(self, name)
|
||||||
if category not in field_dict[type]:
|
if category not in field_dict[type]:
|
||||||
field_dict[type][category] = {}
|
field_dict[type][category] = dict()
|
||||||
# keep paths as strings to make it easier to read
|
# keep paths as strings to make it easier to read
|
||||||
field_dict[type][category][name] = str(value) if isinstance(value, Path) else value
|
field_dict[type][category][name] = str(value) if isinstance(value, Path) else value
|
||||||
conf = OmegaConf.create(field_dict)
|
conf = OmegaConf.create(field_dict)
|
||||||
@ -89,7 +89,7 @@ class InvokeAISettings(BaseSettings):
|
|||||||
# create an upcase version of the environment in
|
# create an upcase version of the environment in
|
||||||
# order to achieve case-insensitive environment
|
# order to achieve case-insensitive environment
|
||||||
# variables (the way Windows does)
|
# variables (the way Windows does)
|
||||||
upcase_environ = {}
|
upcase_environ = dict()
|
||||||
for key, value in os.environ.items():
|
for key, value in os.environ.items():
|
||||||
upcase_environ[key.upper()] = value
|
upcase_environ[key.upper()] = value
|
||||||
|
|
||||||
|
@ -188,18 +188,18 @@ DEFAULT_MAX_VRAM = 0.5
|
|||||||
|
|
||||||
|
|
||||||
class Categories(object):
|
class Categories(object):
|
||||||
WebServer = {"category": "Web Server"}
|
WebServer = dict(category="Web Server")
|
||||||
Features = {"category": "Features"}
|
Features = dict(category="Features")
|
||||||
Paths = {"category": "Paths"}
|
Paths = dict(category="Paths")
|
||||||
Logging = {"category": "Logging"}
|
Logging = dict(category="Logging")
|
||||||
Development = {"category": "Development"}
|
Development = dict(category="Development")
|
||||||
Other = {"category": "Other"}
|
Other = dict(category="Other")
|
||||||
ModelCache = {"category": "Model Cache"}
|
ModelCache = dict(category="Model Cache")
|
||||||
Device = {"category": "Device"}
|
Device = dict(category="Device")
|
||||||
Generation = {"category": "Generation"}
|
Generation = dict(category="Generation")
|
||||||
Queue = {"category": "Queue"}
|
Queue = dict(category="Queue")
|
||||||
Nodes = {"category": "Nodes"}
|
Nodes = dict(category="Nodes")
|
||||||
MemoryPerformance = {"category": "Memory/Performance"}
|
MemoryPerformance = dict(category="Memory/Performance")
|
||||||
|
|
||||||
|
|
||||||
class InvokeAIAppConfig(InvokeAISettings):
|
class InvokeAIAppConfig(InvokeAISettings):
|
||||||
@ -482,7 +482,7 @@ def _find_root() -> Path:
|
|||||||
venv = Path(os.environ.get("VIRTUAL_ENV") or ".")
|
venv = Path(os.environ.get("VIRTUAL_ENV") or ".")
|
||||||
if os.environ.get("INVOKEAI_ROOT"):
|
if os.environ.get("INVOKEAI_ROOT"):
|
||||||
root = Path(os.environ["INVOKEAI_ROOT"])
|
root = Path(os.environ["INVOKEAI_ROOT"])
|
||||||
elif any((venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE]):
|
elif any([(venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE]]):
|
||||||
root = (venv.parent).resolve()
|
root = (venv.parent).resolve()
|
||||||
else:
|
else:
|
||||||
root = Path("~/invokeai").expanduser().resolve()
|
root = Path("~/invokeai").expanduser().resolve()
|
||||||
|
@ -27,7 +27,7 @@ class EventServiceBase:
|
|||||||
payload["timestamp"] = get_timestamp()
|
payload["timestamp"] = get_timestamp()
|
||||||
self.dispatch(
|
self.dispatch(
|
||||||
event_name=EventServiceBase.queue_event,
|
event_name=EventServiceBase.queue_event,
|
||||||
payload={"event": event_name, "data": payload},
|
payload=dict(event=event_name, data=payload),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Define events here for every event in the system.
|
# Define events here for every event in the system.
|
||||||
@ -48,18 +48,18 @@ class EventServiceBase:
|
|||||||
"""Emitted when there is generation progress"""
|
"""Emitted when there is generation progress"""
|
||||||
self.__emit_queue_event(
|
self.__emit_queue_event(
|
||||||
event_name="generator_progress",
|
event_name="generator_progress",
|
||||||
payload={
|
payload=dict(
|
||||||
"queue_id": queue_id,
|
queue_id=queue_id,
|
||||||
"queue_item_id": queue_item_id,
|
queue_item_id=queue_item_id,
|
||||||
"queue_batch_id": queue_batch_id,
|
queue_batch_id=queue_batch_id,
|
||||||
"graph_execution_state_id": graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
"node_id": node.get("id"),
|
node_id=node.get("id"),
|
||||||
"source_node_id": source_node_id,
|
source_node_id=source_node_id,
|
||||||
"progress_image": progress_image.model_dump() if progress_image is not None else None,
|
progress_image=progress_image.model_dump() if progress_image is not None else None,
|
||||||
"step": step,
|
step=step,
|
||||||
"order": order,
|
order=order,
|
||||||
"total_steps": total_steps,
|
total_steps=total_steps,
|
||||||
},
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_invocation_complete(
|
def emit_invocation_complete(
|
||||||
@ -75,15 +75,15 @@ class EventServiceBase:
|
|||||||
"""Emitted when an invocation has completed"""
|
"""Emitted when an invocation has completed"""
|
||||||
self.__emit_queue_event(
|
self.__emit_queue_event(
|
||||||
event_name="invocation_complete",
|
event_name="invocation_complete",
|
||||||
payload={
|
payload=dict(
|
||||||
"queue_id": queue_id,
|
queue_id=queue_id,
|
||||||
"queue_item_id": queue_item_id,
|
queue_item_id=queue_item_id,
|
||||||
"queue_batch_id": queue_batch_id,
|
queue_batch_id=queue_batch_id,
|
||||||
"graph_execution_state_id": graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
"node": node,
|
node=node,
|
||||||
"source_node_id": source_node_id,
|
source_node_id=source_node_id,
|
||||||
"result": result,
|
result=result,
|
||||||
},
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_invocation_error(
|
def emit_invocation_error(
|
||||||
@ -100,16 +100,16 @@ class EventServiceBase:
|
|||||||
"""Emitted when an invocation has completed"""
|
"""Emitted when an invocation has completed"""
|
||||||
self.__emit_queue_event(
|
self.__emit_queue_event(
|
||||||
event_name="invocation_error",
|
event_name="invocation_error",
|
||||||
payload={
|
payload=dict(
|
||||||
"queue_id": queue_id,
|
queue_id=queue_id,
|
||||||
"queue_item_id": queue_item_id,
|
queue_item_id=queue_item_id,
|
||||||
"queue_batch_id": queue_batch_id,
|
queue_batch_id=queue_batch_id,
|
||||||
"graph_execution_state_id": graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
"node": node,
|
node=node,
|
||||||
"source_node_id": source_node_id,
|
source_node_id=source_node_id,
|
||||||
"error_type": error_type,
|
error_type=error_type,
|
||||||
"error": error,
|
error=error,
|
||||||
},
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_invocation_started(
|
def emit_invocation_started(
|
||||||
@ -124,14 +124,14 @@ class EventServiceBase:
|
|||||||
"""Emitted when an invocation has started"""
|
"""Emitted when an invocation has started"""
|
||||||
self.__emit_queue_event(
|
self.__emit_queue_event(
|
||||||
event_name="invocation_started",
|
event_name="invocation_started",
|
||||||
payload={
|
payload=dict(
|
||||||
"queue_id": queue_id,
|
queue_id=queue_id,
|
||||||
"queue_item_id": queue_item_id,
|
queue_item_id=queue_item_id,
|
||||||
"queue_batch_id": queue_batch_id,
|
queue_batch_id=queue_batch_id,
|
||||||
"graph_execution_state_id": graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
"node": node,
|
node=node,
|
||||||
"source_node_id": source_node_id,
|
source_node_id=source_node_id,
|
||||||
},
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_graph_execution_complete(
|
def emit_graph_execution_complete(
|
||||||
@ -140,12 +140,12 @@ class EventServiceBase:
|
|||||||
"""Emitted when a session has completed all invocations"""
|
"""Emitted when a session has completed all invocations"""
|
||||||
self.__emit_queue_event(
|
self.__emit_queue_event(
|
||||||
event_name="graph_execution_state_complete",
|
event_name="graph_execution_state_complete",
|
||||||
payload={
|
payload=dict(
|
||||||
"queue_id": queue_id,
|
queue_id=queue_id,
|
||||||
"queue_item_id": queue_item_id,
|
queue_item_id=queue_item_id,
|
||||||
"queue_batch_id": queue_batch_id,
|
queue_batch_id=queue_batch_id,
|
||||||
"graph_execution_state_id": graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
},
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_model_load_started(
|
def emit_model_load_started(
|
||||||
@ -162,16 +162,16 @@ class EventServiceBase:
|
|||||||
"""Emitted when a model is requested"""
|
"""Emitted when a model is requested"""
|
||||||
self.__emit_queue_event(
|
self.__emit_queue_event(
|
||||||
event_name="model_load_started",
|
event_name="model_load_started",
|
||||||
payload={
|
payload=dict(
|
||||||
"queue_id": queue_id,
|
queue_id=queue_id,
|
||||||
"queue_item_id": queue_item_id,
|
queue_item_id=queue_item_id,
|
||||||
"queue_batch_id": queue_batch_id,
|
queue_batch_id=queue_batch_id,
|
||||||
"graph_execution_state_id": graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
"model_name": model_name,
|
model_name=model_name,
|
||||||
"base_model": base_model,
|
base_model=base_model,
|
||||||
"model_type": model_type,
|
model_type=model_type,
|
||||||
"submodel": submodel,
|
submodel=submodel,
|
||||||
},
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_model_load_completed(
|
def emit_model_load_completed(
|
||||||
@ -189,19 +189,19 @@ class EventServiceBase:
|
|||||||
"""Emitted when a model is correctly loaded (returns model info)"""
|
"""Emitted when a model is correctly loaded (returns model info)"""
|
||||||
self.__emit_queue_event(
|
self.__emit_queue_event(
|
||||||
event_name="model_load_completed",
|
event_name="model_load_completed",
|
||||||
payload={
|
payload=dict(
|
||||||
"queue_id": queue_id,
|
queue_id=queue_id,
|
||||||
"queue_item_id": queue_item_id,
|
queue_item_id=queue_item_id,
|
||||||
"queue_batch_id": queue_batch_id,
|
queue_batch_id=queue_batch_id,
|
||||||
"graph_execution_state_id": graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
"model_name": model_name,
|
model_name=model_name,
|
||||||
"base_model": base_model,
|
base_model=base_model,
|
||||||
"model_type": model_type,
|
model_type=model_type,
|
||||||
"submodel": submodel,
|
submodel=submodel,
|
||||||
"hash": model_info.hash,
|
hash=model_info.hash,
|
||||||
"location": str(model_info.location),
|
location=str(model_info.location),
|
||||||
"precision": str(model_info.precision),
|
precision=str(model_info.precision),
|
||||||
},
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_session_retrieval_error(
|
def emit_session_retrieval_error(
|
||||||
@ -216,14 +216,14 @@ class EventServiceBase:
|
|||||||
"""Emitted when session retrieval fails"""
|
"""Emitted when session retrieval fails"""
|
||||||
self.__emit_queue_event(
|
self.__emit_queue_event(
|
||||||
event_name="session_retrieval_error",
|
event_name="session_retrieval_error",
|
||||||
payload={
|
payload=dict(
|
||||||
"queue_id": queue_id,
|
queue_id=queue_id,
|
||||||
"queue_item_id": queue_item_id,
|
queue_item_id=queue_item_id,
|
||||||
"queue_batch_id": queue_batch_id,
|
queue_batch_id=queue_batch_id,
|
||||||
"graph_execution_state_id": graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
"error_type": error_type,
|
error_type=error_type,
|
||||||
"error": error,
|
error=error,
|
||||||
},
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_invocation_retrieval_error(
|
def emit_invocation_retrieval_error(
|
||||||
@ -239,15 +239,15 @@ class EventServiceBase:
|
|||||||
"""Emitted when invocation retrieval fails"""
|
"""Emitted when invocation retrieval fails"""
|
||||||
self.__emit_queue_event(
|
self.__emit_queue_event(
|
||||||
event_name="invocation_retrieval_error",
|
event_name="invocation_retrieval_error",
|
||||||
payload={
|
payload=dict(
|
||||||
"queue_id": queue_id,
|
queue_id=queue_id,
|
||||||
"queue_item_id": queue_item_id,
|
queue_item_id=queue_item_id,
|
||||||
"queue_batch_id": queue_batch_id,
|
queue_batch_id=queue_batch_id,
|
||||||
"graph_execution_state_id": graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
"node_id": node_id,
|
node_id=node_id,
|
||||||
"error_type": error_type,
|
error_type=error_type,
|
||||||
"error": error,
|
error=error,
|
||||||
},
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_session_canceled(
|
def emit_session_canceled(
|
||||||
@ -260,12 +260,12 @@ class EventServiceBase:
|
|||||||
"""Emitted when a session is canceled"""
|
"""Emitted when a session is canceled"""
|
||||||
self.__emit_queue_event(
|
self.__emit_queue_event(
|
||||||
event_name="session_canceled",
|
event_name="session_canceled",
|
||||||
payload={
|
payload=dict(
|
||||||
"queue_id": queue_id,
|
queue_id=queue_id,
|
||||||
"queue_item_id": queue_item_id,
|
queue_item_id=queue_item_id,
|
||||||
"queue_batch_id": queue_batch_id,
|
queue_batch_id=queue_batch_id,
|
||||||
"graph_execution_state_id": graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
},
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_queue_item_status_changed(
|
def emit_queue_item_status_changed(
|
||||||
@ -277,39 +277,39 @@ class EventServiceBase:
|
|||||||
"""Emitted when a queue item's status changes"""
|
"""Emitted when a queue item's status changes"""
|
||||||
self.__emit_queue_event(
|
self.__emit_queue_event(
|
||||||
event_name="queue_item_status_changed",
|
event_name="queue_item_status_changed",
|
||||||
payload={
|
payload=dict(
|
||||||
"queue_id": queue_status.queue_id,
|
queue_id=queue_status.queue_id,
|
||||||
"queue_item": {
|
queue_item=dict(
|
||||||
"queue_id": session_queue_item.queue_id,
|
queue_id=session_queue_item.queue_id,
|
||||||
"item_id": session_queue_item.item_id,
|
item_id=session_queue_item.item_id,
|
||||||
"status": session_queue_item.status,
|
status=session_queue_item.status,
|
||||||
"batch_id": session_queue_item.batch_id,
|
batch_id=session_queue_item.batch_id,
|
||||||
"session_id": session_queue_item.session_id,
|
session_id=session_queue_item.session_id,
|
||||||
"error": session_queue_item.error,
|
error=session_queue_item.error,
|
||||||
"created_at": str(session_queue_item.created_at) if session_queue_item.created_at else None,
|
created_at=str(session_queue_item.created_at) if session_queue_item.created_at else None,
|
||||||
"updated_at": str(session_queue_item.updated_at) if session_queue_item.updated_at else None,
|
updated_at=str(session_queue_item.updated_at) if session_queue_item.updated_at else None,
|
||||||
"started_at": str(session_queue_item.started_at) if session_queue_item.started_at else None,
|
started_at=str(session_queue_item.started_at) if session_queue_item.started_at else None,
|
||||||
"completed_at": str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
|
completed_at=str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
|
||||||
},
|
),
|
||||||
"batch_status": batch_status.model_dump(),
|
batch_status=batch_status.model_dump(),
|
||||||
"queue_status": queue_status.model_dump(),
|
queue_status=queue_status.model_dump(),
|
||||||
},
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_batch_enqueued(self, enqueue_result: EnqueueBatchResult) -> None:
|
def emit_batch_enqueued(self, enqueue_result: EnqueueBatchResult) -> None:
|
||||||
"""Emitted when a batch is enqueued"""
|
"""Emitted when a batch is enqueued"""
|
||||||
self.__emit_queue_event(
|
self.__emit_queue_event(
|
||||||
event_name="batch_enqueued",
|
event_name="batch_enqueued",
|
||||||
payload={
|
payload=dict(
|
||||||
"queue_id": enqueue_result.queue_id,
|
queue_id=enqueue_result.queue_id,
|
||||||
"batch_id": enqueue_result.batch.batch_id,
|
batch_id=enqueue_result.batch.batch_id,
|
||||||
"enqueued": enqueue_result.enqueued,
|
enqueued=enqueue_result.enqueued,
|
||||||
},
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_queue_cleared(self, queue_id: str) -> None:
|
def emit_queue_cleared(self, queue_id: str) -> None:
|
||||||
"""Emitted when the queue is cleared"""
|
"""Emitted when the queue is cleared"""
|
||||||
self.__emit_queue_event(
|
self.__emit_queue_event(
|
||||||
event_name="queue_cleared",
|
event_name="queue_cleared",
|
||||||
payload={"queue_id": queue_id},
|
payload=dict(queue_id=queue_id),
|
||||||
)
|
)
|
||||||
|
@ -25,7 +25,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
__invoker: Invoker
|
__invoker: Invoker
|
||||||
|
|
||||||
def __init__(self, output_folder: Union[str, Path]):
|
def __init__(self, output_folder: Union[str, Path]):
|
||||||
self.__cache = {}
|
self.__cache = dict()
|
||||||
self.__cache_ids = Queue()
|
self.__cache_ids = Queue()
|
||||||
self.__max_cache_size = 10 # TODO: get this from config
|
self.__max_cache_size = 10 # TODO: get this from config
|
||||||
|
|
||||||
|
@ -90,9 +90,10 @@ class ImageRecordDeleteException(Exception):
|
|||||||
|
|
||||||
|
|
||||||
IMAGE_DTO_COLS = ", ".join(
|
IMAGE_DTO_COLS = ", ".join(
|
||||||
|
list(
|
||||||
|
map(
|
||||||
|
lambda c: "images." + c,
|
||||||
[
|
[
|
||||||
"images." + c
|
|
||||||
for c in [
|
|
||||||
"image_name",
|
"image_name",
|
||||||
"image_origin",
|
"image_origin",
|
||||||
"image_category",
|
"image_category",
|
||||||
@ -105,8 +106,9 @@ IMAGE_DTO_COLS = ", ".join(
|
|||||||
"updated_at",
|
"updated_at",
|
||||||
"deleted_at",
|
"deleted_at",
|
||||||
"starred",
|
"starred",
|
||||||
]
|
],
|
||||||
]
|
)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -263,7 +263,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
|
|
||||||
if categories is not None:
|
if categories is not None:
|
||||||
# Convert the enum values to unique list of strings
|
# Convert the enum values to unique list of strings
|
||||||
category_strings = [c.value for c in set(categories)]
|
category_strings = list(map(lambda c: c.value, set(categories)))
|
||||||
# Create the correct length of placeholders
|
# Create the correct length of placeholders
|
||||||
placeholders = ",".join("?" * len(category_strings))
|
placeholders = ",".join("?" * len(category_strings))
|
||||||
|
|
||||||
@ -307,7 +307,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
# Build the list of images, deserializing each row
|
# Build the list of images, deserializing each row
|
||||||
self._cursor.execute(images_query, images_params)
|
self._cursor.execute(images_query, images_params)
|
||||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||||
images = [deserialize_image_record(dict(r)) for r in result]
|
images = list(map(lambda r: deserialize_image_record(dict(r)), result))
|
||||||
|
|
||||||
# Set up and execute the count query, without pagination
|
# Set up and execute the count query, without pagination
|
||||||
count_query += query_conditions + ";"
|
count_query += query_conditions + ";"
|
||||||
@ -386,7 +386,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||||
image_names = [r[0] for r in result]
|
image_names = list(map(lambda r: r[0], result))
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
DELETE FROM images
|
DELETE FROM images
|
||||||
|
@ -21,8 +21,8 @@ class ImageServiceABC(ABC):
|
|||||||
_on_deleted_callbacks: list[Callable[[str], None]]
|
_on_deleted_callbacks: list[Callable[[str], None]]
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._on_changed_callbacks = []
|
self._on_changed_callbacks = list()
|
||||||
self._on_deleted_callbacks = []
|
self._on_deleted_callbacks = list()
|
||||||
|
|
||||||
def on_changed(self, on_changed: Callable[[ImageDTO], None]) -> None:
|
def on_changed(self, on_changed: Callable[[ImageDTO], None]) -> None:
|
||||||
"""Register a callback for when an image is changed"""
|
"""Register a callback for when an image is changed"""
|
||||||
|
@ -217,16 +217,18 @@ class ImageService(ImageServiceABC):
|
|||||||
board_id,
|
board_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
image_dtos = [
|
image_dtos = list(
|
||||||
image_record_to_dto(
|
map(
|
||||||
|
lambda r: image_record_to_dto(
|
||||||
image_record=r,
|
image_record=r,
|
||||||
image_url=self.__invoker.services.urls.get_image_url(r.image_name),
|
image_url=self.__invoker.services.urls.get_image_url(r.image_name),
|
||||||
thumbnail_url=self.__invoker.services.urls.get_image_url(r.image_name, True),
|
thumbnail_url=self.__invoker.services.urls.get_image_url(r.image_name, True),
|
||||||
board_id=self.__invoker.services.board_image_records.get_board_for_image(r.image_name),
|
board_id=self.__invoker.services.board_image_records.get_board_for_image(r.image_name),
|
||||||
workflow_id=self.__invoker.services.workflow_image_records.get_workflow_for_image(r.image_name),
|
workflow_id=self.__invoker.services.workflow_image_records.get_workflow_for_image(r.image_name),
|
||||||
|
),
|
||||||
|
results.items,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
for r in results.items
|
|
||||||
]
|
|
||||||
|
|
||||||
return OffsetPaginatedResults[ImageDTO](
|
return OffsetPaginatedResults[ImageDTO](
|
||||||
items=image_dtos,
|
items=image_dtos,
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from abc import ABC
|
from abc import ABC
|
||||||
|
|
||||||
|
|
||||||
class InvocationProcessorABC(ABC): # noqa: B024
|
class InvocationProcessorABC(ABC):
|
||||||
pass
|
pass
|
||||||
|
@ -26,7 +26,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
self.__invoker_thread = Thread(
|
self.__invoker_thread = Thread(
|
||||||
name="invoker_processor",
|
name="invoker_processor",
|
||||||
target=self.__process,
|
target=self.__process,
|
||||||
kwargs={"stop_event": self.__stop_event},
|
kwargs=dict(stop_event=self.__stop_event),
|
||||||
)
|
)
|
||||||
self.__invoker_thread.daemon = True # TODO: make async and do not use threads
|
self.__invoker_thread.daemon = True # TODO: make async and do not use threads
|
||||||
self.__invoker_thread.start()
|
self.__invoker_thread.start()
|
||||||
|
@ -14,7 +14,7 @@ class MemoryInvocationQueue(InvocationQueueABC):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.__queue = Queue()
|
self.__queue = Queue()
|
||||||
self.__cancellations = {}
|
self.__cancellations = dict()
|
||||||
|
|
||||||
def get(self) -> InvocationQueueItem:
|
def get(self) -> InvocationQueueItem:
|
||||||
item = self.__queue.get()
|
item = self.__queue.get()
|
||||||
|
@ -22,7 +22,6 @@ if TYPE_CHECKING:
|
|||||||
from .item_storage.item_storage_base import ItemStorageABC
|
from .item_storage.item_storage_base import ItemStorageABC
|
||||||
from .latents_storage.latents_storage_base import LatentsStorageBase
|
from .latents_storage.latents_storage_base import LatentsStorageBase
|
||||||
from .model_manager.model_manager_base import ModelManagerServiceBase
|
from .model_manager.model_manager_base import ModelManagerServiceBase
|
||||||
from .model_records import ModelRecordServiceBase
|
|
||||||
from .names.names_base import NameServiceBase
|
from .names.names_base import NameServiceBase
|
||||||
from .session_processor.session_processor_base import SessionProcessorBase
|
from .session_processor.session_processor_base import SessionProcessorBase
|
||||||
from .session_queue.session_queue_base import SessionQueueBase
|
from .session_queue.session_queue_base import SessionQueueBase
|
||||||
@ -50,7 +49,6 @@ class InvocationServices:
|
|||||||
latents: "LatentsStorageBase"
|
latents: "LatentsStorageBase"
|
||||||
logger: "Logger"
|
logger: "Logger"
|
||||||
model_manager: "ModelManagerServiceBase"
|
model_manager: "ModelManagerServiceBase"
|
||||||
model_records: "ModelRecordServiceBase"
|
|
||||||
processor: "InvocationProcessorABC"
|
processor: "InvocationProcessorABC"
|
||||||
performance_statistics: "InvocationStatsServiceBase"
|
performance_statistics: "InvocationStatsServiceBase"
|
||||||
queue: "InvocationQueueABC"
|
queue: "InvocationQueueABC"
|
||||||
@ -78,7 +76,6 @@ class InvocationServices:
|
|||||||
latents: "LatentsStorageBase",
|
latents: "LatentsStorageBase",
|
||||||
logger: "Logger",
|
logger: "Logger",
|
||||||
model_manager: "ModelManagerServiceBase",
|
model_manager: "ModelManagerServiceBase",
|
||||||
model_records: "ModelRecordServiceBase",
|
|
||||||
processor: "InvocationProcessorABC",
|
processor: "InvocationProcessorABC",
|
||||||
performance_statistics: "InvocationStatsServiceBase",
|
performance_statistics: "InvocationStatsServiceBase",
|
||||||
queue: "InvocationQueueABC",
|
queue: "InvocationQueueABC",
|
||||||
@ -104,7 +101,6 @@ class InvocationServices:
|
|||||||
self.latents = latents
|
self.latents = latents
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.model_manager = model_manager
|
self.model_manager = model_manager
|
||||||
self.model_records = model_records
|
|
||||||
self.processor = processor
|
self.processor = processor
|
||||||
self.performance_statistics = performance_statistics
|
self.performance_statistics = performance_statistics
|
||||||
self.queue = queue
|
self.queue = queue
|
||||||
|
@ -122,7 +122,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
def log_stats(self):
|
def log_stats(self):
|
||||||
completed = set()
|
completed = set()
|
||||||
errored = set()
|
errored = set()
|
||||||
for graph_id, _node_log in self._stats.items():
|
for graph_id, node_log in self._stats.items():
|
||||||
try:
|
try:
|
||||||
current_graph_state = self._invoker.services.graph_execution_manager.get(graph_id)
|
current_graph_state = self._invoker.services.graph_execution_manager.get(graph_id)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -142,7 +142,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
cache_stats = self._cache_stats[graph_id]
|
cache_stats = self._cache_stats[graph_id]
|
||||||
hwm = cache_stats.high_watermark / GIG
|
hwm = cache_stats.high_watermark / GIG
|
||||||
tot = cache_stats.cache_size / GIG
|
tot = cache_stats.cache_size / GIG
|
||||||
loaded = sum(list(cache_stats.loaded_model_sizes.values())) / GIG
|
loaded = sum([v for v in cache_stats.loaded_model_sizes.values()]) / GIG
|
||||||
|
|
||||||
logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:7.3f}s")
|
logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:7.3f}s")
|
||||||
logger.info("RAM used by InvokeAI process: " + "%4.2fG" % self.ram_used + f" ({self.ram_changed:+5.3f}G)")
|
logger.info("RAM used by InvokeAI process: " + "%4.2fG" % self.ram_used + f" ({self.ram_changed:+5.3f}G)")
|
||||||
|
@ -15,8 +15,8 @@ class ItemStorageABC(ABC, Generic[T]):
|
|||||||
_on_deleted_callbacks: list[Callable[[str], None]]
|
_on_deleted_callbacks: list[Callable[[str], None]]
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._on_changed_callbacks = []
|
self._on_changed_callbacks = list()
|
||||||
self._on_deleted_callbacks = []
|
self._on_deleted_callbacks = list()
|
||||||
|
|
||||||
"""Base item storage class"""
|
"""Base item storage class"""
|
||||||
|
|
||||||
|
@ -112,7 +112,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
)
|
)
|
||||||
result = self._cursor.fetchall()
|
result = self._cursor.fetchall()
|
||||||
|
|
||||||
items = [self._parse_item(r[0]) for r in result]
|
items = list(map(lambda r: self._parse_item(r[0]), result))
|
||||||
|
|
||||||
self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""")
|
self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""")
|
||||||
count = self._cursor.fetchone()[0]
|
count = self._cursor.fetchone()[0]
|
||||||
@ -132,7 +132,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
)
|
)
|
||||||
result = self._cursor.fetchall()
|
result = self._cursor.fetchall()
|
||||||
|
|
||||||
items = [self._parse_item(r[0]) for r in result]
|
items = list(map(lambda r: self._parse_item(r[0]), result))
|
||||||
|
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
f"""SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;""",
|
f"""SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;""",
|
||||||
|
@ -13,8 +13,8 @@ class LatentsStorageBase(ABC):
|
|||||||
_on_deleted_callbacks: list[Callable[[str], None]]
|
_on_deleted_callbacks: list[Callable[[str], None]]
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._on_changed_callbacks = []
|
self._on_changed_callbacks = list()
|
||||||
self._on_deleted_callbacks = []
|
self._on_deleted_callbacks = list()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get(self, name: str) -> torch.Tensor:
|
def get(self, name: str) -> torch.Tensor:
|
||||||
|
@ -19,7 +19,7 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
|
|||||||
def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20):
|
def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.__underlying_storage = underlying_storage
|
self.__underlying_storage = underlying_storage
|
||||||
self.__cache = {}
|
self.__cache = dict()
|
||||||
self.__cache_ids = Queue()
|
self.__cache_ids = Queue()
|
||||||
self.__max_cache_size = max_cache_size
|
self.__max_cache_size = max_cache_size
|
||||||
|
|
||||||
|
@ -1,8 +0,0 @@
|
|||||||
"""Init file for model record services."""
|
|
||||||
from .model_records_base import ( # noqa F401
|
|
||||||
DuplicateModelException,
|
|
||||||
InvalidModelException,
|
|
||||||
ModelRecordServiceBase,
|
|
||||||
UnknownModelException,
|
|
||||||
)
|
|
||||||
from .model_records_sql import ModelRecordServiceSQL # noqa F401
|
|
@ -1,169 +0,0 @@
|
|||||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
|
||||||
"""
|
|
||||||
Abstract base class for storing and retrieving model configuration records.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Optional, Union
|
|
||||||
|
|
||||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType
|
|
||||||
|
|
||||||
# should match the InvokeAI version when this is first released.
|
|
||||||
CONFIG_FILE_VERSION = "3.2.0"
|
|
||||||
|
|
||||||
|
|
||||||
class DuplicateModelException(Exception):
|
|
||||||
"""Raised on an attempt to add a model with the same key twice."""
|
|
||||||
|
|
||||||
|
|
||||||
class InvalidModelException(Exception):
|
|
||||||
"""Raised when an invalid model is detected."""
|
|
||||||
|
|
||||||
|
|
||||||
class UnknownModelException(Exception):
|
|
||||||
"""Raised on an attempt to fetch or delete a model with a nonexistent key."""
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigFileVersionMismatchException(Exception):
|
|
||||||
"""Raised on an attempt to open a config with an incompatible version."""
|
|
||||||
|
|
||||||
|
|
||||||
class ModelRecordServiceBase(ABC):
|
|
||||||
"""Abstract base class for storage and retrieval of model configs."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def version(self) -> str:
|
|
||||||
"""Return the config file/database schema version."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
|
|
||||||
"""
|
|
||||||
Add a model to the database.
|
|
||||||
|
|
||||||
:param key: Unique key for the model
|
|
||||||
:param config: Model configuration record, either a dict with the
|
|
||||||
required fields or a ModelConfigBase instance.
|
|
||||||
|
|
||||||
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def del_model(self, key: str) -> None:
|
|
||||||
"""
|
|
||||||
Delete a model.
|
|
||||||
|
|
||||||
:param key: Unique key for the model to be deleted
|
|
||||||
|
|
||||||
Can raise an UnknownModelException
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def update_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
|
|
||||||
"""
|
|
||||||
Update the model, returning the updated version.
|
|
||||||
|
|
||||||
:param key: Unique key for the model to be updated
|
|
||||||
:param config: Model configuration record. Either a dict with the
|
|
||||||
required fields, or a ModelConfigBase instance.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_model(self, key: str) -> AnyModelConfig:
|
|
||||||
"""
|
|
||||||
Retrieve the configuration for the indicated model.
|
|
||||||
|
|
||||||
:param key: Key of model config to be fetched.
|
|
||||||
|
|
||||||
Exceptions: UnknownModelException
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def exists(self, key: str) -> bool:
|
|
||||||
"""
|
|
||||||
Return True if a model with the indicated key exists in the databse.
|
|
||||||
|
|
||||||
:param key: Unique key for the model to be deleted
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def search_by_path(
|
|
||||||
self,
|
|
||||||
path: Union[str, Path],
|
|
||||||
) -> List[AnyModelConfig]:
|
|
||||||
"""Return the model(s) having the indicated path."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def search_by_hash(
|
|
||||||
self,
|
|
||||||
hash: str,
|
|
||||||
) -> List[AnyModelConfig]:
|
|
||||||
"""Return the model(s) having the indicated original hash."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def search_by_attr(
|
|
||||||
self,
|
|
||||||
model_name: Optional[str] = None,
|
|
||||||
base_model: Optional[BaseModelType] = None,
|
|
||||||
model_type: Optional[ModelType] = None,
|
|
||||||
) -> List[AnyModelConfig]:
|
|
||||||
"""
|
|
||||||
Return models matching name, base and/or type.
|
|
||||||
|
|
||||||
:param model_name: Filter by name of model (optional)
|
|
||||||
:param base_model: Filter by base model (optional)
|
|
||||||
:param model_type: Filter by type of model (optional)
|
|
||||||
|
|
||||||
If none of the optional filters are passed, will return all
|
|
||||||
models in the database.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def all_models(self) -> List[AnyModelConfig]:
|
|
||||||
"""Return all the model configs in the database."""
|
|
||||||
return self.search_by_attr()
|
|
||||||
|
|
||||||
def model_info_by_name(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> AnyModelConfig:
|
|
||||||
"""
|
|
||||||
Return information about a single model using its name, base type and model type.
|
|
||||||
|
|
||||||
If there are more than one model that match, raises a DuplicateModelException.
|
|
||||||
If no model matches, raises an UnknownModelException
|
|
||||||
"""
|
|
||||||
model_configs = self.search_by_attr(model_name=model_name, base_model=base_model, model_type=model_type)
|
|
||||||
if len(model_configs) > 1:
|
|
||||||
raise DuplicateModelException(
|
|
||||||
f"More than one model matched the search criteria: base_model='{base_model}', model_type='{model_type}', model_name='{model_name}'."
|
|
||||||
)
|
|
||||||
if len(model_configs) == 0:
|
|
||||||
raise UnknownModelException(
|
|
||||||
f"More than one model matched the search criteria: base_model='{base_model}', model_type='{model_type}', model_name='{model_name}'."
|
|
||||||
)
|
|
||||||
return model_configs[0]
|
|
||||||
|
|
||||||
def rename_model(
|
|
||||||
self,
|
|
||||||
key: str,
|
|
||||||
new_name: str,
|
|
||||||
) -> AnyModelConfig:
|
|
||||||
"""
|
|
||||||
Rename the indicated model. Just a special case of update_model().
|
|
||||||
|
|
||||||
In some implementations, renaming the model may involve changing where
|
|
||||||
it is stored on the filesystem. So this is broken out.
|
|
||||||
|
|
||||||
:param key: Model key
|
|
||||||
:param new_name: New name for model
|
|
||||||
"""
|
|
||||||
config = self.get_model(key)
|
|
||||||
config.name = new_name
|
|
||||||
return self.update_model(key, config)
|
|
@ -1,397 +0,0 @@
|
|||||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
|
||||||
"""
|
|
||||||
SQL Implementation of the ModelRecordServiceBase API
|
|
||||||
|
|
||||||
Typical usage:
|
|
||||||
|
|
||||||
from invokeai.backend.model_manager import ModelConfigStoreSQL
|
|
||||||
store = ModelConfigStoreSQL(sqlite_db)
|
|
||||||
config = dict(
|
|
||||||
path='/tmp/pokemon.bin',
|
|
||||||
name='old name',
|
|
||||||
base_model='sd-1',
|
|
||||||
type='embedding',
|
|
||||||
format='embedding_file',
|
|
||||||
)
|
|
||||||
|
|
||||||
# adding - the key becomes the model's "key" field
|
|
||||||
store.add_model('key1', config)
|
|
||||||
|
|
||||||
# updating
|
|
||||||
config.name='new name'
|
|
||||||
store.update_model('key1', config)
|
|
||||||
|
|
||||||
# checking for existence
|
|
||||||
if store.exists('key1'):
|
|
||||||
print("yes")
|
|
||||||
|
|
||||||
# fetching config
|
|
||||||
new_config = store.get_model('key1')
|
|
||||||
print(new_config.name, new_config.base)
|
|
||||||
assert new_config.key == 'key1'
|
|
||||||
|
|
||||||
# deleting
|
|
||||||
store.del_model('key1')
|
|
||||||
|
|
||||||
# searching
|
|
||||||
configs = store.search_by_path(path='/tmp/pokemon.bin')
|
|
||||||
configs = store.search_by_hash('750a499f35e43b7e1b4d15c207aa2f01')
|
|
||||||
configs = store.search_by_attr(base_model='sd-2', model_type='main')
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
import json
|
|
||||||
import sqlite3
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Optional, Union
|
|
||||||
|
|
||||||
from invokeai.backend.model_manager.config import (
|
|
||||||
AnyModelConfig,
|
|
||||||
BaseModelType,
|
|
||||||
ModelConfigBase,
|
|
||||||
ModelConfigFactory,
|
|
||||||
ModelType,
|
|
||||||
)
|
|
||||||
|
|
||||||
from ..shared.sqlite import SqliteDatabase
|
|
||||||
from .model_records_base import (
|
|
||||||
CONFIG_FILE_VERSION,
|
|
||||||
DuplicateModelException,
|
|
||||||
ModelRecordServiceBase,
|
|
||||||
UnknownModelException,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|
||||||
"""Implementation of the ModelConfigStore ABC using a SQL database."""
|
|
||||||
|
|
||||||
_db: SqliteDatabase
|
|
||||||
_cursor: sqlite3.Cursor
|
|
||||||
|
|
||||||
def __init__(self, db: SqliteDatabase):
|
|
||||||
"""
|
|
||||||
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
|
||||||
|
|
||||||
:param conn: sqlite3 connection object
|
|
||||||
:param lock: threading Lock object
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self._db = db
|
|
||||||
self._cursor = self._db.conn.cursor()
|
|
||||||
|
|
||||||
with self._db.lock:
|
|
||||||
# Enable foreign keys
|
|
||||||
self._db.conn.execute("PRAGMA foreign_keys = ON;")
|
|
||||||
self._create_tables()
|
|
||||||
self._db.conn.commit()
|
|
||||||
assert (
|
|
||||||
str(self.version) == CONFIG_FILE_VERSION
|
|
||||||
), f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}"
|
|
||||||
|
|
||||||
def _create_tables(self) -> None:
|
|
||||||
"""Create sqlite3 tables."""
|
|
||||||
# model_config table breaks out the fields that are common to all config objects
|
|
||||||
# and puts class-specific ones in a serialized json object
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
CREATE TABLE IF NOT EXISTS model_config (
|
|
||||||
id TEXT NOT NULL PRIMARY KEY,
|
|
||||||
-- The next 3 fields are enums in python, unrestricted string here
|
|
||||||
base TEXT NOT NULL,
|
|
||||||
type TEXT NOT NULL,
|
|
||||||
name TEXT NOT NULL,
|
|
||||||
path TEXT NOT NULL,
|
|
||||||
original_hash TEXT, -- could be null
|
|
||||||
-- Serialized JSON representation of the whole config object,
|
|
||||||
-- which will contain additional fields from subclasses
|
|
||||||
config TEXT NOT NULL,
|
|
||||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
|
||||||
-- Updated via trigger
|
|
||||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
|
||||||
-- unique constraint on combo of name, base and type
|
|
||||||
UNIQUE(name, base, type)
|
|
||||||
);
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
# metadata table
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
CREATE TABLE IF NOT EXISTS model_manager_metadata (
|
|
||||||
metadata_key TEXT NOT NULL PRIMARY KEY,
|
|
||||||
metadata_value TEXT NOT NULL
|
|
||||||
);
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add trigger for `updated_at`.
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
CREATE TRIGGER IF NOT EXISTS model_config_updated_at
|
|
||||||
AFTER UPDATE
|
|
||||||
ON model_config FOR EACH ROW
|
|
||||||
BEGIN
|
|
||||||
UPDATE model_config SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
|
||||||
WHERE id = old.id;
|
|
||||||
END;
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add indexes for searchable fields
|
|
||||||
for stmt in [
|
|
||||||
"CREATE INDEX IF NOT EXISTS base_index ON model_config(base);",
|
|
||||||
"CREATE INDEX IF NOT EXISTS type_index ON model_config(type);",
|
|
||||||
"CREATE INDEX IF NOT EXISTS name_index ON model_config(name);",
|
|
||||||
"CREATE UNIQUE INDEX IF NOT EXISTS path_index ON model_config(path);",
|
|
||||||
]:
|
|
||||||
self._cursor.execute(stmt)
|
|
||||||
|
|
||||||
# Add our version to the metadata table
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
INSERT OR IGNORE into model_manager_metadata (
|
|
||||||
metadata_key,
|
|
||||||
metadata_value
|
|
||||||
)
|
|
||||||
VALUES (?,?);
|
|
||||||
""",
|
|
||||||
("version", CONFIG_FILE_VERSION),
|
|
||||||
)
|
|
||||||
|
|
||||||
def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> AnyModelConfig:
|
|
||||||
"""
|
|
||||||
Add a model to the database.
|
|
||||||
|
|
||||||
:param key: Unique key for the model
|
|
||||||
:param config: Model configuration record, either a dict with the
|
|
||||||
required fields or a ModelConfigBase instance.
|
|
||||||
|
|
||||||
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
|
|
||||||
"""
|
|
||||||
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect.
|
|
||||||
json_serialized = record.model_dump_json() # and turn it into a json string.
|
|
||||||
with self._db.lock:
|
|
||||||
try:
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
INSERT INTO model_config (
|
|
||||||
id,
|
|
||||||
base,
|
|
||||||
type,
|
|
||||||
name,
|
|
||||||
path,
|
|
||||||
original_hash,
|
|
||||||
config
|
|
||||||
)
|
|
||||||
VALUES (?,?,?,?,?,?,?);
|
|
||||||
""",
|
|
||||||
(
|
|
||||||
key,
|
|
||||||
record.base,
|
|
||||||
record.type,
|
|
||||||
record.name,
|
|
||||||
record.path,
|
|
||||||
record.original_hash,
|
|
||||||
json_serialized,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
self._db.conn.commit()
|
|
||||||
|
|
||||||
except sqlite3.IntegrityError as e:
|
|
||||||
self._db.conn.rollback()
|
|
||||||
if "UNIQUE constraint failed" in str(e):
|
|
||||||
if "model_config.path" in str(e):
|
|
||||||
msg = f"A model with path '{record.path}' is already installed"
|
|
||||||
elif "model_config.name" in str(e):
|
|
||||||
msg = f"A model with name='{record.name}', type='{record.type}', base='{record.base}' is already installed"
|
|
||||||
else:
|
|
||||||
msg = f"A model with key '{key}' is already installed"
|
|
||||||
raise DuplicateModelException(msg) from e
|
|
||||||
else:
|
|
||||||
raise e
|
|
||||||
except sqlite3.Error as e:
|
|
||||||
self._db.conn.rollback()
|
|
||||||
raise e
|
|
||||||
|
|
||||||
return self.get_model(key)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def version(self) -> str:
|
|
||||||
"""Return the version of the database schema."""
|
|
||||||
with self._db.lock:
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
SELECT metadata_value FROM model_manager_metadata
|
|
||||||
WHERE metadata_key=?;
|
|
||||||
""",
|
|
||||||
("version",),
|
|
||||||
)
|
|
||||||
rows = self._cursor.fetchone()
|
|
||||||
if not rows:
|
|
||||||
raise KeyError("Models database does not have metadata key 'version'")
|
|
||||||
return rows[0]
|
|
||||||
|
|
||||||
def del_model(self, key: str) -> None:
|
|
||||||
"""
|
|
||||||
Delete a model.
|
|
||||||
|
|
||||||
:param key: Unique key for the model to be deleted
|
|
||||||
|
|
||||||
Can raise an UnknownModelException
|
|
||||||
"""
|
|
||||||
with self._db.lock:
|
|
||||||
try:
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
DELETE FROM model_config
|
|
||||||
WHERE id=?;
|
|
||||||
""",
|
|
||||||
(key,),
|
|
||||||
)
|
|
||||||
if self._cursor.rowcount == 0:
|
|
||||||
raise UnknownModelException("model not found")
|
|
||||||
self._db.conn.commit()
|
|
||||||
except sqlite3.Error as e:
|
|
||||||
self._db.conn.rollback()
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def update_model(self, key: str, config: ModelConfigBase) -> AnyModelConfig:
|
|
||||||
"""
|
|
||||||
Update the model, returning the updated version.
|
|
||||||
|
|
||||||
:param key: Unique key for the model to be updated
|
|
||||||
:param config: Model configuration record. Either a dict with the
|
|
||||||
required fields, or a ModelConfigBase instance.
|
|
||||||
"""
|
|
||||||
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect
|
|
||||||
json_serialized = record.model_dump_json() # and turn it into a json string.
|
|
||||||
with self._db.lock:
|
|
||||||
try:
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
UPDATE model_config
|
|
||||||
SET base=?,
|
|
||||||
type=?,
|
|
||||||
name=?,
|
|
||||||
path=?,
|
|
||||||
config=?
|
|
||||||
WHERE id=?;
|
|
||||||
""",
|
|
||||||
(record.base, record.type, record.name, record.path, json_serialized, key),
|
|
||||||
)
|
|
||||||
if self._cursor.rowcount == 0:
|
|
||||||
raise UnknownModelException("model not found")
|
|
||||||
self._db.conn.commit()
|
|
||||||
except sqlite3.Error as e:
|
|
||||||
self._db.conn.rollback()
|
|
||||||
raise e
|
|
||||||
|
|
||||||
return self.get_model(key)
|
|
||||||
|
|
||||||
def get_model(self, key: str) -> AnyModelConfig:
|
|
||||||
"""
|
|
||||||
Retrieve the ModelConfigBase instance for the indicated model.
|
|
||||||
|
|
||||||
:param key: Key of model config to be fetched.
|
|
||||||
|
|
||||||
Exceptions: UnknownModelException
|
|
||||||
"""
|
|
||||||
with self._db.lock:
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
SELECT config FROM model_config
|
|
||||||
WHERE id=?;
|
|
||||||
""",
|
|
||||||
(key,),
|
|
||||||
)
|
|
||||||
rows = self._cursor.fetchone()
|
|
||||||
if not rows:
|
|
||||||
raise UnknownModelException("model not found")
|
|
||||||
model = ModelConfigFactory.make_config(json.loads(rows[0]))
|
|
||||||
return model
|
|
||||||
|
|
||||||
def exists(self, key: str) -> bool:
|
|
||||||
"""
|
|
||||||
Return True if a model with the indicated key exists in the databse.
|
|
||||||
|
|
||||||
:param key: Unique key for the model to be deleted
|
|
||||||
"""
|
|
||||||
count = 0
|
|
||||||
with self._db.lock:
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
select count(*) FROM model_config
|
|
||||||
WHERE id=?;
|
|
||||||
""",
|
|
||||||
(key,),
|
|
||||||
)
|
|
||||||
count = self._cursor.fetchone()[0]
|
|
||||||
return count > 0
|
|
||||||
|
|
||||||
def search_by_attr(
|
|
||||||
self,
|
|
||||||
model_name: Optional[str] = None,
|
|
||||||
base_model: Optional[BaseModelType] = None,
|
|
||||||
model_type: Optional[ModelType] = None,
|
|
||||||
) -> List[AnyModelConfig]:
|
|
||||||
"""
|
|
||||||
Return models matching name, base and/or type.
|
|
||||||
|
|
||||||
:param model_name: Filter by name of model (optional)
|
|
||||||
:param base_model: Filter by base model (optional)
|
|
||||||
:param model_type: Filter by type of model (optional)
|
|
||||||
|
|
||||||
If none of the optional filters are passed, will return all
|
|
||||||
models in the database.
|
|
||||||
"""
|
|
||||||
results = []
|
|
||||||
where_clause = []
|
|
||||||
bindings = []
|
|
||||||
if model_name:
|
|
||||||
where_clause.append("name=?")
|
|
||||||
bindings.append(model_name)
|
|
||||||
if base_model:
|
|
||||||
where_clause.append("base=?")
|
|
||||||
bindings.append(base_model)
|
|
||||||
if model_type:
|
|
||||||
where_clause.append("type=?")
|
|
||||||
bindings.append(model_type)
|
|
||||||
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
|
|
||||||
with self._db.lock:
|
|
||||||
self._cursor.execute(
|
|
||||||
f"""--sql
|
|
||||||
select config FROM model_config
|
|
||||||
{where};
|
|
||||||
""",
|
|
||||||
tuple(bindings),
|
|
||||||
)
|
|
||||||
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
|
|
||||||
return results
|
|
||||||
|
|
||||||
def search_by_path(self, path: Union[str, Path]) -> List[ModelConfigBase]:
|
|
||||||
"""Return models with the indicated path."""
|
|
||||||
results = []
|
|
||||||
with self._db.lock:
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
SELECT config FROM model_config
|
|
||||||
WHERE model_path=?;
|
|
||||||
""",
|
|
||||||
(str(path),),
|
|
||||||
)
|
|
||||||
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
|
|
||||||
return results
|
|
||||||
|
|
||||||
def search_by_hash(self, hash: str) -> List[ModelConfigBase]:
|
|
||||||
"""Return models with the indicated original_hash."""
|
|
||||||
results = []
|
|
||||||
with self._db.lock:
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
SELECT config FROM model_config
|
|
||||||
WHERE original_hash=?;
|
|
||||||
""",
|
|
||||||
(hash,),
|
|
||||||
)
|
|
||||||
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
|
|
||||||
return results
|
|
@ -33,11 +33,9 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
self.__thread = Thread(
|
self.__thread = Thread(
|
||||||
name="session_processor",
|
name="session_processor",
|
||||||
target=self.__process,
|
target=self.__process,
|
||||||
kwargs={
|
kwargs=dict(
|
||||||
"stop_event": self.__stop_event,
|
stop_event=self.__stop_event, poll_now_event=self.__poll_now_event, resume_event=self.__resume_event
|
||||||
"poll_now_event": self.__poll_now_event,
|
),
|
||||||
"resume_event": self.__resume_event,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
self.__thread.start()
|
self.__thread.start()
|
||||||
|
|
||||||
|
@ -129,12 +129,12 @@ class Batch(BaseModel):
|
|||||||
return v
|
return v
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
json_schema_extra={
|
json_schema_extra=dict(
|
||||||
"required": [
|
required=[
|
||||||
"graph",
|
"graph",
|
||||||
"runs",
|
"runs",
|
||||||
]
|
]
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -191,8 +191,8 @@ class SessionQueueItemWithoutGraph(BaseModel):
|
|||||||
return SessionQueueItemDTO(**queue_item_dict)
|
return SessionQueueItemDTO(**queue_item_dict)
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
json_schema_extra={
|
json_schema_extra=dict(
|
||||||
"required": [
|
required=[
|
||||||
"item_id",
|
"item_id",
|
||||||
"status",
|
"status",
|
||||||
"batch_id",
|
"batch_id",
|
||||||
@ -203,7 +203,7 @@ class SessionQueueItemWithoutGraph(BaseModel):
|
|||||||
"created_at",
|
"created_at",
|
||||||
"updated_at",
|
"updated_at",
|
||||||
]
|
]
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -222,8 +222,8 @@ class SessionQueueItem(SessionQueueItemWithoutGraph):
|
|||||||
return SessionQueueItem(**queue_item_dict)
|
return SessionQueueItem(**queue_item_dict)
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
json_schema_extra={
|
json_schema_extra=dict(
|
||||||
"required": [
|
required=[
|
||||||
"item_id",
|
"item_id",
|
||||||
"status",
|
"status",
|
||||||
"batch_id",
|
"batch_id",
|
||||||
@ -235,7 +235,7 @@ class SessionQueueItem(SessionQueueItemWithoutGraph):
|
|||||||
"created_at",
|
"created_at",
|
||||||
"updated_at",
|
"updated_at",
|
||||||
]
|
]
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -355,7 +355,7 @@ def create_session_nfv_tuples(
|
|||||||
for item in batch_datum.items
|
for item in batch_datum.items
|
||||||
]
|
]
|
||||||
node_field_values_to_zip.append(node_field_values)
|
node_field_values_to_zip.append(node_field_values)
|
||||||
data.append(list(zip(*node_field_values_to_zip, strict=True))) # type: ignore [arg-type]
|
data.append(list(zip(*node_field_values_to_zip))) # type: ignore [arg-type]
|
||||||
|
|
||||||
# create generator to yield session,nfv tuples
|
# create generator to yield session,nfv tuples
|
||||||
count = 0
|
count = 0
|
||||||
@ -383,7 +383,7 @@ def calc_session_count(batch: Batch) -> int:
|
|||||||
for batch_datum in batch_datum_list:
|
for batch_datum in batch_datum_list:
|
||||||
batch_data_items = range(len(batch_datum.items))
|
batch_data_items = range(len(batch_datum.items))
|
||||||
to_zip.append(batch_data_items)
|
to_zip.append(batch_data_items)
|
||||||
data.append(list(zip(*to_zip, strict=True)))
|
data.append(list(zip(*to_zip)))
|
||||||
data_product = list(product(*data))
|
data_product = list(product(*data))
|
||||||
return len(data_product) * batch.runs
|
return len(data_product) * batch.runs
|
||||||
|
|
||||||
|
@ -78,7 +78,7 @@ def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[Li
|
|||||||
"""Creates the default system graphs, or adds new versions if the old ones don't match"""
|
"""Creates the default system graphs, or adds new versions if the old ones don't match"""
|
||||||
|
|
||||||
# TODO: Uncomment this when we are ready to fix this up to prevent breaking changes
|
# TODO: Uncomment this when we are ready to fix this up to prevent breaking changes
|
||||||
graphs: list[LibraryGraph] = []
|
graphs: list[LibraryGraph] = list()
|
||||||
|
|
||||||
text_to_image = graph_library.get(default_text_to_image_graph_id)
|
text_to_image = graph_library.get(default_text_to_image_graph_id)
|
||||||
|
|
||||||
|
@ -352,7 +352,7 @@ class Graph(BaseModel):
|
|||||||
|
|
||||||
# Validate that all node ids are unique
|
# Validate that all node ids are unique
|
||||||
node_ids = [n.id for n in self.nodes.values()]
|
node_ids = [n.id for n in self.nodes.values()]
|
||||||
duplicate_node_ids = {node_id for node_id in node_ids if node_ids.count(node_id) >= 2}
|
duplicate_node_ids = set([node_id for node_id in node_ids if node_ids.count(node_id) >= 2])
|
||||||
if duplicate_node_ids:
|
if duplicate_node_ids:
|
||||||
raise DuplicateNodeIdError(f"Node ids must be unique, found duplicates {duplicate_node_ids}")
|
raise DuplicateNodeIdError(f"Node ids must be unique, found duplicates {duplicate_node_ids}")
|
||||||
|
|
||||||
@ -616,7 +616,7 @@ class Graph(BaseModel):
|
|||||||
self, node_path: str, prefix: Optional[str] = None
|
self, node_path: str, prefix: Optional[str] = None
|
||||||
) -> list[tuple["Graph", Union[str, None], Edge]]:
|
) -> list[tuple["Graph", Union[str, None], Edge]]:
|
||||||
"""Gets all input edges for a node along with the graph they are in and the graph's path"""
|
"""Gets all input edges for a node along with the graph they are in and the graph's path"""
|
||||||
edges = []
|
edges = list()
|
||||||
|
|
||||||
# Return any input edges that appear in this graph
|
# Return any input edges that appear in this graph
|
||||||
edges.extend([(self, prefix, e) for e in self.edges if e.destination.node_id == node_path])
|
edges.extend([(self, prefix, e) for e in self.edges if e.destination.node_id == node_path])
|
||||||
@ -658,7 +658,7 @@ class Graph(BaseModel):
|
|||||||
self, node_path: str, prefix: Optional[str] = None
|
self, node_path: str, prefix: Optional[str] = None
|
||||||
) -> list[tuple["Graph", Union[str, None], Edge]]:
|
) -> list[tuple["Graph", Union[str, None], Edge]]:
|
||||||
"""Gets all output edges for a node along with the graph they are in and the graph's path"""
|
"""Gets all output edges for a node along with the graph they are in and the graph's path"""
|
||||||
edges = []
|
edges = list()
|
||||||
|
|
||||||
# Return any input edges that appear in this graph
|
# Return any input edges that appear in this graph
|
||||||
edges.extend([(self, prefix, e) for e in self.edges if e.source.node_id == node_path])
|
edges.extend([(self, prefix, e) for e in self.edges if e.source.node_id == node_path])
|
||||||
@ -680,8 +680,8 @@ class Graph(BaseModel):
|
|||||||
new_input: Optional[EdgeConnection] = None,
|
new_input: Optional[EdgeConnection] = None,
|
||||||
new_output: Optional[EdgeConnection] = None,
|
new_output: Optional[EdgeConnection] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
inputs = [e.source for e in self._get_input_edges(node_path, "collection")]
|
inputs = list([e.source for e in self._get_input_edges(node_path, "collection")])
|
||||||
outputs = [e.destination for e in self._get_output_edges(node_path, "item")]
|
outputs = list([e.destination for e in self._get_output_edges(node_path, "item")])
|
||||||
|
|
||||||
if new_input is not None:
|
if new_input is not None:
|
||||||
inputs.append(new_input)
|
inputs.append(new_input)
|
||||||
@ -694,7 +694,7 @@ class Graph(BaseModel):
|
|||||||
|
|
||||||
# Get input and output fields (the fields linked to the iterator's input/output)
|
# Get input and output fields (the fields linked to the iterator's input/output)
|
||||||
input_field = get_output_field(self.get_node(inputs[0].node_id), inputs[0].field)
|
input_field = get_output_field(self.get_node(inputs[0].node_id), inputs[0].field)
|
||||||
output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
|
output_fields = list([get_input_field(self.get_node(e.node_id), e.field) for e in outputs])
|
||||||
|
|
||||||
# Input type must be a list
|
# Input type must be a list
|
||||||
if get_origin(input_field) != list:
|
if get_origin(input_field) != list:
|
||||||
@ -713,8 +713,8 @@ class Graph(BaseModel):
|
|||||||
new_input: Optional[EdgeConnection] = None,
|
new_input: Optional[EdgeConnection] = None,
|
||||||
new_output: Optional[EdgeConnection] = None,
|
new_output: Optional[EdgeConnection] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
inputs = [e.source for e in self._get_input_edges(node_path, "item")]
|
inputs = list([e.source for e in self._get_input_edges(node_path, "item")])
|
||||||
outputs = [e.destination for e in self._get_output_edges(node_path, "collection")]
|
outputs = list([e.destination for e in self._get_output_edges(node_path, "collection")])
|
||||||
|
|
||||||
if new_input is not None:
|
if new_input is not None:
|
||||||
inputs.append(new_input)
|
inputs.append(new_input)
|
||||||
@ -722,16 +722,18 @@ class Graph(BaseModel):
|
|||||||
outputs.append(new_output)
|
outputs.append(new_output)
|
||||||
|
|
||||||
# Get input and output fields (the fields linked to the iterator's input/output)
|
# Get input and output fields (the fields linked to the iterator's input/output)
|
||||||
input_fields = [get_output_field(self.get_node(e.node_id), e.field) for e in inputs]
|
input_fields = list([get_output_field(self.get_node(e.node_id), e.field) for e in inputs])
|
||||||
output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
|
output_fields = list([get_input_field(self.get_node(e.node_id), e.field) for e in outputs])
|
||||||
|
|
||||||
# Validate that all inputs are derived from or match a single type
|
# Validate that all inputs are derived from or match a single type
|
||||||
input_field_types = {
|
input_field_types = set(
|
||||||
|
[
|
||||||
t
|
t
|
||||||
for input_field in input_fields
|
for input_field in input_fields
|
||||||
for t in ([input_field] if get_origin(input_field) is None else get_args(input_field))
|
for t in ([input_field] if get_origin(input_field) is None else get_args(input_field))
|
||||||
if t != NoneType
|
if t != NoneType
|
||||||
} # Get unique types
|
]
|
||||||
|
) # Get unique types
|
||||||
type_tree = nx.DiGraph()
|
type_tree = nx.DiGraph()
|
||||||
type_tree.add_nodes_from(input_field_types)
|
type_tree.add_nodes_from(input_field_types)
|
||||||
type_tree.add_edges_from([e for e in itertools.permutations(input_field_types, 2) if issubclass(e[1], e[0])])
|
type_tree.add_edges_from([e for e in itertools.permutations(input_field_types, 2) if issubclass(e[1], e[0])])
|
||||||
@ -759,15 +761,15 @@ class Graph(BaseModel):
|
|||||||
"""Returns a NetworkX DiGraph representing the layout of this graph"""
|
"""Returns a NetworkX DiGraph representing the layout of this graph"""
|
||||||
# TODO: Cache this?
|
# TODO: Cache this?
|
||||||
g = nx.DiGraph()
|
g = nx.DiGraph()
|
||||||
g.add_nodes_from(list(self.nodes.keys()))
|
g.add_nodes_from([n for n in self.nodes.keys()])
|
||||||
g.add_edges_from({(e.source.node_id, e.destination.node_id) for e in self.edges})
|
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
|
||||||
return g
|
return g
|
||||||
|
|
||||||
def nx_graph_with_data(self) -> nx.DiGraph:
|
def nx_graph_with_data(self) -> nx.DiGraph:
|
||||||
"""Returns a NetworkX DiGraph representing the data and layout of this graph"""
|
"""Returns a NetworkX DiGraph representing the data and layout of this graph"""
|
||||||
g = nx.DiGraph()
|
g = nx.DiGraph()
|
||||||
g.add_nodes_from(list(self.nodes.items()))
|
g.add_nodes_from([n for n in self.nodes.items()])
|
||||||
g.add_edges_from({(e.source.node_id, e.destination.node_id) for e in self.edges})
|
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
|
||||||
return g
|
return g
|
||||||
|
|
||||||
def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None) -> nx.DiGraph:
|
def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None) -> nx.DiGraph:
|
||||||
@ -789,7 +791,7 @@ class Graph(BaseModel):
|
|||||||
|
|
||||||
# TODO: figure out if iteration nodes need to be expanded
|
# TODO: figure out if iteration nodes need to be expanded
|
||||||
|
|
||||||
unique_edges = {(e.source.node_id, e.destination.node_id) for e in self.edges}
|
unique_edges = set([(e.source.node_id, e.destination.node_id) for e in self.edges])
|
||||||
g.add_edges_from([(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix)) for e in unique_edges])
|
g.add_edges_from([(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix)) for e in unique_edges])
|
||||||
return g
|
return g
|
||||||
|
|
||||||
@ -841,8 +843,8 @@ class GraphExecutionState(BaseModel):
|
|||||||
return v
|
return v
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
json_schema_extra={
|
json_schema_extra=dict(
|
||||||
"required": [
|
required=[
|
||||||
"id",
|
"id",
|
||||||
"graph",
|
"graph",
|
||||||
"execution_graph",
|
"execution_graph",
|
||||||
@ -853,7 +855,7 @@ class GraphExecutionState(BaseModel):
|
|||||||
"prepared_source_mapping",
|
"prepared_source_mapping",
|
||||||
"source_prepared_mapping",
|
"source_prepared_mapping",
|
||||||
]
|
]
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def next(self) -> Optional[BaseInvocation]:
|
def next(self) -> Optional[BaseInvocation]:
|
||||||
@ -893,7 +895,7 @@ class GraphExecutionState(BaseModel):
|
|||||||
source_node = self.prepared_source_mapping[node_id]
|
source_node = self.prepared_source_mapping[node_id]
|
||||||
prepared_nodes = self.source_prepared_mapping[source_node]
|
prepared_nodes = self.source_prepared_mapping[source_node]
|
||||||
|
|
||||||
if all(n in self.executed for n in prepared_nodes):
|
if all([n in self.executed for n in prepared_nodes]):
|
||||||
self.executed.add(source_node)
|
self.executed.add(source_node)
|
||||||
self.executed_history.append(source_node)
|
self.executed_history.append(source_node)
|
||||||
|
|
||||||
@ -928,7 +930,7 @@ class GraphExecutionState(BaseModel):
|
|||||||
input_collection = getattr(input_collection_prepared_node_output, input_collection_edge.source.field)
|
input_collection = getattr(input_collection_prepared_node_output, input_collection_edge.source.field)
|
||||||
self_iteration_count = len(input_collection)
|
self_iteration_count = len(input_collection)
|
||||||
|
|
||||||
new_nodes: list[str] = []
|
new_nodes: list[str] = list()
|
||||||
if self_iteration_count == 0:
|
if self_iteration_count == 0:
|
||||||
# TODO: should this raise a warning? It might just happen if an empty collection is input, and should be valid.
|
# TODO: should this raise a warning? It might just happen if an empty collection is input, and should be valid.
|
||||||
return new_nodes
|
return new_nodes
|
||||||
@ -938,7 +940,7 @@ class GraphExecutionState(BaseModel):
|
|||||||
|
|
||||||
# Create new edges for this iteration
|
# Create new edges for this iteration
|
||||||
# For collect nodes, this may contain multiple inputs to the same field
|
# For collect nodes, this may contain multiple inputs to the same field
|
||||||
new_edges: list[Edge] = []
|
new_edges: list[Edge] = list()
|
||||||
for edge in input_edges:
|
for edge in input_edges:
|
||||||
for input_node_id in (n[1] for n in iteration_node_map if n[0] == edge.source.node_id):
|
for input_node_id in (n[1] for n in iteration_node_map if n[0] == edge.source.node_id):
|
||||||
new_edge = Edge(
|
new_edge = Edge(
|
||||||
@ -1032,7 +1034,7 @@ class GraphExecutionState(BaseModel):
|
|||||||
|
|
||||||
# Create execution nodes
|
# Create execution nodes
|
||||||
next_node = self.graph.get_node(next_node_id)
|
next_node = self.graph.get_node(next_node_id)
|
||||||
new_node_ids = []
|
new_node_ids = list()
|
||||||
if isinstance(next_node, CollectInvocation):
|
if isinstance(next_node, CollectInvocation):
|
||||||
# Collapse all iterator input mappings and create a single execution node for the collect invocation
|
# Collapse all iterator input mappings and create a single execution node for the collect invocation
|
||||||
all_iteration_mappings = list(
|
all_iteration_mappings = list(
|
||||||
@ -1053,10 +1055,7 @@ class GraphExecutionState(BaseModel):
|
|||||||
# For every iterator, the parent must either not be a child of that iterator, or must match the prepared iteration for that iterator
|
# For every iterator, the parent must either not be a child of that iterator, or must match the prepared iteration for that iterator
|
||||||
# TODO: Handle a node mapping to none
|
# TODO: Handle a node mapping to none
|
||||||
eg = self.execution_graph.nx_graph_flat()
|
eg = self.execution_graph.nx_graph_flat()
|
||||||
prepared_parent_mappings = [
|
prepared_parent_mappings = [[(n, self._get_iteration_node(n, g, eg, it)) for n in next_node_parents] for it in iterator_node_prepared_combinations] # type: ignore
|
||||||
[(n, self._get_iteration_node(n, g, eg, it)) for n in next_node_parents]
|
|
||||||
for it in iterator_node_prepared_combinations
|
|
||||||
] # type: ignore
|
|
||||||
|
|
||||||
# Create execution node for each iteration
|
# Create execution node for each iteration
|
||||||
for iteration_mappings in prepared_parent_mappings:
|
for iteration_mappings in prepared_parent_mappings:
|
||||||
@ -1122,7 +1121,7 @@ class GraphExecutionState(BaseModel):
|
|||||||
for edge in input_edges
|
for edge in input_edges
|
||||||
if edge.destination.field == "item"
|
if edge.destination.field == "item"
|
||||||
]
|
]
|
||||||
node.collection = output_collection
|
setattr(node, "collection", output_collection)
|
||||||
else:
|
else:
|
||||||
for edge in input_edges:
|
for edge in input_edges:
|
||||||
output_value = getattr(self.results[edge.source.node_id], edge.source.field)
|
output_value = getattr(self.results[edge.source.node_id], edge.source.field)
|
||||||
@ -1202,7 +1201,7 @@ class LibraryGraph(BaseModel):
|
|||||||
|
|
||||||
@field_validator("exposed_inputs", "exposed_outputs")
|
@field_validator("exposed_inputs", "exposed_outputs")
|
||||||
def validate_exposed_aliases(cls, v: list[Union[ExposedNodeInput, ExposedNodeOutput]]):
|
def validate_exposed_aliases(cls, v: list[Union[ExposedNodeInput, ExposedNodeOutput]]):
|
||||||
if len(v) != len({i.alias for i in v}):
|
if len(v) != len(set(i.alias for i in v)):
|
||||||
raise ValueError("Duplicate exposed alias")
|
raise ValueError("Duplicate exposed alias")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
@ -1,5 +0,0 @@
|
|||||||
"""
|
|
||||||
This module contains various classes, functions and models which are shared across the app, particularly by invocations.
|
|
||||||
|
|
||||||
Lifting these classes, functions and models into this shared module helps to reduce circular imports.
|
|
||||||
"""
|
|
@ -1,66 +0,0 @@
|
|||||||
class FieldDescriptions:
|
|
||||||
denoising_start = "When to start denoising, expressed a percentage of total steps"
|
|
||||||
denoising_end = "When to stop denoising, expressed a percentage of total steps"
|
|
||||||
cfg_scale = "Classifier-Free Guidance scale"
|
|
||||||
scheduler = "Scheduler to use during inference"
|
|
||||||
positive_cond = "Positive conditioning tensor"
|
|
||||||
negative_cond = "Negative conditioning tensor"
|
|
||||||
noise = "Noise tensor"
|
|
||||||
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
|
|
||||||
unet = "UNet (scheduler, LoRAs)"
|
|
||||||
vae = "VAE"
|
|
||||||
cond = "Conditioning tensor"
|
|
||||||
controlnet_model = "ControlNet model to load"
|
|
||||||
vae_model = "VAE model to load"
|
|
||||||
lora_model = "LoRA model to load"
|
|
||||||
main_model = "Main model (UNet, VAE, CLIP) to load"
|
|
||||||
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
|
|
||||||
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
|
|
||||||
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
|
|
||||||
lora_weight = "The weight at which the LoRA is applied to each model"
|
|
||||||
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"
|
|
||||||
raw_prompt = "Raw prompt text (no parsing)"
|
|
||||||
sdxl_aesthetic = "The aesthetic score to apply to the conditioning tensor"
|
|
||||||
skipped_layers = "Number of layers to skip in text encoder"
|
|
||||||
seed = "Seed for random number generation"
|
|
||||||
steps = "Number of steps to run"
|
|
||||||
width = "Width of output (px)"
|
|
||||||
height = "Height of output (px)"
|
|
||||||
control = "ControlNet(s) to apply"
|
|
||||||
ip_adapter = "IP-Adapter to apply"
|
|
||||||
t2i_adapter = "T2I-Adapter(s) to apply"
|
|
||||||
denoised_latents = "Denoised latents tensor"
|
|
||||||
latents = "Latents tensor"
|
|
||||||
strength = "Strength of denoising (proportional to steps)"
|
|
||||||
metadata = "Optional metadata to be saved with the image"
|
|
||||||
metadata_collection = "Collection of Metadata"
|
|
||||||
metadata_item_polymorphic = "A single metadata item or collection of metadata items"
|
|
||||||
metadata_item_label = "Label for this metadata item"
|
|
||||||
metadata_item_value = "The value for this metadata item (may be any type)"
|
|
||||||
workflow = "Optional workflow to be saved with the image"
|
|
||||||
interp_mode = "Interpolation mode"
|
|
||||||
torch_antialias = "Whether or not to apply antialiasing (bilinear or bicubic only)"
|
|
||||||
fp32 = "Whether or not to use full float32 precision"
|
|
||||||
precision = "Precision to use"
|
|
||||||
tiled = "Processing using overlapping tiles (reduce memory consumption)"
|
|
||||||
detect_res = "Pixel resolution for detection"
|
|
||||||
image_res = "Pixel resolution for output image"
|
|
||||||
safe_mode = "Whether or not to use safe mode"
|
|
||||||
scribble_mode = "Whether or not to use scribble mode"
|
|
||||||
scale_factor = "The factor by which to scale"
|
|
||||||
blend_alpha = (
|
|
||||||
"Blending factor. 0.0 = use input A only, 1.0 = use input B only, 0.5 = 50% mix of input A and input B."
|
|
||||||
)
|
|
||||||
num_1 = "The first number"
|
|
||||||
num_2 = "The second number"
|
|
||||||
mask = "The mask to use for the operation"
|
|
||||||
board = "The board to save the image to"
|
|
||||||
image = "The image to process"
|
|
||||||
tile_size = "Tile size"
|
|
||||||
inclusive_low = "The inclusive low value"
|
|
||||||
exclusive_high = "The exclusive high value"
|
|
||||||
decimal_places = "The number of decimal places to round to"
|
|
||||||
freeu_s1 = 'Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.'
|
|
||||||
freeu_s2 = 'Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.'
|
|
||||||
freeu_b1 = "Scaling factor for stage 1 to amplify the contributions of backbone features."
|
|
||||||
freeu_b2 = "Scaling factor for stage 2 to amplify the contributions of backbone features."
|
|
@ -1,16 +0,0 @@
|
|||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from invokeai.app.shared.fields import FieldDescriptions
|
|
||||||
|
|
||||||
|
|
||||||
class FreeUConfig(BaseModel):
|
|
||||||
"""
|
|
||||||
Configuration for the FreeU hyperparameters.
|
|
||||||
- https://huggingface.co/docs/diffusers/main/en/using-diffusers/freeu
|
|
||||||
- https://github.com/ChenyangSi/FreeU
|
|
||||||
"""
|
|
||||||
|
|
||||||
s1: float = Field(ge=-1, le=3, description=FieldDescriptions.freeu_s1)
|
|
||||||
s2: float = Field(ge=-1, le=3, description=FieldDescriptions.freeu_s2)
|
|
||||||
b1: float = Field(ge=-1, le=3, description=FieldDescriptions.freeu_b1)
|
|
||||||
b2: float = Field(ge=-1, le=3, description=FieldDescriptions.freeu_b2)
|
|
@ -59,7 +59,7 @@ def thin_one_time(x, kernels):
|
|||||||
|
|
||||||
def lvmin_thin(x, prunings=True):
|
def lvmin_thin(x, prunings=True):
|
||||||
y = x
|
y = x
|
||||||
for _i in range(32):
|
for i in range(32):
|
||||||
y, is_done = thin_one_time(y, lvmin_kernels)
|
y, is_done = thin_one_time(y, lvmin_kernels)
|
||||||
if is_done:
|
if is_done:
|
||||||
break
|
break
|
||||||
|
@ -21,11 +21,11 @@ def get_metadata_graph_from_raw_session(session_raw: str) -> Optional[dict]:
|
|||||||
|
|
||||||
# sanity check make sure the graph is at least reasonably shaped
|
# sanity check make sure the graph is at least reasonably shaped
|
||||||
if (
|
if (
|
||||||
not isinstance(graph, dict)
|
type(graph) is not dict
|
||||||
or "nodes" not in graph
|
or "nodes" not in graph
|
||||||
or not isinstance(graph["nodes"], dict)
|
or type(graph["nodes"]) is not dict
|
||||||
or "edges" not in graph
|
or "edges" not in graph
|
||||||
or not isinstance(graph["edges"], list)
|
or type(graph["edges"]) is not list
|
||||||
):
|
):
|
||||||
# something has gone terribly awry, return an empty dict
|
# something has gone terribly awry, return an empty dict
|
||||||
return None
|
return None
|
||||||
|
@ -88,7 +88,7 @@ class PromptFormatter:
|
|||||||
t2i = self.t2i
|
t2i = self.t2i
|
||||||
opt = self.opt
|
opt = self.opt
|
||||||
|
|
||||||
switches = []
|
switches = list()
|
||||||
switches.append(f'"{opt.prompt}"')
|
switches.append(f'"{opt.prompt}"')
|
||||||
switches.append(f"-s{opt.steps or t2i.steps}")
|
switches.append(f"-s{opt.steps or t2i.steps}")
|
||||||
switches.append(f"-W{opt.width or t2i.width}")
|
switches.append(f"-W{opt.width or t2i.width}")
|
||||||
|
@ -88,7 +88,7 @@ class Txt2Mask(object):
|
|||||||
provided image and returns a SegmentedGrayscale object in which the brighter
|
provided image and returns a SegmentedGrayscale object in which the brighter
|
||||||
pixels indicate where the object is inferred to be.
|
pixels indicate where the object is inferred to be.
|
||||||
"""
|
"""
|
||||||
if isinstance(image, str):
|
if type(image) is str:
|
||||||
image = Image.open(image).convert("RGB")
|
image = Image.open(image).convert("RGB")
|
||||||
|
|
||||||
image = ImageOps.exif_transpose(image)
|
image = ImageOps.exif_transpose(image)
|
||||||
|
@ -40,7 +40,7 @@ class InitImageResizer:
|
|||||||
(rw, rh) = (int(scale * im.width), int(scale * im.height))
|
(rw, rh) = (int(scale * im.width), int(scale * im.height))
|
||||||
|
|
||||||
# round everything to multiples of 64
|
# round everything to multiples of 64
|
||||||
width, height, rw, rh = (x - x % 64 for x in (width, height, rw, rh))
|
width, height, rw, rh = map(lambda x: x - x % 64, (width, height, rw, rh))
|
||||||
|
|
||||||
# no resize necessary, but return a copy
|
# no resize necessary, but return a copy
|
||||||
if im.width == width and im.height == height:
|
if im.width == width and im.height == height:
|
||||||
|
@ -32,7 +32,7 @@ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionS
|
|||||||
from huggingface_hub import HfFolder
|
from huggingface_hub import HfFolder
|
||||||
from huggingface_hub import login as hf_hub_login
|
from huggingface_hub import login as hf_hub_login
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from pydantic import ValidationError
|
from pydantic.error_wrappers import ValidationError
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
@ -197,7 +197,7 @@ def download_with_progress_bar(model_url: str, model_dest: str, label: str = "th
|
|||||||
|
|
||||||
def download_conversion_models():
|
def download_conversion_models():
|
||||||
target_dir = config.models_path / "core/convert"
|
target_dir = config.models_path / "core/convert"
|
||||||
kwargs = {} # for future use
|
kwargs = dict() # for future use
|
||||||
try:
|
try:
|
||||||
logger.info("Downloading core tokenizers and text encoders")
|
logger.info("Downloading core tokenizers and text encoders")
|
||||||
|
|
||||||
@ -252,26 +252,26 @@ def download_conversion_models():
|
|||||||
def download_realesrgan():
|
def download_realesrgan():
|
||||||
logger.info("Installing ESRGAN Upscaling models...")
|
logger.info("Installing ESRGAN Upscaling models...")
|
||||||
URLs = [
|
URLs = [
|
||||||
{
|
dict(
|
||||||
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||||
"dest": "core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
|
dest="core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
|
||||||
"description": "RealESRGAN_x4plus.pth",
|
description="RealESRGAN_x4plus.pth",
|
||||||
},
|
),
|
||||||
{
|
dict(
|
||||||
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
||||||
"dest": "core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
|
dest="core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
|
||||||
"description": "RealESRGAN_x4plus_anime_6B.pth",
|
description="RealESRGAN_x4plus_anime_6B.pth",
|
||||||
},
|
),
|
||||||
{
|
dict(
|
||||||
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||||
"dest": "core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
dest="core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||||
"description": "ESRGAN_SRx4_DF2KOST_official.pth",
|
description="ESRGAN_SRx4_DF2KOST_official.pth",
|
||||||
},
|
),
|
||||||
{
|
dict(
|
||||||
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||||
"dest": "core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
|
dest="core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
|
||||||
"description": "RealESRGAN_x2plus.pth",
|
description="RealESRGAN_x2plus.pth",
|
||||||
},
|
),
|
||||||
]
|
]
|
||||||
for model in URLs:
|
for model in URLs:
|
||||||
download_with_progress_bar(model["url"], config.models_path / model["dest"], model["description"])
|
download_with_progress_bar(model["url"], config.models_path / model["dest"], model["description"])
|
||||||
@ -680,7 +680,7 @@ def default_user_selections(program_opts: Namespace) -> InstallSelections:
|
|||||||
if program_opts.default_only
|
if program_opts.default_only
|
||||||
else [models[x].path or models[x].repo_id for x in installer.recommended_models()]
|
else [models[x].path or models[x].repo_id for x in installer.recommended_models()]
|
||||||
if program_opts.yes_to_all
|
if program_opts.yes_to_all
|
||||||
else [],
|
else list(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -38,7 +38,6 @@ SAMPLER_CHOICES = [
|
|||||||
"k_heun",
|
"k_heun",
|
||||||
"k_lms",
|
"k_lms",
|
||||||
"plms",
|
"plms",
|
||||||
"lcm",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
PRECISION_CHOICES = [
|
PRECISION_CHOICES = [
|
||||||
|
@ -123,6 +123,8 @@ class MigrateTo3(object):
|
|||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
raise
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(str(e))
|
||||||
for f in files:
|
for f in files:
|
||||||
# don't copy raw learned_embeds.bin or pytorch_lora_weights.bin
|
# don't copy raw learned_embeds.bin or pytorch_lora_weights.bin
|
||||||
# let them be copied as part of a tree copy operation
|
# let them be copied as part of a tree copy operation
|
||||||
@ -141,6 +143,8 @@ class MigrateTo3(object):
|
|||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
raise
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
|
||||||
def migrate_support_models(self):
|
def migrate_support_models(self):
|
||||||
"""
|
"""
|
||||||
@ -178,10 +182,10 @@ class MigrateTo3(object):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
dest_directory = self.dest_models
|
dest_directory = self.dest_models
|
||||||
kwargs = {
|
kwargs = dict(
|
||||||
"cache_dir": self.root_directory / "models/hub",
|
cache_dir=self.root_directory / "models/hub",
|
||||||
# local_files_only = True
|
# local_files_only = True
|
||||||
}
|
)
|
||||||
try:
|
try:
|
||||||
logger.info("Migrating core tokenizers and text encoders")
|
logger.info("Migrating core tokenizers and text encoders")
|
||||||
target_dir = dest_directory / "core" / "convert"
|
target_dir = dest_directory / "core" / "convert"
|
||||||
@ -312,11 +316,11 @@ class MigrateTo3(object):
|
|||||||
dest_dir = self.dest_models
|
dest_dir = self.dest_models
|
||||||
|
|
||||||
cache = self.root_directory / "models/hub"
|
cache = self.root_directory / "models/hub"
|
||||||
kwargs = {
|
kwargs = dict(
|
||||||
"cache_dir": cache,
|
cache_dir=cache,
|
||||||
"safety_checker": None,
|
safety_checker=None,
|
||||||
# local_files_only = True,
|
# local_files_only = True,
|
||||||
}
|
)
|
||||||
|
|
||||||
owner, repo_name = repo_id.split("/")
|
owner, repo_name = repo_id.split("/")
|
||||||
model_name = model_name or repo_name
|
model_name = model_name or repo_name
|
||||||
|
@ -120,7 +120,7 @@ class ModelInstall(object):
|
|||||||
be treated uniformly. It also sorts the models alphabetically
|
be treated uniformly. It also sorts the models alphabetically
|
||||||
by their name, to improve the display somewhat.
|
by their name, to improve the display somewhat.
|
||||||
"""
|
"""
|
||||||
model_dict = {}
|
model_dict = dict()
|
||||||
|
|
||||||
# first populate with the entries in INITIAL_MODELS.yaml
|
# first populate with the entries in INITIAL_MODELS.yaml
|
||||||
for key, value in self.datasets.items():
|
for key, value in self.datasets.items():
|
||||||
@ -134,7 +134,7 @@ class ModelInstall(object):
|
|||||||
model_dict[key] = model_info
|
model_dict[key] = model_info
|
||||||
|
|
||||||
# supplement with entries in models.yaml
|
# supplement with entries in models.yaml
|
||||||
installed_models = list(self.mgr.list_models())
|
installed_models = [x for x in self.mgr.list_models()]
|
||||||
|
|
||||||
for md in installed_models:
|
for md in installed_models:
|
||||||
base = md["base_model"]
|
base = md["base_model"]
|
||||||
@ -176,7 +176,7 @@ class ModelInstall(object):
|
|||||||
# logic here a little reversed to maintain backward compatibility
|
# logic here a little reversed to maintain backward compatibility
|
||||||
def starter_models(self, all_models: bool = False) -> Set[str]:
|
def starter_models(self, all_models: bool = False) -> Set[str]:
|
||||||
models = set()
|
models = set()
|
||||||
for key, _value in self.datasets.items():
|
for key, value in self.datasets.items():
|
||||||
name, base, model_type = ModelManager.parse_key(key)
|
name, base, model_type = ModelManager.parse_key(key)
|
||||||
if all_models or model_type in [ModelType.Main, ModelType.Vae]:
|
if all_models or model_type in [ModelType.Main, ModelType.Vae]:
|
||||||
models.add(key)
|
models.add(key)
|
||||||
@ -184,7 +184,7 @@ class ModelInstall(object):
|
|||||||
|
|
||||||
def recommended_models(self) -> Set[str]:
|
def recommended_models(self) -> Set[str]:
|
||||||
starters = self.starter_models(all_models=True)
|
starters = self.starter_models(all_models=True)
|
||||||
return {x for x in starters if self.datasets[x].get("recommended", False)}
|
return set([x for x in starters if self.datasets[x].get("recommended", False)])
|
||||||
|
|
||||||
def default_model(self) -> str:
|
def default_model(self) -> str:
|
||||||
starters = self.starter_models()
|
starters = self.starter_models()
|
||||||
@ -234,7 +234,7 @@ class ModelInstall(object):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
if not models_installed:
|
if not models_installed:
|
||||||
models_installed = {}
|
models_installed = dict()
|
||||||
|
|
||||||
model_path_id_or_url = str(model_path_id_or_url).strip("\"' ")
|
model_path_id_or_url = str(model_path_id_or_url).strip("\"' ")
|
||||||
|
|
||||||
@ -252,14 +252,10 @@ class ModelInstall(object):
|
|||||||
|
|
||||||
# folders style or similar
|
# folders style or similar
|
||||||
elif path.is_dir() and any(
|
elif path.is_dir() and any(
|
||||||
|
[
|
||||||
(path / x).exists()
|
(path / x).exists()
|
||||||
for x in {
|
for x in {"config.json", "model_index.json", "learned_embeds.bin", "pytorch_lora_weights.bin"}
|
||||||
"config.json",
|
]
|
||||||
"model_index.json",
|
|
||||||
"learned_embeds.bin",
|
|
||||||
"pytorch_lora_weights.bin",
|
|
||||||
"pytorch_lora_weights.safetensors",
|
|
||||||
}
|
|
||||||
):
|
):
|
||||||
models_installed.update({str(model_path_id_or_url): self._install_path(path)})
|
models_installed.update({str(model_path_id_or_url): self._install_path(path)})
|
||||||
|
|
||||||
@ -361,7 +357,7 @@ class ModelInstall(object):
|
|||||||
for suffix in ["safetensors", "bin"]:
|
for suffix in ["safetensors", "bin"]:
|
||||||
if f"{prefix}pytorch_lora_weights.{suffix}" in files:
|
if f"{prefix}pytorch_lora_weights.{suffix}" in files:
|
||||||
location = self._download_hf_model(
|
location = self._download_hf_model(
|
||||||
repo_id, [f"pytorch_lora_weights.{suffix}"], staging, subfolder=subfolder
|
repo_id, ["pytorch_lora_weights.bin"], staging, subfolder=subfolder
|
||||||
) # LoRA
|
) # LoRA
|
||||||
break
|
break
|
||||||
elif (
|
elif (
|
||||||
@ -431,17 +427,17 @@ class ModelInstall(object):
|
|||||||
|
|
||||||
rel_path = self.relative_to_root(path, self.config.models_path)
|
rel_path = self.relative_to_root(path, self.config.models_path)
|
||||||
|
|
||||||
attributes = {
|
attributes = dict(
|
||||||
"path": str(rel_path),
|
path=str(rel_path),
|
||||||
"description": str(description),
|
description=str(description),
|
||||||
"model_format": info.format,
|
model_format=info.format,
|
||||||
}
|
)
|
||||||
legacy_conf = None
|
legacy_conf = None
|
||||||
if info.model_type == ModelType.Main or info.model_type == ModelType.ONNX:
|
if info.model_type == ModelType.Main or info.model_type == ModelType.ONNX:
|
||||||
attributes.update(
|
attributes.update(
|
||||||
{
|
dict(
|
||||||
"variant": info.variant_type,
|
variant=info.variant_type,
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
if info.format == "checkpoint":
|
if info.format == "checkpoint":
|
||||||
try:
|
try:
|
||||||
@ -472,7 +468,7 @@ class ModelInstall(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if legacy_conf:
|
if legacy_conf:
|
||||||
attributes.update({"config": str(legacy_conf)})
|
attributes.update(dict(config=str(legacy_conf)))
|
||||||
return attributes
|
return attributes
|
||||||
|
|
||||||
def relative_to_root(self, path: Path, root: Optional[Path] = None) -> Path:
|
def relative_to_root(self, path: Path, root: Optional[Path] = None) -> Path:
|
||||||
@ -517,7 +513,7 @@ class ModelInstall(object):
|
|||||||
def _download_hf_model(self, repo_id: str, files: List[str], staging: Path, subfolder: None) -> Path:
|
def _download_hf_model(self, repo_id: str, files: List[str], staging: Path, subfolder: None) -> Path:
|
||||||
_, name = repo_id.split("/")
|
_, name = repo_id.split("/")
|
||||||
location = staging / name
|
location = staging / name
|
||||||
paths = []
|
paths = list()
|
||||||
for filename in files:
|
for filename in files:
|
||||||
filePath = Path(filename)
|
filePath = Path(filename)
|
||||||
p = hf_download_with_resume(
|
p = hf_download_with_resume(
|
||||||
|
@ -130,9 +130,7 @@ class IPAttnProcessor2_0(torch.nn.Module):
|
|||||||
assert ip_adapter_image_prompt_embeds is not None
|
assert ip_adapter_image_prompt_embeds is not None
|
||||||
assert len(ip_adapter_image_prompt_embeds) == len(self._weights)
|
assert len(ip_adapter_image_prompt_embeds) == len(self._weights)
|
||||||
|
|
||||||
for ipa_embed, ipa_weights, scale in zip(
|
for ipa_embed, ipa_weights, scale in zip(ip_adapter_image_prompt_embeds, self._weights, self._scales):
|
||||||
ip_adapter_image_prompt_embeds, self._weights, self._scales, strict=True
|
|
||||||
):
|
|
||||||
# The batch dimensions should match.
|
# The batch dimensions should match.
|
||||||
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
|
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
|
||||||
# The token_len dimensions should match.
|
# The token_len dimensions should match.
|
||||||
|
@ -56,7 +56,7 @@ class PerceiverAttention(nn.Module):
|
|||||||
x = self.norm1(x)
|
x = self.norm1(x)
|
||||||
latents = self.norm2(latents)
|
latents = self.norm2(latents)
|
||||||
|
|
||||||
b, L, _ = latents.shape
|
b, l, _ = latents.shape
|
||||||
|
|
||||||
q = self.to_q(latents)
|
q = self.to_q(latents)
|
||||||
kv_input = torch.cat((x, latents), dim=-2)
|
kv_input = torch.cat((x, latents), dim=-2)
|
||||||
@ -72,7 +72,7 @@ class PerceiverAttention(nn.Module):
|
|||||||
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||||
out = weight @ v
|
out = weight @ v
|
||||||
|
|
||||||
out = out.permute(0, 2, 1, 3).reshape(b, L, -1)
|
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
||||||
|
|
||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
|
|
||||||
|
@ -269,7 +269,7 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
|
|||||||
resolution *= 2
|
resolution *= 2
|
||||||
|
|
||||||
up_block_types = []
|
up_block_types = []
|
||||||
for _i in range(len(block_out_channels)):
|
for i in range(len(block_out_channels)):
|
||||||
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
|
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
|
||||||
up_block_types.append(block_type)
|
up_block_types.append(block_type)
|
||||||
resolution //= 2
|
resolution //= 2
|
||||||
@ -1223,7 +1223,7 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
# scan model
|
# scan model
|
||||||
scan_result = scan_file_path(checkpoint_path)
|
scan_result = scan_file_path(checkpoint_path)
|
||||||
if scan_result.infected_files != 0:
|
if scan_result.infected_files != 0:
|
||||||
raise Exception("The model {checkpoint_path} is potentially infected by malware. Aborting import.")
|
raise "The model {checkpoint_path} is potentially infected by malware. Aborting import."
|
||||||
if device is None:
|
if device is None:
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||||
@ -1664,7 +1664,7 @@ def download_controlnet_from_original_ckpt(
|
|||||||
# scan model
|
# scan model
|
||||||
scan_result = scan_file_path(checkpoint_path)
|
scan_result = scan_file_path(checkpoint_path)
|
||||||
if scan_result.infected_files != 0:
|
if scan_result.infected_files != 0:
|
||||||
raise Exception("The model {checkpoint_path} is potentially infected by malware. Aborting import.")
|
raise "The model {checkpoint_path} is potentially infected by malware. Aborting import."
|
||||||
if device is None:
|
if device is None:
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||||
|
@ -12,8 +12,6 @@ from diffusers.models import UNet2DConditionModel
|
|||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
from invokeai.app.shared.models import FreeUConfig
|
|
||||||
|
|
||||||
from .models.lora import LoRAModel
|
from .models.lora import LoRAModel
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -104,7 +102,7 @@ class ModelPatcher:
|
|||||||
loras: List[Tuple[LoRAModel, float]],
|
loras: List[Tuple[LoRAModel, float]],
|
||||||
prefix: str,
|
prefix: str,
|
||||||
):
|
):
|
||||||
original_weights = {}
|
original_weights = dict()
|
||||||
try:
|
try:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for lora, lora_weight in loras:
|
for lora, lora_weight in loras:
|
||||||
@ -166,15 +164,6 @@ class ModelPatcher:
|
|||||||
init_tokens_count = None
|
init_tokens_count = None
|
||||||
new_tokens_added = None
|
new_tokens_added = None
|
||||||
|
|
||||||
# TODO: This is required since Transformers 4.32 see
|
|
||||||
# https://github.com/huggingface/transformers/pull/25088
|
|
||||||
# More information by NVIDIA:
|
|
||||||
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
|
|
||||||
# This value might need to be changed in the future and take the GPUs model into account as there seem
|
|
||||||
# to be ideal values for different GPUS. This value is temporary!
|
|
||||||
# For references to the current discussion please see https://github.com/invoke-ai/InvokeAI/pull/4817
|
|
||||||
pad_to_multiple_of = 8
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a
|
# HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a
|
||||||
# workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after
|
# workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after
|
||||||
@ -184,7 +173,7 @@ class ModelPatcher:
|
|||||||
# but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs).
|
# but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs).
|
||||||
ti_tokenizer = pickle.loads(pickle.dumps(tokenizer))
|
ti_tokenizer = pickle.loads(pickle.dumps(tokenizer))
|
||||||
ti_manager = TextualInversionManager(ti_tokenizer)
|
ti_manager = TextualInversionManager(ti_tokenizer)
|
||||||
init_tokens_count = text_encoder.resize_token_embeddings(None, pad_to_multiple_of).num_embeddings
|
init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings
|
||||||
|
|
||||||
def _get_trigger(ti_name, index):
|
def _get_trigger(ti_name, index):
|
||||||
trigger = ti_name
|
trigger = ti_name
|
||||||
@ -199,7 +188,7 @@ class ModelPatcher:
|
|||||||
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
|
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
|
||||||
|
|
||||||
# modify text_encoder
|
# modify text_encoder
|
||||||
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added, pad_to_multiple_of)
|
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added)
|
||||||
model_embeddings = text_encoder.get_input_embeddings()
|
model_embeddings = text_encoder.get_input_embeddings()
|
||||||
|
|
||||||
for ti_name, ti in ti_list:
|
for ti_name, ti in ti_list:
|
||||||
@ -231,7 +220,7 @@ class ModelPatcher:
|
|||||||
|
|
||||||
finally:
|
finally:
|
||||||
if init_tokens_count and new_tokens_added:
|
if init_tokens_count and new_tokens_added:
|
||||||
text_encoder.resize_token_embeddings(init_tokens_count, pad_to_multiple_of)
|
text_encoder.resize_token_embeddings(init_tokens_count)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@contextmanager
|
@contextmanager
|
||||||
@ -242,7 +231,7 @@ class ModelPatcher:
|
|||||||
):
|
):
|
||||||
skipped_layers = []
|
skipped_layers = []
|
||||||
try:
|
try:
|
||||||
for _i in range(clip_skip):
|
for i in range(clip_skip):
|
||||||
skipped_layers.append(text_encoder.text_model.encoder.layers.pop(-1))
|
skipped_layers.append(text_encoder.text_model.encoder.layers.pop(-1))
|
||||||
|
|
||||||
yield
|
yield
|
||||||
@ -251,25 +240,6 @@ class ModelPatcher:
|
|||||||
while len(skipped_layers) > 0:
|
while len(skipped_layers) > 0:
|
||||||
text_encoder.text_model.encoder.layers.append(skipped_layers.pop())
|
text_encoder.text_model.encoder.layers.append(skipped_layers.pop())
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@contextmanager
|
|
||||||
def apply_freeu(
|
|
||||||
cls,
|
|
||||||
unet: UNet2DConditionModel,
|
|
||||||
freeu_config: Optional[FreeUConfig] = None,
|
|
||||||
):
|
|
||||||
did_apply_freeu = False
|
|
||||||
try:
|
|
||||||
if freeu_config is not None:
|
|
||||||
unet.enable_freeu(b1=freeu_config.b1, b2=freeu_config.b2, s1=freeu_config.s1, s2=freeu_config.s2)
|
|
||||||
did_apply_freeu = True
|
|
||||||
|
|
||||||
yield
|
|
||||||
|
|
||||||
finally:
|
|
||||||
if did_apply_freeu:
|
|
||||||
unet.disable_freeu()
|
|
||||||
|
|
||||||
|
|
||||||
class TextualInversionModel:
|
class TextualInversionModel:
|
||||||
embedding: torch.Tensor # [n, 768]|[n, 1280]
|
embedding: torch.Tensor # [n, 768]|[n, 1280]
|
||||||
@ -324,7 +294,7 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
tokenizer: CLIPTokenizer
|
tokenizer: CLIPTokenizer
|
||||||
|
|
||||||
def __init__(self, tokenizer: CLIPTokenizer):
|
def __init__(self, tokenizer: CLIPTokenizer):
|
||||||
self.pad_tokens = {}
|
self.pad_tokens = dict()
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]:
|
def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]:
|
||||||
@ -385,10 +355,10 @@ class ONNXModelPatcher:
|
|||||||
if not isinstance(model, IAIOnnxRuntimeModel):
|
if not isinstance(model, IAIOnnxRuntimeModel):
|
||||||
raise Exception("Only IAIOnnxRuntimeModel models supported")
|
raise Exception("Only IAIOnnxRuntimeModel models supported")
|
||||||
|
|
||||||
orig_weights = {}
|
orig_weights = dict()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
blended_loras = {}
|
blended_loras = dict()
|
||||||
|
|
||||||
for lora, lora_weight in loras:
|
for lora, lora_weight in loras:
|
||||||
for layer_key, layer in lora.layers.items():
|
for layer_key, layer in lora.layers.items():
|
||||||
@ -404,7 +374,7 @@ class ONNXModelPatcher:
|
|||||||
else:
|
else:
|
||||||
blended_loras[layer_key] = layer_weight
|
blended_loras[layer_key] = layer_weight
|
||||||
|
|
||||||
node_names = {}
|
node_names = dict()
|
||||||
for node in model.nodes.values():
|
for node in model.nodes.values():
|
||||||
node_names[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = node.name
|
node_names[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = node.name
|
||||||
|
|
||||||
|
@ -66,13 +66,11 @@ class CacheStats(object):
|
|||||||
|
|
||||||
class ModelLocker(object):
|
class ModelLocker(object):
|
||||||
"Forward declaration"
|
"Forward declaration"
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ModelCache(object):
|
class ModelCache(object):
|
||||||
"Forward declaration"
|
"Forward declaration"
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -134,7 +132,7 @@ class ModelCache(object):
|
|||||||
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
||||||
behaviour.
|
behaviour.
|
||||||
"""
|
"""
|
||||||
self.model_infos: Dict[str, ModelBase] = {}
|
self.model_infos: Dict[str, ModelBase] = dict()
|
||||||
# allow lazy offloading only when vram cache enabled
|
# allow lazy offloading only when vram cache enabled
|
||||||
self.lazy_offloading = lazy_offloading and max_vram_cache_size > 0
|
self.lazy_offloading = lazy_offloading and max_vram_cache_size > 0
|
||||||
self.precision: torch.dtype = precision
|
self.precision: torch.dtype = precision
|
||||||
@ -149,8 +147,8 @@ class ModelCache(object):
|
|||||||
# used for stats collection
|
# used for stats collection
|
||||||
self.stats = None
|
self.stats = None
|
||||||
|
|
||||||
self._cached_models = {}
|
self._cached_models = dict()
|
||||||
self._cache_stack = []
|
self._cache_stack = list()
|
||||||
|
|
||||||
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
|
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
|
||||||
if self._log_memory_usage:
|
if self._log_memory_usage:
|
||||||
|
@ -26,5 +26,5 @@ def skip_torch_weight_init():
|
|||||||
|
|
||||||
yield None
|
yield None
|
||||||
finally:
|
finally:
|
||||||
for torch_module, saved_function in zip(torch_modules, saved_functions, strict=True):
|
for torch_module, saved_function in zip(torch_modules, saved_functions):
|
||||||
torch_module.reset_parameters = saved_function
|
torch_module.reset_parameters = saved_function
|
||||||
|
@ -363,7 +363,7 @@ class ModelManager(object):
|
|||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.models = {}
|
self.models = dict()
|
||||||
for model_key, model_config in config.items():
|
for model_key, model_config in config.items():
|
||||||
if model_key.startswith("_"):
|
if model_key.startswith("_"):
|
||||||
continue
|
continue
|
||||||
@ -374,7 +374,7 @@ class ModelManager(object):
|
|||||||
self.models[model_key] = model_class.create_config(**model_config)
|
self.models[model_key] = model_class.create_config(**model_config)
|
||||||
|
|
||||||
# check config version number and update on disk/RAM if necessary
|
# check config version number and update on disk/RAM if necessary
|
||||||
self.cache_keys = {}
|
self.cache_keys = dict()
|
||||||
|
|
||||||
# add controlnet, lora and textual_inversion models from disk
|
# add controlnet, lora and textual_inversion models from disk
|
||||||
self.scan_models_directory()
|
self.scan_models_directory()
|
||||||
@ -655,7 +655,7 @@ class ModelManager(object):
|
|||||||
"""
|
"""
|
||||||
# TODO: redo
|
# TODO: redo
|
||||||
for model_dict in self.list_models():
|
for model_dict in self.list_models():
|
||||||
for _model_name, model_info in model_dict.items():
|
for model_name, model_info in model_dict.items():
|
||||||
line = f'{model_info["name"]:25s} {model_info["type"]:10s} {model_info["description"]}'
|
line = f'{model_info["name"]:25s} {model_info["type"]:10s} {model_info["description"]}'
|
||||||
print(line)
|
print(line)
|
||||||
|
|
||||||
@ -902,7 +902,7 @@ class ModelManager(object):
|
|||||||
"""
|
"""
|
||||||
Write current configuration out to the indicated file.
|
Write current configuration out to the indicated file.
|
||||||
"""
|
"""
|
||||||
data_to_save = {}
|
data_to_save = dict()
|
||||||
data_to_save["__metadata__"] = self.config_meta.model_dump()
|
data_to_save["__metadata__"] = self.config_meta.model_dump()
|
||||||
|
|
||||||
for model_key, model_config in self.models.items():
|
for model_key, model_config in self.models.items():
|
||||||
@ -1034,7 +1034,7 @@ class ModelManager(object):
|
|||||||
self.ignore = ignore
|
self.ignore = ignore
|
||||||
|
|
||||||
def on_search_started(self):
|
def on_search_started(self):
|
||||||
self.new_models_found = {}
|
self.new_models_found = dict()
|
||||||
|
|
||||||
def on_model_found(self, model: Path):
|
def on_model_found(self, model: Path):
|
||||||
if model not in self.ignore:
|
if model not in self.ignore:
|
||||||
@ -1106,7 +1106,7 @@ class ModelManager(object):
|
|||||||
# avoid circular import here
|
# avoid circular import here
|
||||||
from invokeai.backend.install.model_install_backend import ModelInstall
|
from invokeai.backend.install.model_install_backend import ModelInstall
|
||||||
|
|
||||||
successfully_installed = {}
|
successfully_installed = dict()
|
||||||
|
|
||||||
installer = ModelInstall(
|
installer = ModelInstall(
|
||||||
config=self.app_config, prediction_type_helper=prediction_type_helper, model_manager=self
|
config=self.app_config, prediction_type_helper=prediction_type_helper, model_manager=self
|
||||||
|
@ -92,7 +92,7 @@ class ModelMerger(object):
|
|||||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||||
"""
|
"""
|
||||||
model_paths = []
|
model_paths = list()
|
||||||
config = self.manager.app_config
|
config = self.manager.app_config
|
||||||
base_model = BaseModelType(base_model)
|
base_model = BaseModelType(base_model)
|
||||||
vae = None
|
vae = None
|
||||||
@ -124,13 +124,13 @@ class ModelMerger(object):
|
|||||||
dump_path = (dump_path / merged_model_name).as_posix()
|
dump_path = (dump_path / merged_model_name).as_posix()
|
||||||
|
|
||||||
merged_pipe.save_pretrained(dump_path, safe_serialization=True)
|
merged_pipe.save_pretrained(dump_path, safe_serialization=True)
|
||||||
attributes = {
|
attributes = dict(
|
||||||
"path": dump_path,
|
path=dump_path,
|
||||||
"description": f"Merge of models {', '.join(model_names)}",
|
description=f"Merge of models {', '.join(model_names)}",
|
||||||
"model_format": "diffusers",
|
model_format="diffusers",
|
||||||
"variant": ModelVariantType.Normal.value,
|
variant=ModelVariantType.Normal.value,
|
||||||
"vae": vae,
|
vae=vae,
|
||||||
}
|
)
|
||||||
return self.manager.add_model(
|
return self.manager.add_model(
|
||||||
merged_model_name,
|
merged_model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
|
@ -183,13 +183,12 @@ class ModelProbe(object):
|
|||||||
if model:
|
if model:
|
||||||
class_name = model.__class__.__name__
|
class_name = model.__class__.__name__
|
||||||
else:
|
else:
|
||||||
for suffix in ["bin", "safetensors"]:
|
|
||||||
if (folder_path / f"learned_embeds.{suffix}").exists():
|
|
||||||
return ModelType.TextualInversion
|
|
||||||
if (folder_path / f"pytorch_lora_weights.{suffix}").exists():
|
|
||||||
return ModelType.Lora
|
|
||||||
if (folder_path / "unet/model.onnx").exists():
|
if (folder_path / "unet/model.onnx").exists():
|
||||||
return ModelType.ONNX
|
return ModelType.ONNX
|
||||||
|
if (folder_path / "learned_embeds.bin").exists():
|
||||||
|
return ModelType.TextualInversion
|
||||||
|
if (folder_path / "pytorch_lora_weights.bin").exists():
|
||||||
|
return ModelType.Lora
|
||||||
if (folder_path / "image_encoder.txt").exists():
|
if (folder_path / "image_encoder.txt").exists():
|
||||||
return ModelType.IPAdapter
|
return ModelType.IPAdapter
|
||||||
|
|
||||||
@ -237,7 +236,7 @@ class ModelProbe(object):
|
|||||||
# scan model
|
# scan model
|
||||||
scan_result = scan_file_path(checkpoint)
|
scan_result = scan_file_path(checkpoint)
|
||||||
if scan_result.infected_files != 0:
|
if scan_result.infected_files != 0:
|
||||||
raise Exception("The model {model_name} is potentially infected by malware. Aborting import.")
|
raise "The model {model_name} is potentially infected by malware. Aborting import."
|
||||||
|
|
||||||
|
|
||||||
# ##################################################3
|
# ##################################################3
|
||||||
|
@ -59,7 +59,7 @@ class ModelSearch(ABC):
|
|||||||
for root, dirs, files in os.walk(path, followlinks=True):
|
for root, dirs, files in os.walk(path, followlinks=True):
|
||||||
if str(Path(root).name).startswith("."):
|
if str(Path(root).name).startswith("."):
|
||||||
self._pruned_paths.add(root)
|
self._pruned_paths.add(root)
|
||||||
if any(Path(root).is_relative_to(x) for x in self._pruned_paths):
|
if any([Path(root).is_relative_to(x) for x in self._pruned_paths]):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
self._items_scanned += len(dirs) + len(files)
|
self._items_scanned += len(dirs) + len(files)
|
||||||
@ -69,6 +69,7 @@ class ModelSearch(ABC):
|
|||||||
self._scanned_dirs.add(path)
|
self._scanned_dirs.add(path)
|
||||||
continue
|
continue
|
||||||
if any(
|
if any(
|
||||||
|
[
|
||||||
(path / x).exists()
|
(path / x).exists()
|
||||||
for x in {
|
for x in {
|
||||||
"config.json",
|
"config.json",
|
||||||
@ -77,6 +78,7 @@ class ModelSearch(ABC):
|
|||||||
"pytorch_lora_weights.bin",
|
"pytorch_lora_weights.bin",
|
||||||
"image_encoder.txt",
|
"image_encoder.txt",
|
||||||
}
|
}
|
||||||
|
]
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
self.on_model_found(path)
|
self.on_model_found(path)
|
||||||
|
@ -97,8 +97,8 @@ MODEL_CLASSES = {
|
|||||||
# },
|
# },
|
||||||
}
|
}
|
||||||
|
|
||||||
MODEL_CONFIGS = []
|
MODEL_CONFIGS = list()
|
||||||
OPENAPI_MODEL_CONFIGS = []
|
OPENAPI_MODEL_CONFIGS = list()
|
||||||
|
|
||||||
|
|
||||||
class OpenAPIModelInfoBase(BaseModel):
|
class OpenAPIModelInfoBase(BaseModel):
|
||||||
@ -109,7 +109,7 @@ class OpenAPIModelInfoBase(BaseModel):
|
|||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
|
||||||
for _base_model, models in MODEL_CLASSES.items():
|
for base_model, models in MODEL_CLASSES.items():
|
||||||
for model_type, model_class in models.items():
|
for model_type, model_class in models.items():
|
||||||
model_configs = set(model_class._get_configs().values())
|
model_configs = set(model_class._get_configs().values())
|
||||||
model_configs.discard(None)
|
model_configs.discard(None)
|
||||||
@ -133,7 +133,7 @@ for _base_model, models in MODEL_CLASSES.items():
|
|||||||
|
|
||||||
|
|
||||||
def get_model_config_enums():
|
def get_model_config_enums():
|
||||||
enums = []
|
enums = list()
|
||||||
|
|
||||||
for model_config in MODEL_CONFIGS:
|
for model_config in MODEL_CONFIGS:
|
||||||
if hasattr(inspect, "get_annotations"):
|
if hasattr(inspect, "get_annotations"):
|
||||||
|
@ -153,7 +153,7 @@ class ModelBase(metaclass=ABCMeta):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
res_type = sys.modules["diffusers"]
|
res_type = sys.modules["diffusers"]
|
||||||
res_type = res_type.pipelines
|
res_type = getattr(res_type, "pipelines")
|
||||||
|
|
||||||
for subtype in subtypes:
|
for subtype in subtypes:
|
||||||
res_type = getattr(res_type, subtype)
|
res_type = getattr(res_type, subtype)
|
||||||
@ -164,7 +164,7 @@ class ModelBase(metaclass=ABCMeta):
|
|||||||
with suppress(Exception):
|
with suppress(Exception):
|
||||||
return cls.__configs
|
return cls.__configs
|
||||||
|
|
||||||
configs = {}
|
configs = dict()
|
||||||
for name in dir(cls):
|
for name in dir(cls):
|
||||||
if name.startswith("__"):
|
if name.startswith("__"):
|
||||||
continue
|
continue
|
||||||
@ -246,8 +246,8 @@ class DiffusersModel(ModelBase):
|
|||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
super().__init__(model_path, base_model, model_type)
|
super().__init__(model_path, base_model, model_type)
|
||||||
|
|
||||||
self.child_types: Dict[str, Type] = {}
|
self.child_types: Dict[str, Type] = dict()
|
||||||
self.child_sizes: Dict[str, int] = {}
|
self.child_sizes: Dict[str, int] = dict()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
config_data = DiffusionPipeline.load_config(self.model_path)
|
config_data = DiffusionPipeline.load_config(self.model_path)
|
||||||
@ -326,8 +326,8 @@ def calc_model_size_by_fs(model_path: str, subfolder: Optional[str] = None, vari
|
|||||||
all_files = os.listdir(model_path)
|
all_files = os.listdir(model_path)
|
||||||
all_files = [f for f in all_files if os.path.isfile(os.path.join(model_path, f))]
|
all_files = [f for f in all_files if os.path.isfile(os.path.join(model_path, f))]
|
||||||
|
|
||||||
fp16_files = {f for f in all_files if ".fp16." in f or ".fp16-" in f}
|
fp16_files = set([f for f in all_files if ".fp16." in f or ".fp16-" in f])
|
||||||
bit8_files = {f for f in all_files if ".8bit." in f or ".8bit-" in f}
|
bit8_files = set([f for f in all_files if ".8bit." in f or ".8bit-" in f])
|
||||||
other_files = set(all_files) - fp16_files - bit8_files
|
other_files = set(all_files) - fp16_files - bit8_files
|
||||||
|
|
||||||
if variant is None:
|
if variant is None:
|
||||||
@ -413,7 +413,7 @@ def _calc_onnx_model_by_data(model) -> int:
|
|||||||
|
|
||||||
|
|
||||||
def _fast_safetensors_reader(path: str):
|
def _fast_safetensors_reader(path: str):
|
||||||
checkpoint = {}
|
checkpoint = dict()
|
||||||
device = torch.device("meta")
|
device = torch.device("meta")
|
||||||
with open(path, "rb") as f:
|
with open(path, "rb") as f:
|
||||||
definition_len = int.from_bytes(f.read(8), "little")
|
definition_len = int.from_bytes(f.read(8), "little")
|
||||||
@ -483,7 +483,7 @@ class IAIOnnxRuntimeModel:
|
|||||||
class _tensor_access:
|
class _tensor_access:
|
||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.indexes = {}
|
self.indexes = dict()
|
||||||
for idx, obj in enumerate(self.model.proto.graph.initializer):
|
for idx, obj in enumerate(self.model.proto.graph.initializer):
|
||||||
self.indexes[obj.name] = idx
|
self.indexes[obj.name] = idx
|
||||||
|
|
||||||
@ -524,7 +524,7 @@ class IAIOnnxRuntimeModel:
|
|||||||
|
|
||||||
class _access_helper:
|
class _access_helper:
|
||||||
def __init__(self, raw_proto):
|
def __init__(self, raw_proto):
|
||||||
self.indexes = {}
|
self.indexes = dict()
|
||||||
self.raw_proto = raw_proto
|
self.raw_proto = raw_proto
|
||||||
for idx, obj in enumerate(raw_proto):
|
for idx, obj in enumerate(raw_proto):
|
||||||
self.indexes[obj.name] = idx
|
self.indexes[obj.name] = idx
|
||||||
@ -549,7 +549,7 @@ class IAIOnnxRuntimeModel:
|
|||||||
return self.indexes.keys()
|
return self.indexes.keys()
|
||||||
|
|
||||||
def values(self):
|
def values(self):
|
||||||
return list(self.raw_proto)
|
return [obj for obj in self.raw_proto]
|
||||||
|
|
||||||
def __init__(self, model_path: str, provider: Optional[str]):
|
def __init__(self, model_path: str, provider: Optional[str]):
|
||||||
self.path = model_path
|
self.path = model_path
|
||||||
|
@ -104,7 +104,7 @@ class ControlNetModel(ModelBase):
|
|||||||
return ControlNetModelFormat.Diffusers
|
return ControlNetModelFormat.Diffusers
|
||||||
|
|
||||||
if os.path.isfile(path):
|
if os.path.isfile(path):
|
||||||
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "pth"]):
|
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "pth"]]):
|
||||||
return ControlNetModelFormat.Checkpoint
|
return ControlNetModelFormat.Checkpoint
|
||||||
|
|
||||||
raise InvalidModelException(f"Not a valid model: {path}")
|
raise InvalidModelException(f"Not a valid model: {path}")
|
||||||
|
@ -68,12 +68,11 @@ class LoRAModel(ModelBase):
|
|||||||
raise ModelNotFoundException()
|
raise ModelNotFoundException()
|
||||||
|
|
||||||
if os.path.isdir(path):
|
if os.path.isdir(path):
|
||||||
for ext in ["safetensors", "bin"]:
|
if os.path.exists(os.path.join(path, "pytorch_lora_weights.bin")):
|
||||||
if os.path.exists(os.path.join(path, f"pytorch_lora_weights.{ext}")):
|
|
||||||
return LoRAModelFormat.Diffusers
|
return LoRAModelFormat.Diffusers
|
||||||
|
|
||||||
if os.path.isfile(path):
|
if os.path.isfile(path):
|
||||||
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
|
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
|
||||||
return LoRAModelFormat.LyCORIS
|
return LoRAModelFormat.LyCORIS
|
||||||
|
|
||||||
raise InvalidModelException(f"Not a valid model: {path}")
|
raise InvalidModelException(f"Not a valid model: {path}")
|
||||||
@ -87,10 +86,8 @@ class LoRAModel(ModelBase):
|
|||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
) -> str:
|
) -> str:
|
||||||
if cls.detect_format(model_path) == LoRAModelFormat.Diffusers:
|
if cls.detect_format(model_path) == LoRAModelFormat.Diffusers:
|
||||||
for ext in ["safetensors", "bin"]: # return path to the safetensors file inside the folder
|
# TODO: add diffusers lora when it stabilizes a bit
|
||||||
path = Path(model_path, f"pytorch_lora_weights.{ext}")
|
raise NotImplementedError("Diffusers lora not supported")
|
||||||
if path.exists():
|
|
||||||
return path
|
|
||||||
else:
|
else:
|
||||||
return model_path
|
return model_path
|
||||||
|
|
||||||
@ -462,7 +459,7 @@ class LoRAModelRaw: # (torch.nn.Module):
|
|||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
):
|
):
|
||||||
# TODO: try revert if exception?
|
# TODO: try revert if exception?
|
||||||
for _key, layer in self.layers.items():
|
for key, layer in self.layers.items():
|
||||||
layer.to(device=device, dtype=dtype)
|
layer.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
def calc_size(self) -> int:
|
||||||
@ -499,7 +496,7 @@ class LoRAModelRaw: # (torch.nn.Module):
|
|||||||
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
|
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
|
||||||
stability_unet_keys.sort()
|
stability_unet_keys.sort()
|
||||||
|
|
||||||
new_state_dict = {}
|
new_state_dict = dict()
|
||||||
for full_key, value in state_dict.items():
|
for full_key, value in state_dict.items():
|
||||||
if full_key.startswith("lora_unet_"):
|
if full_key.startswith("lora_unet_"):
|
||||||
search_key = full_key.replace("lora_unet_", "")
|
search_key = full_key.replace("lora_unet_", "")
|
||||||
@ -545,7 +542,7 @@ class LoRAModelRaw: # (torch.nn.Module):
|
|||||||
|
|
||||||
model = cls(
|
model = cls(
|
||||||
name=file_path.stem, # TODO:
|
name=file_path.stem, # TODO:
|
||||||
layers={},
|
layers=dict(),
|
||||||
)
|
)
|
||||||
|
|
||||||
if file_path.suffix == ".safetensors":
|
if file_path.suffix == ".safetensors":
|
||||||
@ -593,12 +590,12 @@ class LoRAModelRaw: # (torch.nn.Module):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _group_state(state_dict: dict):
|
def _group_state(state_dict: dict):
|
||||||
state_dict_groupped = {}
|
state_dict_groupped = dict()
|
||||||
|
|
||||||
for key, value in state_dict.items():
|
for key, value in state_dict.items():
|
||||||
stem, leaf = key.split(".", 1)
|
stem, leaf = key.split(".", 1)
|
||||||
if stem not in state_dict_groupped:
|
if stem not in state_dict_groupped:
|
||||||
state_dict_groupped[stem] = {}
|
state_dict_groupped[stem] = dict()
|
||||||
state_dict_groupped[stem][leaf] = value
|
state_dict_groupped[stem][leaf] = value
|
||||||
|
|
||||||
return state_dict_groupped
|
return state_dict_groupped
|
||||||
|
@ -110,7 +110,7 @@ class StableDiffusion1Model(DiffusersModel):
|
|||||||
return StableDiffusion1ModelFormat.Diffusers
|
return StableDiffusion1ModelFormat.Diffusers
|
||||||
|
|
||||||
if os.path.isfile(model_path):
|
if os.path.isfile(model_path):
|
||||||
if any(model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
|
if any([model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
|
||||||
return StableDiffusion1ModelFormat.Checkpoint
|
return StableDiffusion1ModelFormat.Checkpoint
|
||||||
|
|
||||||
raise InvalidModelException(f"Not a valid model: {model_path}")
|
raise InvalidModelException(f"Not a valid model: {model_path}")
|
||||||
@ -221,7 +221,7 @@ class StableDiffusion2Model(DiffusersModel):
|
|||||||
return StableDiffusion2ModelFormat.Diffusers
|
return StableDiffusion2ModelFormat.Diffusers
|
||||||
|
|
||||||
if os.path.isfile(model_path):
|
if os.path.isfile(model_path):
|
||||||
if any(model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
|
if any([model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
|
||||||
return StableDiffusion2ModelFormat.Checkpoint
|
return StableDiffusion2ModelFormat.Checkpoint
|
||||||
|
|
||||||
raise InvalidModelException(f"Not a valid model: {model_path}")
|
raise InvalidModelException(f"Not a valid model: {model_path}")
|
||||||
|
@ -71,7 +71,7 @@ class TextualInversionModel(ModelBase):
|
|||||||
return None # diffusers-ti
|
return None # diffusers-ti
|
||||||
|
|
||||||
if os.path.isfile(path):
|
if os.path.isfile(path):
|
||||||
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "bin"]):
|
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "bin"]]):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
raise InvalidModelException(f"Not a valid model: {path}")
|
raise InvalidModelException(f"Not a valid model: {path}")
|
||||||
|
@ -89,7 +89,7 @@ class VaeModel(ModelBase):
|
|||||||
return VaeModelFormat.Diffusers
|
return VaeModelFormat.Diffusers
|
||||||
|
|
||||||
if os.path.isfile(path):
|
if os.path.isfile(path):
|
||||||
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
|
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
|
||||||
return VaeModelFormat.Checkpoint
|
return VaeModelFormat.Checkpoint
|
||||||
|
|
||||||
raise InvalidModelException(f"Not a valid model: {path}")
|
raise InvalidModelException(f"Not a valid model: {path}")
|
||||||
|
@ -1,323 +0,0 @@
|
|||||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
|
||||||
"""
|
|
||||||
Configuration definitions for image generation models.
|
|
||||||
|
|
||||||
Typical usage:
|
|
||||||
|
|
||||||
from invokeai.backend.model_manager import ModelConfigFactory
|
|
||||||
raw = dict(path='models/sd-1/main/foo.ckpt',
|
|
||||||
name='foo',
|
|
||||||
base='sd-1',
|
|
||||||
type='main',
|
|
||||||
config='configs/stable-diffusion/v1-inference.yaml',
|
|
||||||
variant='normal',
|
|
||||||
format='checkpoint'
|
|
||||||
)
|
|
||||||
config = ModelConfigFactory.make_config(raw)
|
|
||||||
print(config.name)
|
|
||||||
|
|
||||||
Validation errors will raise an InvalidModelConfigException error.
|
|
||||||
|
|
||||||
"""
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Literal, Optional, Type, Union
|
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
|
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
|
|
||||||
class InvalidModelConfigException(Exception):
|
|
||||||
"""Exception for when config parser doesn't recognized this combination of model type and format."""
|
|
||||||
|
|
||||||
|
|
||||||
class BaseModelType(str, Enum):
|
|
||||||
"""Base model type."""
|
|
||||||
|
|
||||||
Any = "any"
|
|
||||||
StableDiffusion1 = "sd-1"
|
|
||||||
StableDiffusion2 = "sd-2"
|
|
||||||
StableDiffusionXL = "sdxl"
|
|
||||||
StableDiffusionXLRefiner = "sdxl-refiner"
|
|
||||||
# Kandinsky2_1 = "kandinsky-2.1"
|
|
||||||
|
|
||||||
|
|
||||||
class ModelType(str, Enum):
|
|
||||||
"""Model type."""
|
|
||||||
|
|
||||||
ONNX = "onnx"
|
|
||||||
Main = "main"
|
|
||||||
Vae = "vae"
|
|
||||||
Lora = "lora"
|
|
||||||
ControlNet = "controlnet" # used by model_probe
|
|
||||||
TextualInversion = "embedding"
|
|
||||||
IPAdapter = "ip_adapter"
|
|
||||||
CLIPVision = "clip_vision"
|
|
||||||
T2IAdapter = "t2i_adapter"
|
|
||||||
|
|
||||||
|
|
||||||
class SubModelType(str, Enum):
|
|
||||||
"""Submodel type."""
|
|
||||||
|
|
||||||
UNet = "unet"
|
|
||||||
TextEncoder = "text_encoder"
|
|
||||||
TextEncoder2 = "text_encoder_2"
|
|
||||||
Tokenizer = "tokenizer"
|
|
||||||
Tokenizer2 = "tokenizer_2"
|
|
||||||
Vae = "vae"
|
|
||||||
VaeDecoder = "vae_decoder"
|
|
||||||
VaeEncoder = "vae_encoder"
|
|
||||||
Scheduler = "scheduler"
|
|
||||||
SafetyChecker = "safety_checker"
|
|
||||||
|
|
||||||
|
|
||||||
class ModelVariantType(str, Enum):
|
|
||||||
"""Variant type."""
|
|
||||||
|
|
||||||
Normal = "normal"
|
|
||||||
Inpaint = "inpaint"
|
|
||||||
Depth = "depth"
|
|
||||||
|
|
||||||
|
|
||||||
class ModelFormat(str, Enum):
|
|
||||||
"""Storage format of model."""
|
|
||||||
|
|
||||||
Diffusers = "diffusers"
|
|
||||||
Checkpoint = "checkpoint"
|
|
||||||
Lycoris = "lycoris"
|
|
||||||
Onnx = "onnx"
|
|
||||||
Olive = "olive"
|
|
||||||
EmbeddingFile = "embedding_file"
|
|
||||||
EmbeddingFolder = "embedding_folder"
|
|
||||||
InvokeAI = "invokeai"
|
|
||||||
|
|
||||||
|
|
||||||
class SchedulerPredictionType(str, Enum):
|
|
||||||
"""Scheduler prediction type."""
|
|
||||||
|
|
||||||
Epsilon = "epsilon"
|
|
||||||
VPrediction = "v_prediction"
|
|
||||||
Sample = "sample"
|
|
||||||
|
|
||||||
|
|
||||||
class ModelConfigBase(BaseModel):
|
|
||||||
"""Base class for model configuration information."""
|
|
||||||
|
|
||||||
path: str
|
|
||||||
name: str
|
|
||||||
base: BaseModelType
|
|
||||||
type: ModelType
|
|
||||||
format: ModelFormat
|
|
||||||
key: str = Field(description="unique key for model", default="<NOKEY>")
|
|
||||||
original_hash: Optional[str] = Field(
|
|
||||||
description="original fasthash of model contents", default=None
|
|
||||||
) # this is assigned at install time and will not change
|
|
||||||
current_hash: Optional[str] = Field(
|
|
||||||
description="current fasthash of model contents", default=None
|
|
||||||
) # if model is converted or otherwise modified, this will hold updated hash
|
|
||||||
description: Optional[str] = Field(default=None)
|
|
||||||
source: Optional[str] = Field(description="Model download source (URL or repo_id)", default=None)
|
|
||||||
|
|
||||||
model_config = ConfigDict(
|
|
||||||
use_enum_values=False,
|
|
||||||
validate_assignment=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
def update(self, attributes: dict):
|
|
||||||
"""Update the object with fields in dict."""
|
|
||||||
for key, value in attributes.items():
|
|
||||||
setattr(self, key, value) # may raise a validation error
|
|
||||||
|
|
||||||
|
|
||||||
class _CheckpointConfig(ModelConfigBase):
|
|
||||||
"""Model config for checkpoint-style models."""
|
|
||||||
|
|
||||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
|
||||||
config: str = Field(description="path to the checkpoint model config file")
|
|
||||||
|
|
||||||
|
|
||||||
class _DiffusersConfig(ModelConfigBase):
|
|
||||||
"""Model config for diffusers-style models."""
|
|
||||||
|
|
||||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
|
||||||
|
|
||||||
|
|
||||||
class LoRAConfig(ModelConfigBase):
|
|
||||||
"""Model config for LoRA/Lycoris models."""
|
|
||||||
|
|
||||||
type: Literal[ModelType.Lora] = ModelType.Lora
|
|
||||||
format: Literal[ModelFormat.Lycoris, ModelFormat.Diffusers]
|
|
||||||
|
|
||||||
|
|
||||||
class VaeCheckpointConfig(ModelConfigBase):
|
|
||||||
"""Model config for standalone VAE models."""
|
|
||||||
|
|
||||||
type: Literal[ModelType.Vae] = ModelType.Vae
|
|
||||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
|
||||||
|
|
||||||
|
|
||||||
class VaeDiffusersConfig(ModelConfigBase):
|
|
||||||
"""Model config for standalone VAE models (diffusers version)."""
|
|
||||||
|
|
||||||
type: Literal[ModelType.Vae] = ModelType.Vae
|
|
||||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
|
||||||
|
|
||||||
|
|
||||||
class ControlNetDiffusersConfig(_DiffusersConfig):
|
|
||||||
"""Model config for ControlNet models (diffusers version)."""
|
|
||||||
|
|
||||||
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
|
||||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
|
||||||
|
|
||||||
|
|
||||||
class ControlNetCheckpointConfig(_CheckpointConfig):
|
|
||||||
"""Model config for ControlNet models (diffusers version)."""
|
|
||||||
|
|
||||||
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
|
||||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
|
||||||
|
|
||||||
|
|
||||||
class TextualInversionConfig(ModelConfigBase):
|
|
||||||
"""Model config for textual inversion embeddings."""
|
|
||||||
|
|
||||||
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
|
|
||||||
format: Literal[ModelFormat.EmbeddingFile, ModelFormat.EmbeddingFolder]
|
|
||||||
|
|
||||||
|
|
||||||
class _MainConfig(ModelConfigBase):
|
|
||||||
"""Model config for main models."""
|
|
||||||
|
|
||||||
vae: Optional[str] = Field(default=None)
|
|
||||||
variant: ModelVariantType = ModelVariantType.Normal
|
|
||||||
ztsnr_training: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class MainCheckpointConfig(_CheckpointConfig, _MainConfig):
|
|
||||||
"""Model config for main checkpoint models."""
|
|
||||||
|
|
||||||
type: Literal[ModelType.Main] = ModelType.Main
|
|
||||||
# Note that we do not need prediction_type or upcast_attention here
|
|
||||||
# because they are provided in the checkpoint's own config file.
|
|
||||||
|
|
||||||
|
|
||||||
class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
|
|
||||||
"""Model config for main diffusers models."""
|
|
||||||
|
|
||||||
type: Literal[ModelType.Main] = ModelType.Main
|
|
||||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
|
||||||
upcast_attention: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class ONNXSD1Config(_MainConfig):
|
|
||||||
"""Model config for ONNX format models based on sd-1."""
|
|
||||||
|
|
||||||
type: Literal[ModelType.ONNX] = ModelType.ONNX
|
|
||||||
format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
|
|
||||||
base: Literal[BaseModelType.StableDiffusion1] = BaseModelType.StableDiffusion1
|
|
||||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
|
||||||
upcast_attention: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class ONNXSD2Config(_MainConfig):
|
|
||||||
"""Model config for ONNX format models based on sd-2."""
|
|
||||||
|
|
||||||
type: Literal[ModelType.ONNX] = ModelType.ONNX
|
|
||||||
format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
|
|
||||||
# No yaml config file for ONNX, so these are part of config
|
|
||||||
base: Literal[BaseModelType.StableDiffusion2] = BaseModelType.StableDiffusion2
|
|
||||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.VPrediction
|
|
||||||
upcast_attention: bool = True
|
|
||||||
|
|
||||||
|
|
||||||
class IPAdapterConfig(ModelConfigBase):
|
|
||||||
"""Model config for IP Adaptor format models."""
|
|
||||||
|
|
||||||
type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter
|
|
||||||
format: Literal[ModelFormat.InvokeAI]
|
|
||||||
|
|
||||||
|
|
||||||
class CLIPVisionDiffusersConfig(ModelConfigBase):
|
|
||||||
"""Model config for ClipVision."""
|
|
||||||
|
|
||||||
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
|
|
||||||
format: Literal[ModelFormat.Diffusers]
|
|
||||||
|
|
||||||
|
|
||||||
class T2IConfig(ModelConfigBase):
|
|
||||||
"""Model config for T2I."""
|
|
||||||
|
|
||||||
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
|
|
||||||
format: Literal[ModelFormat.Diffusers]
|
|
||||||
|
|
||||||
|
|
||||||
_ONNXConfig = Annotated[Union[ONNXSD1Config, ONNXSD2Config], Field(discriminator="base")]
|
|
||||||
_ControlNetConfig = Annotated[
|
|
||||||
Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig],
|
|
||||||
Field(discriminator="format"),
|
|
||||||
]
|
|
||||||
_VaeConfig = Annotated[Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format")]
|
|
||||||
_MainModelConfig = Annotated[Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")]
|
|
||||||
|
|
||||||
AnyModelConfig = Union[
|
|
||||||
_MainModelConfig,
|
|
||||||
_ONNXConfig,
|
|
||||||
_VaeConfig,
|
|
||||||
_ControlNetConfig,
|
|
||||||
LoRAConfig,
|
|
||||||
TextualInversionConfig,
|
|
||||||
IPAdapterConfig,
|
|
||||||
CLIPVisionDiffusersConfig,
|
|
||||||
T2IConfig,
|
|
||||||
]
|
|
||||||
|
|
||||||
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
|
|
||||||
|
|
||||||
# IMPLEMENTATION NOTE:
|
|
||||||
# The preferred alternative to the above is a discriminated Union as shown
|
|
||||||
# below. However, it breaks FastAPI when used as the input Body parameter in a route.
|
|
||||||
# This is a known issue. Please see:
|
|
||||||
# https://github.com/tiangolo/fastapi/discussions/9761 and
|
|
||||||
# https://github.com/tiangolo/fastapi/discussions/9287
|
|
||||||
# AnyModelConfig = Annotated[
|
|
||||||
# Union[
|
|
||||||
# _MainModelConfig,
|
|
||||||
# _ONNXConfig,
|
|
||||||
# _VaeConfig,
|
|
||||||
# _ControlNetConfig,
|
|
||||||
# LoRAConfig,
|
|
||||||
# TextualInversionConfig,
|
|
||||||
# IPAdapterConfig,
|
|
||||||
# CLIPVisionDiffusersConfig,
|
|
||||||
# T2IConfig,
|
|
||||||
# ],
|
|
||||||
# Field(discriminator="type"),
|
|
||||||
# ]
|
|
||||||
|
|
||||||
|
|
||||||
class ModelConfigFactory(object):
|
|
||||||
"""Class for parsing config dicts into StableDiffusion Config obects."""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def make_config(
|
|
||||||
cls,
|
|
||||||
model_data: Union[dict, AnyModelConfig],
|
|
||||||
key: Optional[str] = None,
|
|
||||||
dest_class: Optional[Type] = None,
|
|
||||||
) -> AnyModelConfig:
|
|
||||||
"""
|
|
||||||
Return the appropriate config object from raw dict values.
|
|
||||||
|
|
||||||
:param model_data: A raw dict corresponding the obect fields to be
|
|
||||||
parsed into a ModelConfigBase obect (or descendent), or a ModelConfigBase
|
|
||||||
object, which will be passed through unchanged.
|
|
||||||
:param dest_class: The config class to be returned. If not provided, will
|
|
||||||
be selected automatically.
|
|
||||||
"""
|
|
||||||
if isinstance(model_data, ModelConfigBase):
|
|
||||||
model = model_data
|
|
||||||
elif dest_class:
|
|
||||||
model = dest_class.validate_python(model_data)
|
|
||||||
else:
|
|
||||||
model = AnyModelConfigValidator.validate_python(model_data)
|
|
||||||
if key:
|
|
||||||
model.key = key
|
|
||||||
return model
|
|
@ -1,66 +0,0 @@
|
|||||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
|
||||||
"""
|
|
||||||
Fast hashing of diffusers and checkpoint-style models.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
from invokeai.backend.model_managre.model_hash import FastModelHash
|
|
||||||
>>> FastModelHash.hash('/home/models/stable-diffusion-v1.5')
|
|
||||||
'a8e693a126ea5b831c96064dc569956f'
|
|
||||||
"""
|
|
||||||
|
|
||||||
import hashlib
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, Union
|
|
||||||
|
|
||||||
from imohash import hashfile
|
|
||||||
|
|
||||||
|
|
||||||
class FastModelHash(object):
|
|
||||||
"""FastModelHash obect provides one public class method, hash()."""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def hash(cls, model_location: Union[str, Path]) -> str:
|
|
||||||
"""
|
|
||||||
Return hexdigest string for model located at model_location.
|
|
||||||
|
|
||||||
:param model_location: Path to the model
|
|
||||||
"""
|
|
||||||
model_location = Path(model_location)
|
|
||||||
if model_location.is_file():
|
|
||||||
return cls._hash_file(model_location)
|
|
||||||
elif model_location.is_dir():
|
|
||||||
return cls._hash_dir(model_location)
|
|
||||||
else:
|
|
||||||
raise OSError(f"Not a valid file or directory: {model_location}")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _hash_file(cls, model_location: Union[str, Path]) -> str:
|
|
||||||
"""
|
|
||||||
Fasthash a single file and return its hexdigest.
|
|
||||||
|
|
||||||
:param model_location: Path to the model file
|
|
||||||
"""
|
|
||||||
# we return md5 hash of the filehash to make it shorter
|
|
||||||
# cryptographic security not needed here
|
|
||||||
return hashlib.md5(hashfile(model_location)).hexdigest()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _hash_dir(cls, model_location: Union[str, Path]) -> str:
|
|
||||||
components: Dict[str, str] = {}
|
|
||||||
|
|
||||||
for root, _dirs, files in os.walk(model_location):
|
|
||||||
for file in files:
|
|
||||||
# only tally tensor files because diffusers config files change slightly
|
|
||||||
# depending on how the model was downloaded/converted.
|
|
||||||
if not file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")):
|
|
||||||
continue
|
|
||||||
path = (Path(root) / file).as_posix()
|
|
||||||
fast_hash = cls._hash_file(path)
|
|
||||||
components.update({path: fast_hash})
|
|
||||||
|
|
||||||
# hash all the model hashes together, using alphabetic file order
|
|
||||||
md5 = hashlib.md5()
|
|
||||||
for _path, fast_hash in sorted(components.items()):
|
|
||||||
md5.update(fast_hash.encode("utf-8"))
|
|
||||||
return md5.hexdigest()
|
|
@ -1,93 +0,0 @@
|
|||||||
# Copyright (c) 2023 Lincoln D. Stein
|
|
||||||
"""Migrate from the InvokeAI v2 models.yaml format to the v3 sqlite format."""
|
|
||||||
|
|
||||||
from hashlib import sha1
|
|
||||||
|
|
||||||
from omegaconf import DictConfig, OmegaConf
|
|
||||||
from pydantic import TypeAdapter
|
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
|
||||||
from invokeai.app.services.model_records import (
|
|
||||||
DuplicateModelException,
|
|
||||||
ModelRecordServiceSQL,
|
|
||||||
)
|
|
||||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
|
||||||
from invokeai.backend.model_manager.config import (
|
|
||||||
AnyModelConfig,
|
|
||||||
BaseModelType,
|
|
||||||
ModelType,
|
|
||||||
)
|
|
||||||
from invokeai.backend.model_manager.hash import FastModelHash
|
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
|
||||||
|
|
||||||
ModelsValidator = TypeAdapter(AnyModelConfig)
|
|
||||||
|
|
||||||
|
|
||||||
class MigrateModelYamlToDb:
|
|
||||||
"""
|
|
||||||
Migrate the InvokeAI models.yaml format (VERSION 3.0.0) to SQL3 database format (VERSION 3.2.0)
|
|
||||||
|
|
||||||
The class has one externally useful method, migrate(), which scans the
|
|
||||||
currently models.yaml file and imports all its entries into invokeai.db.
|
|
||||||
|
|
||||||
Use this way:
|
|
||||||
|
|
||||||
from invokeai.backend.model_manager/migrate_to_db import MigrateModelYamlToDb
|
|
||||||
MigrateModelYamlToDb().migrate()
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
config: InvokeAIAppConfig
|
|
||||||
logger: InvokeAILogger
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.config = InvokeAIAppConfig.get_config()
|
|
||||||
self.config.parse_args()
|
|
||||||
self.logger = InvokeAILogger.get_logger()
|
|
||||||
|
|
||||||
def get_db(self) -> ModelRecordServiceSQL:
|
|
||||||
"""Fetch the sqlite3 database for this installation."""
|
|
||||||
db = SqliteDatabase(self.config, self.logger)
|
|
||||||
return ModelRecordServiceSQL(db)
|
|
||||||
|
|
||||||
def get_yaml(self) -> DictConfig:
|
|
||||||
"""Fetch the models.yaml DictConfig for this installation."""
|
|
||||||
yaml_path = self.config.model_conf_path
|
|
||||||
return OmegaConf.load(yaml_path)
|
|
||||||
|
|
||||||
def migrate(self):
|
|
||||||
"""Do the migration from models.yaml to invokeai.db."""
|
|
||||||
db = self.get_db()
|
|
||||||
yaml = self.get_yaml()
|
|
||||||
|
|
||||||
for model_key, stanza in yaml.items():
|
|
||||||
if model_key == "__metadata__":
|
|
||||||
assert (
|
|
||||||
stanza["version"] == "3.0.0"
|
|
||||||
), f"This script works on version 3.0.0 yaml files, but your configuration points to a {stanza['version']} version"
|
|
||||||
continue
|
|
||||||
|
|
||||||
base_type, model_type, model_name = str(model_key).split("/")
|
|
||||||
hash = FastModelHash.hash(self.config.models_path / stanza.path)
|
|
||||||
new_key = sha1(model_key.encode("utf-8")).hexdigest()
|
|
||||||
|
|
||||||
stanza["base"] = BaseModelType(base_type)
|
|
||||||
stanza["type"] = ModelType(model_type)
|
|
||||||
stanza["name"] = model_name
|
|
||||||
stanza["original_hash"] = hash
|
|
||||||
stanza["current_hash"] = hash
|
|
||||||
|
|
||||||
new_config = ModelsValidator.validate_python(stanza)
|
|
||||||
self.logger.info(f"Adding model {model_name} with key {model_key}")
|
|
||||||
try:
|
|
||||||
db.add_model(new_key, new_config)
|
|
||||||
except DuplicateModelException:
|
|
||||||
self.logger.warning(f"Model {model_name} is already in the database")
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
MigrateModelYamlToDb().migrate()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
@ -193,7 +193,6 @@ class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput):
|
|||||||
attention_map_saver (`AttentionMapSaver`): Object containing attention maps that can be displayed to the user
|
attention_map_saver (`AttentionMapSaver`): Object containing attention maps that can be displayed to the user
|
||||||
after generation completes. Optional.
|
after generation completes. Optional.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
attention_map_saver: Optional[AttentionMapSaver]
|
attention_map_saver: Optional[AttentionMapSaver]
|
||||||
|
|
||||||
|
|
||||||
@ -547,13 +546,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
# Handle ControlNet(s) and T2I-Adapter(s)
|
# Handle ControlNet(s) and T2I-Adapter(s)
|
||||||
down_block_additional_residuals = None
|
down_block_additional_residuals = None
|
||||||
mid_block_additional_residual = None
|
mid_block_additional_residual = None
|
||||||
down_intrablock_additional_residuals = None
|
if control_data is not None and t2i_adapter_data is not None:
|
||||||
# if control_data is not None and t2i_adapter_data is not None:
|
|
||||||
# TODO(ryand): This is a limitation of the UNet2DConditionModel API, not a fundamental incompatibility
|
# TODO(ryand): This is a limitation of the UNet2DConditionModel API, not a fundamental incompatibility
|
||||||
# between ControlNets and T2I-Adapters. We will try to fix this upstream in diffusers.
|
# between ControlNets and T2I-Adapters. We will try to fix this upstream in diffusers.
|
||||||
# raise Exception("ControlNet(s) and T2I-Adapter(s) cannot be used simultaneously (yet).")
|
raise Exception("ControlNet(s) and T2I-Adapter(s) cannot be used simultaneously (yet).")
|
||||||
# elif control_data is not None:
|
elif control_data is not None:
|
||||||
if control_data is not None:
|
|
||||||
down_block_additional_residuals, mid_block_additional_residual = self.invokeai_diffuser.do_controlnet_step(
|
down_block_additional_residuals, mid_block_additional_residual = self.invokeai_diffuser.do_controlnet_step(
|
||||||
control_data=control_data,
|
control_data=control_data,
|
||||||
sample=latent_model_input,
|
sample=latent_model_input,
|
||||||
@ -562,8 +559,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
total_step_count=total_step_count,
|
total_step_count=total_step_count,
|
||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
)
|
)
|
||||||
# elif t2i_adapter_data is not None:
|
elif t2i_adapter_data is not None:
|
||||||
if t2i_adapter_data is not None:
|
|
||||||
accum_adapter_state = None
|
accum_adapter_state = None
|
||||||
for single_t2i_adapter_data in t2i_adapter_data:
|
for single_t2i_adapter_data in t2i_adapter_data:
|
||||||
# Determine the T2I-Adapter weights for the current denoising step.
|
# Determine the T2I-Adapter weights for the current denoising step.
|
||||||
@ -588,8 +584,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
for idx, value in enumerate(single_t2i_adapter_data.adapter_state):
|
for idx, value in enumerate(single_t2i_adapter_data.adapter_state):
|
||||||
accum_adapter_state[idx] += value * t2i_adapter_weight
|
accum_adapter_state[idx] += value * t2i_adapter_weight
|
||||||
|
|
||||||
# down_block_additional_residuals = accum_adapter_state
|
down_block_additional_residuals = accum_adapter_state
|
||||||
down_intrablock_additional_residuals = accum_adapter_state
|
|
||||||
|
|
||||||
uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step(
|
uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step(
|
||||||
sample=latent_model_input,
|
sample=latent_model_input,
|
||||||
@ -598,9 +593,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
total_step_count=total_step_count,
|
total_step_count=total_step_count,
|
||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
# extra:
|
# extra:
|
||||||
down_block_additional_residuals=down_block_additional_residuals, # for ControlNet
|
down_block_additional_residuals=down_block_additional_residuals,
|
||||||
mid_block_additional_residual=mid_block_additional_residual, # for ControlNet
|
mid_block_additional_residual=mid_block_additional_residual,
|
||||||
down_intrablock_additional_residuals=down_intrablock_additional_residuals, # for T2I-Adapter
|
|
||||||
)
|
)
|
||||||
|
|
||||||
guidance_scale = conditioning_data.guidance_scale
|
guidance_scale = conditioning_data.guidance_scale
|
||||||
|
@ -54,13 +54,13 @@ class Context:
|
|||||||
self.clear_requests(cleanup=True)
|
self.clear_requests(cleanup=True)
|
||||||
|
|
||||||
def register_cross_attention_modules(self, model):
|
def register_cross_attention_modules(self, model):
|
||||||
for name, _module in get_cross_attention_modules(model, CrossAttentionType.SELF):
|
for name, module in get_cross_attention_modules(model, CrossAttentionType.SELF):
|
||||||
if name in self.self_cross_attention_module_identifiers:
|
if name in self.self_cross_attention_module_identifiers:
|
||||||
raise AssertionError(f"name {name} cannot appear more than once")
|
assert False, f"name {name} cannot appear more than once"
|
||||||
self.self_cross_attention_module_identifiers.append(name)
|
self.self_cross_attention_module_identifiers.append(name)
|
||||||
for name, _module in get_cross_attention_modules(model, CrossAttentionType.TOKENS):
|
for name, module in get_cross_attention_modules(model, CrossAttentionType.TOKENS):
|
||||||
if name in self.tokens_cross_attention_module_identifiers:
|
if name in self.tokens_cross_attention_module_identifiers:
|
||||||
raise AssertionError(f"name {name} cannot appear more than once")
|
assert False, f"name {name} cannot appear more than once"
|
||||||
self.tokens_cross_attention_module_identifiers.append(name)
|
self.tokens_cross_attention_module_identifiers.append(name)
|
||||||
|
|
||||||
def request_save_attention_maps(self, cross_attention_type: CrossAttentionType):
|
def request_save_attention_maps(self, cross_attention_type: CrossAttentionType):
|
||||||
@ -170,7 +170,7 @@ class Context:
|
|||||||
self.saved_cross_attention_maps = {}
|
self.saved_cross_attention_maps = {}
|
||||||
|
|
||||||
def offload_saved_attention_slices_to_cpu(self):
|
def offload_saved_attention_slices_to_cpu(self):
|
||||||
for _key, map_dict in self.saved_cross_attention_maps.items():
|
for key, map_dict in self.saved_cross_attention_maps.items():
|
||||||
for offset, slice in map_dict["slices"].items():
|
for offset, slice in map_dict["slices"].items():
|
||||||
map_dict[offset] = slice.to("cpu")
|
map_dict[offset] = slice.to("cpu")
|
||||||
|
|
||||||
@ -433,7 +433,7 @@ def inject_attention_function(unet, context: Context):
|
|||||||
module.identifier = identifier
|
module.identifier = identifier
|
||||||
try:
|
try:
|
||||||
module.set_attention_slice_wrangler(attention_slice_wrangler)
|
module.set_attention_slice_wrangler(attention_slice_wrangler)
|
||||||
module.set_slicing_strategy_getter(lambda module: context.get_slicing_strategy(identifier)) # noqa: B023
|
module.set_slicing_strategy_getter(lambda module: context.get_slicing_strategy(identifier))
|
||||||
except AttributeError as e:
|
except AttributeError as e:
|
||||||
if is_attribute_error_about(e, "set_attention_slice_wrangler"):
|
if is_attribute_error_about(e, "set_attention_slice_wrangler"):
|
||||||
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") # TODO
|
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") # TODO
|
||||||
@ -445,7 +445,7 @@ def remove_attention_function(unet):
|
|||||||
cross_attention_modules = get_cross_attention_modules(
|
cross_attention_modules = get_cross_attention_modules(
|
||||||
unet, CrossAttentionType.TOKENS
|
unet, CrossAttentionType.TOKENS
|
||||||
) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
|
) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
|
||||||
for _identifier, module in cross_attention_modules:
|
for identifier, module in cross_attention_modules:
|
||||||
try:
|
try:
|
||||||
# clear wrangler callback
|
# clear wrangler callback
|
||||||
module.set_attention_slice_wrangler(None)
|
module.set_attention_slice_wrangler(None)
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user