Merge branch 'main' into fix/controlnet_cfg_inj_cond

This commit is contained in:
blessedcoolant 2023-07-14 01:11:04 +12:00 committed by GitHub
commit 9348dc8e0d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
112 changed files with 3279 additions and 1951 deletions

View File

@ -1,25 +1,9 @@
# use this file as a whitelist
* *
!invokeai !invokeai
!ldm
!pyproject.toml !pyproject.toml
!docker/docker-entrypoint.sh
!LICENSE
# ignore frontend/web but whitelist dist **/node_modules
invokeai/frontend/web/ **/__pycache__
!invokeai/frontend/web/dist/ **/*.egg-info
# ignore invokeai/assets but whitelist invokeai/assets/web
invokeai/assets/
!invokeai/assets/web/
# Guard against pulling in any models that might exist in the directory tree
**/*.pt*
**/*.ckpt
# Byte-compiled / optimized / DLL files
**/__pycache__/
**/*.py[cod]
# Distribution / packaging
**/*.egg-info/
**/*.egg

View File

@ -3,17 +3,15 @@ on:
push: push:
branches: branches:
- 'main' - 'main'
- 'update/ci/docker/*'
- 'update/docker/*'
- 'dev/ci/docker/*'
- 'dev/docker/*'
paths: paths:
- 'pyproject.toml' - 'pyproject.toml'
- '.dockerignore' - '.dockerignore'
- 'invokeai/**' - 'invokeai/**'
- 'docker/Dockerfile' - 'docker/Dockerfile'
- 'docker/docker-entrypoint.sh'
- 'workflows/build-container.yml'
tags: tags:
- 'v*.*.*' - 'v*'
workflow_dispatch: workflow_dispatch:
permissions: permissions:
@ -26,23 +24,27 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
flavor: gpu-driver:
- rocm
- cuda - cuda
- cpu - cpu
include: - rocm
- flavor: rocm
pip-extra-index-url: 'https://download.pytorch.org/whl/rocm5.2'
- flavor: cuda
pip-extra-index-url: ''
- flavor: cpu
pip-extra-index-url: 'https://download.pytorch.org/whl/cpu'
runs-on: ubuntu-latest runs-on: ubuntu-latest
name: ${{ matrix.flavor }} name: ${{ matrix.gpu-driver }}
env: env:
PLATFORMS: 'linux/amd64,linux/arm64' # torch/arm64 does not support GPU currently, so arm64 builds
DOCKERFILE: 'docker/Dockerfile' # would not be GPU-accelerated.
# re-enable arm64 if there is sufficient demand.
# PLATFORMS: 'linux/amd64,linux/arm64'
PLATFORMS: 'linux/amd64'
steps: steps:
- name: Free up more disk space on the runner
# https://github.com/actions/runner-images/issues/2840#issuecomment-1284059930
run: |
sudo rm -rf /usr/share/dotnet
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
sudo swapoff /mnt/swapfile
sudo rm -rf /mnt/swapfile
- name: Checkout - name: Checkout
uses: actions/checkout@v3 uses: actions/checkout@v3
@ -53,7 +55,7 @@ jobs:
github-token: ${{ secrets.GITHUB_TOKEN }} github-token: ${{ secrets.GITHUB_TOKEN }}
images: | images: |
ghcr.io/${{ github.repository }} ghcr.io/${{ github.repository }}
${{ vars.DOCKERHUB_REPOSITORY }} ${{ env.DOCKERHUB_REPOSITORY }}
tags: | tags: |
type=ref,event=branch type=ref,event=branch
type=ref,event=tag type=ref,event=tag
@ -62,8 +64,8 @@ jobs:
type=pep440,pattern={{major}} type=pep440,pattern={{major}}
type=sha,enable=true,prefix=sha-,format=short type=sha,enable=true,prefix=sha-,format=short
flavor: | flavor: |
latest=${{ matrix.flavor == 'cuda' && github.ref == 'refs/heads/main' }} latest=${{ matrix.gpu-driver == 'cuda' && github.ref == 'refs/heads/main' }}
suffix=-${{ matrix.flavor }},onlatest=false suffix=-${{ matrix.gpu-driver }},onlatest=false
- name: Set up QEMU - name: Set up QEMU
uses: docker/setup-qemu-action@v2 uses: docker/setup-qemu-action@v2
@ -81,34 +83,33 @@ jobs:
username: ${{ github.repository_owner }} username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }} password: ${{ secrets.GITHUB_TOKEN }}
- name: Login to Docker Hub # - name: Login to Docker Hub
if: github.event_name != 'pull_request' && vars.DOCKERHUB_REPOSITORY != '' # if: github.event_name != 'pull_request' && vars.DOCKERHUB_REPOSITORY != ''
uses: docker/login-action@v2 # uses: docker/login-action@v2
with: # with:
username: ${{ secrets.DOCKERHUB_USERNAME }} # username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }} # password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Build container - name: Build container
id: docker_build id: docker_build
uses: docker/build-push-action@v4 uses: docker/build-push-action@v4
with: with:
context: . context: .
file: ${{ env.DOCKERFILE }} file: docker/Dockerfile
platforms: ${{ env.PLATFORMS }} platforms: ${{ env.PLATFORMS }}
push: ${{ github.ref == 'refs/heads/main' || github.ref_type == 'tag' }} push: ${{ github.ref == 'refs/heads/main' || github.ref_type == 'tag' }}
tags: ${{ steps.meta.outputs.tags }} tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }} labels: ${{ steps.meta.outputs.labels }}
build-args: PIP_EXTRA_INDEX_URL=${{ matrix.pip-extra-index-url }}
cache-from: | cache-from: |
type=gha,scope=${{ github.ref_name }}-${{ matrix.flavor }} type=gha,scope=${{ github.ref_name }}-${{ matrix.gpu-driver }}
type=gha,scope=main-${{ matrix.flavor }} type=gha,scope=main-${{ matrix.gpu-driver }}
cache-to: type=gha,mode=max,scope=${{ github.ref_name }}-${{ matrix.flavor }} cache-to: type=gha,mode=max,scope=${{ github.ref_name }}-${{ matrix.gpu-driver }}
- name: Docker Hub Description # - name: Docker Hub Description
if: github.ref == 'refs/heads/main' || github.ref == 'refs/tags/*' && vars.DOCKERHUB_REPOSITORY != '' # if: github.ref == 'refs/heads/main' || github.ref == 'refs/tags/*' && vars.DOCKERHUB_REPOSITORY != ''
uses: peter-evans/dockerhub-description@v3 # uses: peter-evans/dockerhub-description@v3
with: # with:
username: ${{ secrets.DOCKERHUB_USERNAME }} # username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }} # password: ${{ secrets.DOCKERHUB_TOKEN }}
repository: ${{ vars.DOCKERHUB_REPOSITORY }} # repository: ${{ vars.DOCKERHUB_REPOSITORY }}
short-description: ${{ github.event.repository.description }} # short-description: ${{ github.event.repository.description }}

13
docker/.env.sample Normal file
View File

@ -0,0 +1,13 @@
## Make a copy of this file named `.env` and fill in the values below.
## Any environment variables supported by InvokeAI can be specified here.
# INVOKEAI_ROOT is the path to a path on the local filesystem where InvokeAI will store data.
# Outputs will also be stored here by default.
# This **must** be an absolute path.
INVOKEAI_ROOT=
HUGGINGFACE_TOKEN=
## optional variables specific to the docker setup
# GPU_DRIVER=cuda
# CONTAINER_UID=1000

View File

@ -1,107 +1,129 @@
# syntax=docker/dockerfile:1 # syntax=docker/dockerfile:1.4
ARG PYTHON_VERSION=3.9 ## Builder stage
##################
## base image ##
##################
FROM --platform=${TARGETPLATFORM} python:${PYTHON_VERSION}-slim AS python-base
LABEL org.opencontainers.image.authors="mauwii@outlook.de" FROM library/ubuntu:22.04 AS builder
# Prepare apt for buildkit cache ARG DEBIAN_FRONTEND=noninteractive
RUN rm -f /etc/apt/apt.conf.d/docker-clean \ RUN rm -f /etc/apt/apt.conf.d/docker-clean; echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' > /etc/apt/apt.conf.d/keep-cache
&& echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' >/etc/apt/apt.conf.d/keep-cache RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
# Install dependencies
RUN \
--mount=type=cache,target=/var/cache/apt,sharing=locked \
--mount=type=cache,target=/var/lib/apt,sharing=locked \ --mount=type=cache,target=/var/lib/apt,sharing=locked \
apt-get update \ apt update && apt-get install -y \
&& apt-get install -y \ git \
--no-install-recommends \ python3.10-venv \
libgl1-mesa-glx=20.3.* \ python3-pip \
libglib2.0-0=2.66.* \ build-essential
libopencv-dev=4.5.*
# Set working directory and env ENV INVOKEAI_SRC=/opt/invokeai
ARG APPDIR=/usr/src ENV VIRTUAL_ENV=/opt/venv/invokeai
ARG APPNAME=InvokeAI
WORKDIR ${APPDIR}
ENV PATH ${APPDIR}/${APPNAME}/bin:$PATH
# Keeps Python from generating .pyc files in the container
ENV PYTHONDONTWRITEBYTECODE 1
# Turns off buffering for easier container logging
ENV PYTHONUNBUFFERED 1
# Don't fall back to legacy build system
ENV PIP_USE_PEP517=1
####################### ENV PATH="$VIRTUAL_ENV/bin:$PATH"
## build pyproject ## ARG TORCH_VERSION=2.0.1
####################### ARG TORCHVISION_VERSION=0.15.2
FROM python-base AS pyproject-builder ARG GPU_DRIVER=cuda
ARG TARGETPLATFORM="linux/amd64"
# unused but available
ARG BUILDPLATFORM
# Install build dependencies WORKDIR ${INVOKEAI_SRC}
RUN \
--mount=type=cache,target=/var/cache/apt,sharing=locked \
--mount=type=cache,target=/var/lib/apt,sharing=locked \
apt-get update \
&& apt-get install -y \
--no-install-recommends \
build-essential=12.9 \
gcc=4:10.2.* \
python3-dev=3.9.*
# Prepare pip for buildkit cache # Install pytorch before all other pip packages
ARG PIP_CACHE_DIR=/var/cache/buildkit/pip # NOTE: there are no pytorch builds for arm64 + cuda, only cpu
ENV PIP_CACHE_DIR ${PIP_CACHE_DIR} # x86_64/CUDA is default
RUN mkdir -p ${PIP_CACHE_DIR} RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m venv ${VIRTUAL_ENV} &&\
if [ "$TARGETPLATFORM" = "linux/arm64" ] || [ "$GPU_DRIVER" = "cpu" ]; then \
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/cpu"; \
elif [ "$GPU_DRIVER" = "rocm" ]; then \
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/rocm5.4.2"; \
else \
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/cu118"; \
fi &&\
pip install $extra_index_url_arg \
torch==$TORCH_VERSION \
torchvision==$TORCHVISION_VERSION
# Create virtual environment # Install the local package.
RUN --mount=type=cache,target=${PIP_CACHE_DIR} \ # Editable mode helps use the same image for development:
python3 -m venv "${APPNAME}" \ # the local working copy can be bind-mounted into the image
--upgrade-deps # at path defined by ${INVOKEAI_SRC}
COPY invokeai ./invokeai
COPY pyproject.toml ./
RUN --mount=type=cache,target=/root/.cache/pip \
# xformers + triton fails to install on arm64
if [ "$GPU_DRIVER" = "cuda" ] && [ "$TARGETPLATFORM" = "linux/amd64" ]; then \
pip install -e ".[xformers]"; \
else \
pip install -e "."; \
fi
# Install requirements # #### Build the Web UI ------------------------------------
COPY --link pyproject.toml .
COPY --link invokeai/version/invokeai_version.py invokeai/version/__init__.py invokeai/version/
ARG PIP_EXTRA_INDEX_URL
ENV PIP_EXTRA_INDEX_URL ${PIP_EXTRA_INDEX_URL}
RUN --mount=type=cache,target=${PIP_CACHE_DIR} \
"${APPNAME}"/bin/pip install .
# Install pyproject.toml FROM node:18 AS web-builder
COPY --link . . WORKDIR /build
RUN --mount=type=cache,target=${PIP_CACHE_DIR} \ COPY invokeai/frontend/web/ ./
"${APPNAME}/bin/pip" install . RUN --mount=type=cache,target=/usr/lib/node_modules \
npm install --include dev
RUN --mount=type=cache,target=/usr/lib/node_modules \
yarn vite build
# Build patchmatch
#### Runtime stage ---------------------------------------
FROM library/ubuntu:22.04 AS runtime
ARG DEBIAN_FRONTEND=noninteractive
ENV PYTHONUNBUFFERED=1
ENV PYTHONDONTWRITEBYTECODE=1
RUN apt update && apt install -y --no-install-recommends \
git \
curl \
vim \
tmux \
ncdu \
iotop \
bzip2 \
gosu \
libglib2.0-0 \
libgl1-mesa-glx \
python3-venv \
python3-pip \
build-essential \
libopencv-dev \
libstdc++-10-dev &&\
apt-get clean && apt-get autoclean
# globally add magic-wormhole
# for ease of transferring data to and from the container
# when running in sandboxed cloud environments; e.g. Runpod etc.
RUN pip install magic-wormhole
ENV INVOKEAI_SRC=/opt/invokeai
ENV VIRTUAL_ENV=/opt/venv/invokeai
ENV INVOKEAI_ROOT=/invokeai
ENV PATH="$VIRTUAL_ENV/bin:$INVOKEAI_SRC:$PATH"
# --link requires buldkit w/ dockerfile syntax 1.4
COPY --link --from=builder ${INVOKEAI_SRC} ${INVOKEAI_SRC}
COPY --link --from=builder ${VIRTUAL_ENV} ${VIRTUAL_ENV}
COPY --link --from=web-builder /build/dist ${INVOKEAI_SRC}/invokeai/frontend/web/dist
# Link amdgpu.ids for ROCm builds
# contributed by https://github.com/Rubonnek
RUN mkdir -p "/opt/amdgpu/share/libdrm" &&\
ln -s "/usr/share/libdrm/amdgpu.ids" "/opt/amdgpu/share/libdrm/amdgpu.ids"
WORKDIR ${INVOKEAI_SRC}
# build patchmatch
RUN cd /usr/lib/$(uname -p)-linux-gnu/pkgconfig/ && ln -sf opencv4.pc opencv.pc
RUN python3 -c "from patchmatch import patch_match" RUN python3 -c "from patchmatch import patch_match"
##################### # Create unprivileged user and make the local dir
## runtime image ## RUN useradd --create-home --shell /bin/bash -u 1000 --comment "container local user" invoke
##################### RUN mkdir -p ${INVOKEAI_ROOT} && chown -R invoke:invoke ${INVOKEAI_ROOT}
FROM python-base AS runtime
# Create a new user COPY docker/docker-entrypoint.sh ./
ARG UNAME=appuser ENTRYPOINT ["/opt/invokeai/docker-entrypoint.sh"]
RUN useradd \ CMD ["invokeai-web", "--host", "0.0.0.0"]
--no-log-init \
-m \
-U \
"${UNAME}"
# Create volume directory
ARG VOLUME_DIR=/data
RUN mkdir -p "${VOLUME_DIR}" \
&& chown -hR "${UNAME}:${UNAME}" "${VOLUME_DIR}"
# Setup runtime environment
USER ${UNAME}:${UNAME}
COPY --chown=${UNAME}:${UNAME} --from=pyproject-builder ${APPDIR}/${APPNAME} ${APPNAME}
ENV INVOKEAI_ROOT ${VOLUME_DIR}
ENV TRANSFORMERS_CACHE ${VOLUME_DIR}/.cache
ENV INVOKE_MODEL_RECONFIGURE "--yes --default_only"
EXPOSE 9090
ENTRYPOINT [ "invokeai" ]
CMD [ "--web", "--host", "0.0.0.0", "--port", "9090" ]
VOLUME [ "${VOLUME_DIR}" ]

77
docker/README.md Normal file
View File

@ -0,0 +1,77 @@
# InvokeAI Containerized
All commands are to be run from the `docker` directory: `cd docker`
#### Linux
1. Ensure builkit is enabled in the Docker daemon settings (`/etc/docker/daemon.json`)
2. Install the `docker compose` plugin using your package manager, or follow a [tutorial](https://www.digitalocean.com/community/tutorials/how-to-install-and-use-docker-compose-on-ubuntu-22-04).
- The deprecated `docker-compose` (hyphenated) CLI continues to work for now.
3. Ensure docker daemon is able to access the GPU.
- You may need to install [nvidia-container-toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html)
#### macOS
1. Ensure Docker has at least 16GB RAM
2. Enable VirtioFS for file sharing
3. Enable `docker compose` V2 support
This is done via Docker Desktop preferences
## Quickstart
1. Make a copy of `env.sample` and name it `.env` (`cp env.sample .env` (Mac/Linux) or `copy example.env .env` (Windows)). Make changes as necessary. Set `INVOKEAI_ROOT` to an absolute path to:
a. the desired location of the InvokeAI runtime directory, or
b. an existing, v3.0.0 compatible runtime directory.
1. `docker compose up`
The image will be built automatically if needed.
The runtime directory (holding models and outputs) will be created in the location specified by `INVOKEAI_ROOT`. The default location is `~/invokeai`. The runtime directory will be populated with the base configs and models necessary to start generating.
### Use a GPU
- Linux is *recommended* for GPU support in Docker.
- WSL2 is *required* for Windows.
- only `x86_64` architecture is supported.
The Docker daemon on the system must be already set up to use the GPU. In case of Linux, this involves installing `nvidia-docker-runtime` and configuring the `nvidia` runtime as default. Steps will be different for AMD. Please see Docker documentation for the most up-to-date instructions for using your GPU with Docker.
## Customize
Check the `.env.sample` file. It contains some environment variables for running in Docker. Copy it, name it `.env`, and fill it in with your own values. Next time you run `docker compose up`, your custom values will be used.
You can also set these values in `docker compose.yml` directly, but `.env` will help avoid conflicts when code is updated.
Example (most values are optional):
```
INVOKEAI_ROOT=/Volumes/WorkDrive/invokeai
HUGGINGFACE_TOKEN=the_actual_token
CONTAINER_UID=1000
GPU_DRIVER=cuda
```
## Even Moar Customizing!
See the `docker compose.yaml` file. The `command` instruction can be uncommented and used to run arbitrary startup commands. Some examples below.
### Reconfigure the runtime directory
Can be used to download additional models from the supported model list
In conjunction with `INVOKEAI_ROOT` can be also used to initialize a runtime directory
```
command:
- invokeai-configure
- --yes
```
Or install models:
```
command:
- invokeai-model-install
```

View File

@ -1,51 +1,11 @@
#!/usr/bin/env bash #!/usr/bin/env bash
set -e set -e
# If you want to build a specific flavor, set the CONTAINER_FLAVOR environment variable build_args=""
# e.g. CONTAINER_FLAVOR=cpu ./build.sh
# Possible Values are:
# - cpu
# - cuda
# - rocm
# Don't forget to also set it when executing run.sh
# if it is not set, the script will try to detect the flavor by itself.
#
# Doc can be found here:
# https://invoke-ai.github.io/InvokeAI/installation/040_INSTALL_DOCKER/
SCRIPTDIR=$(dirname "${BASH_SOURCE[0]}") [[ -f ".env" ]] && build_args=$(awk '$1 ~ /\=[^$]/ {print "--build-arg " $0 " "}' .env)
cd "$SCRIPTDIR" || exit 1
source ./env.sh echo "docker-compose build args:"
echo $build_args
DOCKERFILE=${INVOKE_DOCKERFILE:-./Dockerfile} docker-compose build $build_args
# print the settings
echo -e "You are using these values:\n"
echo -e "Dockerfile:\t\t${DOCKERFILE}"
echo -e "index-url:\t\t${PIP_EXTRA_INDEX_URL:-none}"
echo -e "Volumename:\t\t${VOLUMENAME}"
echo -e "Platform:\t\t${PLATFORM}"
echo -e "Container Registry:\t${CONTAINER_REGISTRY}"
echo -e "Container Repository:\t${CONTAINER_REPOSITORY}"
echo -e "Container Tag:\t\t${CONTAINER_TAG}"
echo -e "Container Flavor:\t${CONTAINER_FLAVOR}"
echo -e "Container Image:\t${CONTAINER_IMAGE}\n"
# Create docker volume
if [[ -n "$(docker volume ls -f name="${VOLUMENAME}" -q)" ]]; then
echo -e "Volume already exists\n"
else
echo -n "creating docker volume "
docker volume create "${VOLUMENAME}"
fi
# Build Container
docker build \
--platform="${PLATFORM:-linux/amd64}" \
--tag="${CONTAINER_IMAGE:-invokeai}" \
${CONTAINER_FLAVOR:+--build-arg="CONTAINER_FLAVOR=${CONTAINER_FLAVOR}"} \
${PIP_EXTRA_INDEX_URL:+--build-arg="PIP_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL}"} \
${PIP_PACKAGE:+--build-arg="PIP_PACKAGE=${PIP_PACKAGE}"} \
--file="${DOCKERFILE}" \
..

48
docker/docker-compose.yml Normal file
View File

@ -0,0 +1,48 @@
# Copyright (c) 2023 Eugene Brodsky https://github.com/ebr
version: '3.8'
services:
invokeai:
image: "local/invokeai:latest"
# edit below to run on a container runtime other than nvidia-container-runtime.
# not yet tested with rocm/AMD GPUs
# Comment out the "deploy" section to run on CPU only
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
build:
context: ..
dockerfile: docker/Dockerfile
# variables without a default will automatically inherit from the host environment
environment:
- INVOKEAI_ROOT
- HF_HOME
# Create a .env file in the same directory as this docker-compose.yml file
# and populate it with environment variables. See .env.sample
env_file:
- .env
ports:
- "${INVOKEAI_PORT:-9090}:9090"
volumes:
- ${INVOKEAI_ROOT:-~/invokeai}:${INVOKEAI_ROOT:-/invokeai}
- ${HF_HOME:-~/.cache/huggingface}:${HF_HOME:-/invokeai/.cache/huggingface}
# - ${INVOKEAI_MODELS_DIR:-${INVOKEAI_ROOT:-/invokeai/models}}
# - ${INVOKEAI_MODELS_CONFIG_PATH:-${INVOKEAI_ROOT:-/invokeai/configs/models.yaml}}
tty: true
stdin_open: true
# # Example of running alternative commands/scripts in the container
# command:
# - bash
# - -c
# - |
# invokeai-model-install --yes --default-only --config_file ${INVOKEAI_ROOT}/config_custom.yaml
# invokeai-nodes-web --host 0.0.0.0

65
docker/docker-entrypoint.sh Executable file
View File

@ -0,0 +1,65 @@
#!/bin/bash
set -e -o pipefail
### Container entrypoint
# Runs the CMD as defined by the Dockerfile or passed to `docker run`
# Can be used to configure the runtime dir
# Bypass by using ENTRYPOINT or `--entrypoint`
### Set INVOKEAI_ROOT pointing to a valid runtime directory
# Otherwise configure the runtime dir first.
### Configure the InvokeAI runtime directory (done by default)):
# docker run --rm -it <this image> --configure
# or skip with --no-configure
### Set the CONTAINER_UID envvar to match your user.
# Ensures files created in the container are owned by you:
# docker run --rm -it -v /some/path:/invokeai -e CONTAINER_UID=$(id -u) <this image>
# Default UID: 1000 chosen due to popularity on Linux systems. Possibly 501 on MacOS.
USER_ID=${CONTAINER_UID:-1000}
USER=invoke
usermod -u ${USER_ID} ${USER} 1>/dev/null
configure() {
# Configure the runtime directory
if [[ -f ${INVOKEAI_ROOT}/invokeai.yaml ]]; then
echo "${INVOKEAI_ROOT}/invokeai.yaml exists. InvokeAI is already configured."
echo "To reconfigure InvokeAI, delete the above file."
echo "======================================================================"
else
mkdir -p ${INVOKEAI_ROOT}
chown --recursive ${USER} ${INVOKEAI_ROOT}
gosu ${USER} invokeai-configure --yes --default_only
fi
}
## Skip attempting to configure.
## Must be passed first, before any other args.
if [[ $1 != "--no-configure" ]]; then
configure
else
shift
fi
### Set the $PUBLIC_KEY env var to enable SSH access.
# We do not install openssh-server in the image by default to avoid bloat.
# but it is useful to have the full SSH server e.g. on Runpod.
# (use SCP to copy files to/from the image, etc)
if [[ -v "PUBLIC_KEY" ]] && [[ ! -d "${HOME}/.ssh" ]]; then
apt-get update
apt-get install -y openssh-server
pushd $HOME
mkdir -p .ssh
echo ${PUBLIC_KEY} > .ssh/authorized_keys
chmod -R 700 .ssh
popd
service ssh start
fi
cd ${INVOKEAI_ROOT}
# Run the CMD as the Container User (not root).
exec gosu ${USER} "$@"

View File

@ -1,54 +0,0 @@
#!/usr/bin/env bash
# This file is used to set environment variables for the build.sh and run.sh scripts.
# Try to detect the container flavor if no PIP_EXTRA_INDEX_URL got specified
if [[ -z "$PIP_EXTRA_INDEX_URL" ]]; then
# Activate virtual environment if not already activated and exists
if [[ -z $VIRTUAL_ENV ]]; then
[[ -e "$(dirname "${BASH_SOURCE[0]}")/../.venv/bin/activate" ]] \
&& source "$(dirname "${BASH_SOURCE[0]}")/../.venv/bin/activate" \
&& echo "Activated virtual environment: $VIRTUAL_ENV"
fi
# Decide which container flavor to build if not specified
if [[ -z "$CONTAINER_FLAVOR" ]] && python -c "import torch" &>/dev/null; then
# Check for CUDA and ROCm
CUDA_AVAILABLE=$(python -c "import torch;print(torch.cuda.is_available())")
ROCM_AVAILABLE=$(python -c "import torch;print(torch.version.hip is not None)")
if [[ "${CUDA_AVAILABLE}" == "True" ]]; then
CONTAINER_FLAVOR="cuda"
elif [[ "${ROCM_AVAILABLE}" == "True" ]]; then
CONTAINER_FLAVOR="rocm"
else
CONTAINER_FLAVOR="cpu"
fi
fi
# Set PIP_EXTRA_INDEX_URL based on container flavor
if [[ "$CONTAINER_FLAVOR" == "rocm" ]]; then
PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/rocm"
elif [[ "$CONTAINER_FLAVOR" == "cpu" ]]; then
PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu"
# elif [[ -z "$CONTAINER_FLAVOR" || "$CONTAINER_FLAVOR" == "cuda" ]]; then
# PIP_PACKAGE=${PIP_PACKAGE-".[xformers]"}
fi
fi
# Variables shared by build.sh and run.sh
REPOSITORY_NAME="${REPOSITORY_NAME-$(basename "$(git rev-parse --show-toplevel)")}"
REPOSITORY_NAME="${REPOSITORY_NAME,,}"
VOLUMENAME="${VOLUMENAME-"${REPOSITORY_NAME}_data"}"
ARCH="${ARCH-$(uname -m)}"
PLATFORM="${PLATFORM-linux/${ARCH}}"
INVOKEAI_BRANCH="${INVOKEAI_BRANCH-$(git branch --show)}"
CONTAINER_REGISTRY="${CONTAINER_REGISTRY-"ghcr.io"}"
CONTAINER_REPOSITORY="${CONTAINER_REPOSITORY-"$(whoami)/${REPOSITORY_NAME}"}"
CONTAINER_FLAVOR="${CONTAINER_FLAVOR-cuda}"
CONTAINER_TAG="${CONTAINER_TAG-"${INVOKEAI_BRANCH##*/}-${CONTAINER_FLAVOR}"}"
CONTAINER_IMAGE="${CONTAINER_REGISTRY}/${CONTAINER_REPOSITORY}:${CONTAINER_TAG}"
CONTAINER_IMAGE="${CONTAINER_IMAGE,,}"
# enable docker buildkit
export DOCKER_BUILDKIT=1

View File

@ -1,41 +1,8 @@
#!/usr/bin/env bash #!/usr/bin/env bash
set -e set -e
# How to use: https://invoke-ai.github.io/InvokeAI/installation/040_INSTALL_DOCKER/
SCRIPTDIR=$(dirname "${BASH_SOURCE[0]}") SCRIPTDIR=$(dirname "${BASH_SOURCE[0]}")
cd "$SCRIPTDIR" || exit 1 cd "$SCRIPTDIR" || exit 1
source ./env.sh docker-compose up --build -d
docker-compose logs -f
# Create outputs directory if it does not exist
[[ -d ./outputs ]] || mkdir ./outputs
echo -e "You are using these values:\n"
echo -e "Volumename:\t${VOLUMENAME}"
echo -e "Invokeai_tag:\t${CONTAINER_IMAGE}"
echo -e "local Models:\t${MODELSPATH:-unset}\n"
docker run \
--interactive \
--tty \
--rm \
--platform="${PLATFORM}" \
--name="${REPOSITORY_NAME}" \
--hostname="${REPOSITORY_NAME}" \
--mount type=volume,volume-driver=local,source="${VOLUMENAME}",target=/data \
--mount type=bind,source="$(pwd)"/outputs/,target=/data/outputs/ \
${MODELSPATH:+--mount="type=bind,source=${MODELSPATH},target=/data/models"} \
${HUGGING_FACE_HUB_TOKEN:+--env="HUGGING_FACE_HUB_TOKEN=${HUGGING_FACE_HUB_TOKEN}"} \
--publish=9090:9090 \
--cap-add=sys_nice \
${GPU_FLAGS:+--gpus="${GPU_FLAGS}"} \
"${CONTAINER_IMAGE}" ${@:+$@}
echo -e "\nCleaning trash folder ..."
for f in outputs/.Trash*; do
if [ -e "$f" ]; then
rm -Rf "$f"
break
fi
done

60
docker/runpod-readme.md Normal file
View File

@ -0,0 +1,60 @@
# InvokeAI - A Stable Diffusion Toolkit
Stable Diffusion distribution by InvokeAI: https://github.com/invoke-ai
The Docker image tracks the `main` branch of the InvokeAI project, which means it includes the latest features, but may contain some bugs.
Your working directory is mounted under the `/workspace` path inside the pod. The models are in `/workspace/invokeai/models`, and outputs are in `/workspace/invokeai/outputs`.
> **Only the /workspace directory will persist between pod restarts!**
> **If you _terminate_ (not just _stop_) the pod, the /workspace will be lost.**
## Quickstart
1. Launch a pod from this template. **It will take about 5-10 minutes to run through the initial setup**. Be patient.
1. Wait for the application to load.
- TIP: you know it's ready when the CPU usage goes idle
- You can also check the logs for a line that says "_Point your browser at..._"
1. Open the Invoke AI web UI: click the `Connect` => `connect over HTTP` button.
1. Generate some art!
## Other things you can do
At any point you may edit the pod configuration and set an arbitrary Docker command. For example, you could run a command to downloads some models using `curl`, or fetch some images and place them into your outputs to continue a working session.
If you need to run *multiple commands*, define them in the Docker Command field like this:
`bash -c "cd ${INVOKEAI_ROOT}/outputs; wormhole receive 2-foo-bar; invoke.py --web --host 0.0.0.0"`
### Copying your data in and out of the pod
This image includes a couple of handy tools to help you get the data into the pod (such as your custom models or embeddings), and out of the pod (such as downloading your outputs). Here are your options for getting your data in and out of the pod:
- **SSH server**:
1. Make sure to create and set your Public Key in the RunPod settings (follow the official instructions)
1. Add an exposed port 22 (TCP) in the pod settings!
1. When your pod restarts, you will see a new entry in the `Connect` dialog. Use this SSH server to `scp` or `sftp` your files as necessary, or SSH into the pod using the fully fledged SSH server.
- [**Magic Wormhole**](https://magic-wormhole.readthedocs.io/en/latest/welcome.html):
1. On your computer, `pip install magic-wormhole` (see above instructions for details)
1. Connect to the command line **using the "light" SSH client** or the browser-based console. _Currently there's a bug where `wormhole` isn't available when connected to "full" SSH server, as described above_.
1. `wormhole send /workspace/invokeai/outputs` will send the entire `outputs` directory. You can also send individual files.
1. Once packaged, you will see a `wormhole receive <123-some-words>` command. Copy it
1. Paste this command into the terminal on your local machine to securely download the payload.
1. It works the same in reverse: you can `wormhole send` some models from your computer to the pod. Again, save your files somewhere in `/workspace` or they will be lost when the pod is stopped.
- **RunPod's Cloud Sync feature** may be used to sync the persistent volume to cloud storage. You could, for example, copy the entire `/workspace` to S3, add some custom models to it, and copy it back from S3 when launching new pod configurations. Follow the Cloud Sync instructions.
### Disable the NSFW checker
The NSFW checker is enabled by default. To disable it, edit the pod configuration and set the following command:
```
invoke --web --host 0.0.0.0 --no-nsfw_checker
```
---
Template ©2023 Eugene Brodsky [ebr](https://github.com/ebr)

View File

@ -248,6 +248,7 @@ class InvokeAiInstance:
"install", "install",
"--require-virtualenv", "--require-virtualenv",
"torch~=2.0.0", "torch~=2.0.0",
"torchmetrics==0.11.4",
"torchvision>=0.14.1", "torchvision>=0.14.1",
"--force-reinstall", "--force-reinstall",
"--find-links" if find_links is not None else None, "--find-links" if find_links is not None else None,

View File

@ -13,7 +13,6 @@ from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
from invokeai.app.services.boards import BoardService, BoardServiceDependencies from invokeai.app.services.boards import BoardService, BoardServiceDependencies
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService, ImageServiceDependencies from invokeai.app.services.images import ImageService, ImageServiceDependencies
from invokeai.app.services.metadata import CoreMetadataService
from invokeai.app.services.resource_name import SimpleNameService from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService from invokeai.app.services.urls import LocalUrlService
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
@ -75,7 +74,6 @@ class ApiDependencies:
) )
urls = LocalUrlService() urls = LocalUrlService()
metadata = CoreMetadataService()
image_record_storage = SqliteImageRecordStorage(db_location) image_record_storage = SqliteImageRecordStorage(db_location)
image_file_storage = DiskImageFileStorage(f"{output_folder}/images") image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
names = SimpleNameService() names = SimpleNameService()
@ -111,7 +109,6 @@ class ApiDependencies:
board_image_record_storage=board_image_record_storage, board_image_record_storage=board_image_record_storage,
image_record_storage=image_record_storage, image_record_storage=image_record_storage,
image_file_storage=image_file_storage, image_file_storage=image_file_storage,
metadata=metadata,
url=urls, url=urls,
logger=logger, logger=logger,
names=names, names=names,

View File

@ -1,20 +1,19 @@
import io import io
from typing import Optional from typing import Optional
from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile
from fastapi.routing import APIRouter from fastapi import (Body, HTTPException, Path, Query, Request, Response,
UploadFile)
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from fastapi.routing import APIRouter
from PIL import Image from PIL import Image
from invokeai.app.models.image import (
ImageCategory, from invokeai.app.invocations.metadata import ImageMetadata
ResourceOrigin, from invokeai.app.models.image import ImageCategory, ResourceOrigin
)
from invokeai.app.services.image_record_storage import OffsetPaginatedResults from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from invokeai.app.services.models.image_record import (
ImageDTO,
ImageRecordChanges,
ImageUrlsDTO,
)
from invokeai.app.services.item_storage import PaginatedResults from invokeai.app.services.item_storage import PaginatedResults
from invokeai.app.services.models.image_record import (ImageDTO,
ImageRecordChanges,
ImageUrlsDTO)
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
@ -103,23 +102,38 @@ async def update_image(
@images_router.get( @images_router.get(
"/{image_name}/metadata", "/{image_name}",
operation_id="get_image_metadata", operation_id="get_image_dto",
response_model=ImageDTO, response_model=ImageDTO,
) )
async def get_image_metadata( async def get_image_dto(
image_name: str = Path(description="The name of image to get"), image_name: str = Path(description="The name of image to get"),
) -> ImageDTO: ) -> ImageDTO:
"""Gets an image's metadata""" """Gets an image's DTO"""
try: try:
return ApiDependencies.invoker.services.images.get_dto(image_name) return ApiDependencies.invoker.services.images.get_dto(image_name)
except Exception as e: except Exception as e:
raise HTTPException(status_code=404) raise HTTPException(status_code=404)
@images_router.get(
"/{image_name}/metadata",
operation_id="get_image_metadata",
response_model=ImageMetadata,
)
async def get_image_metadata(
image_name: str = Path(description="The name of image to get"),
) -> ImageMetadata:
"""Gets an image's metadata"""
try:
return ApiDependencies.invoker.services.images.get_metadata(image_name)
except Exception as e:
raise HTTPException(status_code=404)
@images_router.get( @images_router.get(
"/{image_name}", "/{image_name}/full",
operation_id="get_image_full", operation_id="get_image_full",
response_class=Response, response_class=Response,
responses={ responses={
@ -208,10 +222,10 @@ async def get_image_urls(
@images_router.get( @images_router.get(
"/", "/",
operation_id="list_images_with_metadata", operation_id="list_image_dtos",
response_model=OffsetPaginatedResults[ImageDTO], response_model=OffsetPaginatedResults[ImageDTO],
) )
async def list_images_with_metadata( async def list_image_dtos(
image_origin: Optional[ResourceOrigin] = Query( image_origin: Optional[ResourceOrigin] = Query(
default=None, description="The origin of images to list" default=None, description="The origin of images to list"
), ),
@ -227,7 +241,7 @@ async def list_images_with_metadata(
offset: int = Query(default=0, description="The page offset"), offset: int = Query(default=0, description="The page offset"),
limit: int = Query(default=10, description="The number of images per page"), limit: int = Query(default=10, description="The number of images per page"),
) -> OffsetPaginatedResults[ImageDTO]: ) -> OffsetPaginatedResults[ImageDTO]:
"""Gets a list of images""" """Gets a list of image DTOs"""
image_dtos = ApiDependencies.invoker.services.images.get_many( image_dtos = ApiDependencies.invoker.services.images.get_many(
offset, offset,

View File

@ -34,7 +34,6 @@ from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
from invokeai.app.services.boards import BoardService, BoardServiceDependencies from invokeai.app.services.boards import BoardService, BoardServiceDependencies
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService, ImageServiceDependencies from invokeai.app.services.images import ImageService, ImageServiceDependencies
from invokeai.app.services.metadata import CoreMetadataService
from invokeai.app.services.resource_name import SimpleNameService from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService from invokeai.app.services.urls import LocalUrlService
from .services.default_graphs import (default_text_to_image_graph_id, from .services.default_graphs import (default_text_to_image_graph_id,
@ -244,7 +243,6 @@ def invoke_cli():
) )
urls = LocalUrlService() urls = LocalUrlService()
metadata = CoreMetadataService()
image_record_storage = SqliteImageRecordStorage(db_location) image_record_storage = SqliteImageRecordStorage(db_location)
image_file_storage = DiskImageFileStorage(f"{output_folder}/images") image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
names = SimpleNameService() names = SimpleNameService()
@ -277,7 +275,6 @@ def invoke_cli():
board_image_record_storage=board_image_record_storage, board_image_record_storage=board_image_record_storage,
image_record_storage=image_record_storage, image_record_storage=image_record_storage,
image_file_storage=image_file_storage, image_file_storage=image_file_storage,
metadata=metadata,
url=urls, url=urls,
logger=logger, logger=logger,
names=names, names=names,

View File

@ -154,18 +154,20 @@ class InpaintInvocation(BaseInvocation):
@contextmanager @contextmanager
def load_model_old_way(self, context, scheduler): def load_model_old_way(self, context, scheduler):
def _lora_loader():
for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}))
yield (lora_info.context.model, lora.weight)
del lora_info
return
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict()) unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
vae_info = context.services.model_manager.get_model(**self.vae.vae.dict()) vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
#unet = unet_info.context.model
#vae = vae_info.context.model
with ExitStack() as stack:
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
with vae_info as vae,\ with vae_info as vae,\
unet_info as unet,\ ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
ModelPatcher.apply_lora_unet(unet, loras): unet_info as unet:
device = context.services.model_manager.mgr.cache.execution_device device = context.services.model_manager.mgr.cache.execution_device
dtype = context.services.model_manager.mgr.cache.precision dtype = context.services.model_manager.mgr.cache.precision

View File

@ -9,9 +9,9 @@ from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import SchedulerMixin as Scheduler from diffusers.schedulers import SchedulerMixin as Scheduler
from pydantic import BaseModel, Field, validator from pydantic import BaseModel, Field, validator
from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.app.util.step_callback import stable_diffusion_step_callback
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from ...backend.model_management.lora import ModelPatcher from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import ( from ...backend.stable_diffusion.diffusers_pipeline import (
@ -21,6 +21,7 @@ from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \
PostprocessingSettings PostprocessingSettings
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.util.devices import torch_dtype from ...backend.util.devices import torch_dtype
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import (BaseInvocation, BaseInvocationOutput, from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext) InvocationConfig, InvocationContext)
from .compel import ConditioningField from .compel import ConditioningField
@ -449,6 +450,8 @@ class LatentsToImageInvocation(BaseInvocation):
tiled: bool = Field( tiled: bool = Field(
default=False, default=False,
description="Decode latents by overlaping tiles(less memory consumption)") description="Decode latents by overlaping tiles(less memory consumption)")
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
# Schema customisation # Schema customisation
class Config(InvocationConfig): class Config(InvocationConfig):
@ -493,7 +496,8 @@ class LatentsToImageInvocation(BaseInvocation):
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate is_intermediate=self.is_intermediate,
metadata=self.metadata.dict() if self.metadata else None,
) )
return ImageOutput( return ImageOutput(

View File

@ -0,0 +1,124 @@
from typing import Literal, Optional, Union
from pydantic import BaseModel, Field
from invokeai.app.invocations.baseinvocation import (BaseInvocation,
BaseInvocationOutput,
InvocationContext)
from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.model import (LoRAModelField, MainModelField,
VAEModelField)
class LoRAMetadataField(BaseModel):
"""LoRA metadata for an image generated in InvokeAI."""
lora: LoRAModelField = Field(description="The LoRA model")
weight: float = Field(description="The weight of the LoRA model")
class CoreMetadata(BaseModel):
"""Core generation metadata for an image generated in InvokeAI."""
generation_mode: str = Field(description="The generation mode that output this image",)
positive_prompt: str = Field(description="The positive prompt parameter")
negative_prompt: str = Field(description="The negative prompt parameter")
width: int = Field(description="The width parameter")
height: int = Field(description="The height parameter")
seed: int = Field(description="The seed used for noise generation")
rand_device: str = Field(description="The device used for random number generation")
cfg_scale: float = Field(description="The classifier-free guidance scale parameter")
steps: int = Field(description="The number of steps used for inference")
scheduler: str = Field(description="The scheduler used for inference")
clip_skip: int = Field(description="The number of skipped CLIP layers",)
model: MainModelField = Field(description="The main model used for inference")
controlnets: list[ControlField]= Field(description="The ControlNets used for inference")
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
strength: Union[float, None] = Field(
default=None,
description="The strength used for latents-to-latents",
)
init_image: Union[str, None] = Field(
default=None, description="The name of the initial image"
)
vae: Union[VAEModelField, None] = Field(
default=None,
description="The VAE used for decoding, if the main model's default was not used",
)
class ImageMetadata(BaseModel):
"""An image's generation metadata"""
metadata: Optional[dict] = Field(
default=None,
description="The image's core metadata, if it was created in the Linear or Canvas UI",
)
graph: Optional[dict] = Field(
default=None, description="The graph that created the image"
)
class MetadataAccumulatorOutput(BaseInvocationOutput):
"""The output of the MetadataAccumulator node"""
type: Literal["metadata_accumulator_output"] = "metadata_accumulator_output"
metadata: CoreMetadata = Field(description="The core metadata for the image")
class MetadataAccumulatorInvocation(BaseInvocation):
"""Outputs a Core Metadata Object"""
type: Literal["metadata_accumulator"] = "metadata_accumulator"
generation_mode: str = Field(description="The generation mode that output this image",)
positive_prompt: str = Field(description="The positive prompt parameter")
negative_prompt: str = Field(description="The negative prompt parameter")
width: int = Field(description="The width parameter")
height: int = Field(description="The height parameter")
seed: int = Field(description="The seed used for noise generation")
rand_device: str = Field(description="The device used for random number generation")
cfg_scale: float = Field(description="The classifier-free guidance scale parameter")
steps: int = Field(description="The number of steps used for inference")
scheduler: str = Field(description="The scheduler used for inference")
clip_skip: int = Field(description="The number of skipped CLIP layers",)
model: MainModelField = Field(description="The main model used for inference")
controlnets: list[ControlField]= Field(description="The ControlNets used for inference")
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
strength: Union[float, None] = Field(
default=None,
description="The strength used for latents-to-latents",
)
init_image: Union[str, None] = Field(
default=None, description="The name of the initial image"
)
vae: Union[VAEModelField, None] = Field(
default=None,
description="The VAE used for decoding, if the main model's default was not used",
)
def invoke(self, context: InvocationContext) -> MetadataAccumulatorOutput:
"""Collects and outputs a CoreMetadata object"""
return MetadataAccumulatorOutput(
metadata=CoreMetadata(
generation_mode=self.generation_mode,
positive_prompt=self.positive_prompt,
negative_prompt=self.negative_prompt,
width=self.width,
height=self.height,
seed=self.seed,
rand_device=self.rand_device,
cfg_scale=self.cfg_scale,
steps=self.steps,
scheduler=self.scheduler,
model=self.model,
strength=self.strength,
init_image=self.init_image,
vae=self.vae,
controlnets=self.controlnets,
loras=self.loras,
clip_skip=self.clip_skip,
)
)

View File

@ -1,93 +0,0 @@
from typing import Optional, Union, List
from pydantic import BaseModel, Extra, Field, StrictFloat, StrictInt, StrictStr
class ImageMetadata(BaseModel):
"""
Core generation metadata for an image/tensor generated in InvokeAI.
Also includes any metadata from the image's PNG tEXt chunks.
Generated by traversing the execution graph, collecting the parameters of the nearest ancestors
of a given node.
Full metadata may be accessed by querying for the session in the `graph_executions` table.
"""
class Config:
extra = Extra.allow
"""
This lets the ImageMetadata class accept arbitrary additional fields. The CoreMetadataService
won't add any fields that are not already defined, but other a different metadata service
implementation might.
"""
type: Optional[StrictStr] = Field(
default=None,
description="The type of the ancestor node of the image output node.",
)
"""The type of the ancestor node of the image output node."""
positive_conditioning: Optional[StrictStr] = Field(
default=None, description="The positive conditioning."
)
"""The positive conditioning"""
negative_conditioning: Optional[StrictStr] = Field(
default=None, description="The negative conditioning."
)
"""The negative conditioning"""
width: Optional[StrictInt] = Field(
default=None, description="Width of the image/latents in pixels."
)
"""Width of the image/latents in pixels"""
height: Optional[StrictInt] = Field(
default=None, description="Height of the image/latents in pixels."
)
"""Height of the image/latents in pixels"""
seed: Optional[StrictInt] = Field(
default=None, description="The seed used for noise generation."
)
"""The seed used for noise generation"""
# cfg_scale: Optional[StrictFloat] = Field(
# cfg_scale: Union[float, list[float]] = Field(
cfg_scale: Union[StrictFloat, List[StrictFloat]] = Field(
default=None, description="The classifier-free guidance scale."
)
"""The classifier-free guidance scale"""
steps: Optional[StrictInt] = Field(
default=None, description="The number of steps used for inference."
)
"""The number of steps used for inference"""
scheduler: Optional[StrictStr] = Field(
default=None, description="The scheduler used for inference."
)
"""The scheduler used for inference"""
model: Optional[StrictStr] = Field(
default=None, description="The model used for inference."
)
"""The model used for inference"""
strength: Optional[StrictFloat] = Field(
default=None,
description="The strength used for image-to-image/latents-to-latents.",
)
"""The strength used for image-to-image/latents-to-latents."""
latents: Optional[StrictStr] = Field(
default=None, description="The ID of the initial latents."
)
"""The ID of the initial latents"""
vae: Optional[StrictStr] = Field(
default=None, description="The VAE used for decoding."
)
"""The VAE used for decoding"""
unet: Optional[StrictStr] = Field(
default=None, description="The UNet used dor inference."
)
"""The UNet used dor inference"""
clip: Optional[StrictStr] = Field(
default=None, description="The CLIP Encoder used for conditioning."
)
"""The CLIP Encoder used for conditioning"""
extra: Optional[StrictStr] = Field(
default=None,
description="Uploaded image metadata, extracted from the PNG tEXt chunk.",
)
"""Uploaded image metadata, extracted from the PNG tEXt chunk."""

View File

@ -1,14 +1,14 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from queue import Queue from queue import Queue
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
from PIL.Image import Image as PILImageType
from PIL import Image, PngImagePlugin from PIL import Image, PngImagePlugin
from PIL.Image import Image as PILImageType
from send2trash import send2trash from send2trash import send2trash
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
@ -59,7 +59,8 @@ class ImageFileStorageBase(ABC):
self, self,
image: PILImageType, image: PILImageType,
image_name: str, image_name: str,
metadata: Optional[ImageMetadata] = None, metadata: Optional[dict] = None,
graph: Optional[dict] = None,
thumbnail_size: int = 256, thumbnail_size: int = 256,
) -> None: ) -> None:
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp.""" """Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
@ -110,20 +111,22 @@ class DiskImageFileStorage(ImageFileStorageBase):
self, self,
image: PILImageType, image: PILImageType,
image_name: str, image_name: str,
metadata: Optional[ImageMetadata] = None, metadata: Optional[dict] = None,
graph: Optional[dict] = None,
thumbnail_size: int = 256, thumbnail_size: int = 256,
) -> None: ) -> None:
try: try:
self.__validate_storage_folders() self.__validate_storage_folders()
image_path = self.get_path(image_name) image_path = self.get_path(image_name)
if metadata is not None:
pnginfo = PngImagePlugin.PngInfo() pnginfo = PngImagePlugin.PngInfo()
pnginfo.add_text("invokeai", metadata.json())
image.save(image_path, "PNG", pnginfo=pnginfo)
else:
image.save(image_path, "PNG")
if metadata is not None:
pnginfo.add_text("invokeai_metadata", json.dumps(metadata))
if graph is not None:
pnginfo.add_text("invokeai_graph", json.dumps(graph))
image.save(image_path, "PNG", pnginfo=pnginfo)
thumbnail_name = get_thumbnail_name(image_name) thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(thumbnail_name, thumbnail=True) thumbnail_path = self.get_path(thumbnail_name, thumbnail=True)
thumbnail_image = make_thumbnail(image, thumbnail_size) thumbnail_image = make_thumbnail(image, thumbnail_size)

View File

@ -1,3 +1,4 @@
import json
import sqlite3 import sqlite3
import threading import threading
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@ -8,7 +9,6 @@ from pydantic import BaseModel, Field
from pydantic.generics import GenericModel from pydantic.generics import GenericModel
from invokeai.app.models.image import ImageCategory, ResourceOrigin from invokeai.app.models.image import ImageCategory, ResourceOrigin
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.services.models.image_record import ( from invokeai.app.services.models.image_record import (
ImageRecord, ImageRecordChanges, deserialize_image_record) ImageRecord, ImageRecordChanges, deserialize_image_record)
@ -48,6 +48,28 @@ class ImageRecordDeleteException(Exception):
super().__init__(message) super().__init__(message)
IMAGE_DTO_COLS = ", ".join(
list(
map(
lambda c: "images." + c,
[
"image_name",
"image_origin",
"image_category",
"width",
"height",
"session_id",
"node_id",
"is_intermediate",
"created_at",
"updated_at",
"deleted_at",
],
)
)
)
class ImageRecordStorageBase(ABC): class ImageRecordStorageBase(ABC):
"""Low-level service responsible for interfacing with the image record store.""" """Low-level service responsible for interfacing with the image record store."""
@ -58,6 +80,11 @@ class ImageRecordStorageBase(ABC):
"""Gets an image record.""" """Gets an image record."""
pass pass
@abstractmethod
def get_metadata(self, image_name: str) -> Optional[dict]:
"""Gets an image's metadata'."""
pass
@abstractmethod @abstractmethod
def update( def update(
self, self,
@ -102,7 +129,7 @@ class ImageRecordStorageBase(ABC):
height: int, height: int,
session_id: Optional[str], session_id: Optional[str],
node_id: Optional[str], node_id: Optional[str],
metadata: Optional[ImageMetadata], metadata: Optional[dict],
is_intermediate: bool = False, is_intermediate: bool = False,
) -> datetime: ) -> datetime:
"""Saves an image record.""" """Saves an image record."""
@ -206,7 +233,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
self._cursor.execute( self._cursor.execute(
f"""--sql f"""--sql
SELECT * FROM images SELECT {IMAGE_DTO_COLS} FROM images
WHERE image_name = ?; WHERE image_name = ?;
""", """,
(image_name,), (image_name,),
@ -224,6 +251,28 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
return deserialize_image_record(dict(result)) return deserialize_image_record(dict(result))
def get_metadata(self, image_name: str) -> Optional[dict]:
try:
self._lock.acquire()
self._cursor.execute(
f"""--sql
SELECT images.metadata FROM images
WHERE image_name = ?;
""",
(image_name,),
)
result = cast(Optional[sqlite3.Row], self._cursor.fetchone())
if not result or not result[0]:
return None
return json.loads(result[0])
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordNotFoundException from e
finally:
self._lock.release()
def update( def update(
self, self,
image_name: str, image_name: str,
@ -291,8 +340,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
WHERE 1=1 WHERE 1=1
""" """
images_query = """--sql images_query = f"""--sql
SELECT images.* SELECT {IMAGE_DTO_COLS}
FROM images FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1 WHERE 1=1
@ -410,12 +459,12 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
width: int, width: int,
height: int, height: int,
node_id: Optional[str], node_id: Optional[str],
metadata: Optional[ImageMetadata], metadata: Optional[dict],
is_intermediate: bool = False, is_intermediate: bool = False,
) -> datetime: ) -> datetime:
try: try:
metadata_json = ( metadata_json = (
None if metadata is None else metadata.json(exclude_none=True) None if metadata is None else json.dumps(metadata)
) )
self._lock.acquire() self._lock.acquire()
self._cursor.execute( self._cursor.execute(
@ -465,9 +514,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
finally: finally:
self._lock.release() self._lock.release()
def get_most_recent_image_for_board( def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
self, board_id: str
) -> Optional[ImageRecord]:
try: try:
self._lock.acquire() self._lock.acquire()
self._cursor.execute( self._cursor.execute(

View File

@ -1,39 +1,30 @@
import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from logging import Logger from logging import Logger
from typing import Optional, TYPE_CHECKING, Union from typing import TYPE_CHECKING, Optional
from PIL.Image import Image as PILImageType from PIL.Image import Image as PILImageType
from invokeai.app.models.image import ( from invokeai.app.invocations.metadata import ImageMetadata
ImageCategory, from invokeai.app.models.image import (ImageCategory,
ResourceOrigin,
InvalidImageCategoryException, InvalidImageCategoryException,
InvalidOriginException, InvalidOriginException, ResourceOrigin)
) from invokeai.app.services.board_image_record_storage import \
from invokeai.app.models.metadata import ImageMetadata BoardImageRecordStorageBase
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase from invokeai.app.services.graph import Graph
from invokeai.app.services.image_record_storage import (
ImageRecordDeleteException,
ImageRecordNotFoundException,
ImageRecordSaveException,
ImageRecordStorageBase,
OffsetPaginatedResults,
)
from invokeai.app.services.models.image_record import (
ImageRecord,
ImageDTO,
ImageRecordChanges,
image_record_to_dto,
)
from invokeai.app.services.image_file_storage import ( from invokeai.app.services.image_file_storage import (
ImageFileDeleteException, ImageFileDeleteException, ImageFileNotFoundException,
ImageFileNotFoundException, ImageFileSaveException, ImageFileStorageBase)
ImageFileSaveException, from invokeai.app.services.image_record_storage import (
ImageFileStorageBase, ImageRecordDeleteException, ImageRecordNotFoundException,
) ImageRecordSaveException, ImageRecordStorageBase, OffsetPaginatedResults)
from invokeai.app.services.item_storage import ItemStorageABC, PaginatedResults from invokeai.app.services.item_storage import ItemStorageABC
from invokeai.app.services.metadata import MetadataServiceBase from invokeai.app.services.models.image_record import (ImageDTO, ImageRecord,
ImageRecordChanges,
image_record_to_dto)
from invokeai.app.services.resource_name import NameServiceBase from invokeai.app.services.resource_name import NameServiceBase
from invokeai.app.services.urls import UrlServiceBase from invokeai.app.services.urls import UrlServiceBase
from invokeai.app.util.metadata import get_metadata_graph_from_raw_session
if TYPE_CHECKING: if TYPE_CHECKING:
from invokeai.app.services.graph import GraphExecutionState from invokeai.app.services.graph import GraphExecutionState
@ -51,6 +42,7 @@ class ImageServiceABC(ABC):
node_id: Optional[str] = None, node_id: Optional[str] = None,
session_id: Optional[str] = None, session_id: Optional[str] = None,
is_intermediate: bool = False, is_intermediate: bool = False,
metadata: Optional[dict] = None,
) -> ImageDTO: ) -> ImageDTO:
"""Creates an image, storing the file and its metadata.""" """Creates an image, storing the file and its metadata."""
pass pass
@ -79,6 +71,11 @@ class ImageServiceABC(ABC):
"""Gets an image DTO.""" """Gets an image DTO."""
pass pass
@abstractmethod
def get_metadata(self, image_name: str) -> ImageMetadata:
"""Gets an image's metadata."""
pass
@abstractmethod @abstractmethod
def get_path(self, image_name: str, thumbnail: bool = False) -> str: def get_path(self, image_name: str, thumbnail: bool = False) -> str:
"""Gets an image's path.""" """Gets an image's path."""
@ -124,7 +121,6 @@ class ImageServiceDependencies:
image_records: ImageRecordStorageBase image_records: ImageRecordStorageBase
image_files: ImageFileStorageBase image_files: ImageFileStorageBase
board_image_records: BoardImageRecordStorageBase board_image_records: BoardImageRecordStorageBase
metadata: MetadataServiceBase
urls: UrlServiceBase urls: UrlServiceBase
logger: Logger logger: Logger
names: NameServiceBase names: NameServiceBase
@ -135,7 +131,6 @@ class ImageServiceDependencies:
image_record_storage: ImageRecordStorageBase, image_record_storage: ImageRecordStorageBase,
image_file_storage: ImageFileStorageBase, image_file_storage: ImageFileStorageBase,
board_image_record_storage: BoardImageRecordStorageBase, board_image_record_storage: BoardImageRecordStorageBase,
metadata: MetadataServiceBase,
url: UrlServiceBase, url: UrlServiceBase,
logger: Logger, logger: Logger,
names: NameServiceBase, names: NameServiceBase,
@ -144,7 +139,6 @@ class ImageServiceDependencies:
self.image_records = image_record_storage self.image_records = image_record_storage
self.image_files = image_file_storage self.image_files = image_file_storage
self.board_image_records = board_image_record_storage self.board_image_records = board_image_record_storage
self.metadata = metadata
self.urls = url self.urls = url
self.logger = logger self.logger = logger
self.names = names self.names = names
@ -165,6 +159,7 @@ class ImageService(ImageServiceABC):
node_id: Optional[str] = None, node_id: Optional[str] = None,
session_id: Optional[str] = None, session_id: Optional[str] = None,
is_intermediate: bool = False, is_intermediate: bool = False,
metadata: Optional[dict] = None,
) -> ImageDTO: ) -> ImageDTO:
if image_origin not in ResourceOrigin: if image_origin not in ResourceOrigin:
raise InvalidOriginException raise InvalidOriginException
@ -174,7 +169,16 @@ class ImageService(ImageServiceABC):
image_name = self._services.names.create_image_name() image_name = self._services.names.create_image_name()
metadata = self._get_metadata(session_id, node_id) graph = None
if session_id is not None:
session_raw = self._services.graph_execution_manager.get_raw(session_id)
if session_raw is not None:
try:
graph = get_metadata_graph_from_raw_session(session_raw)
except Exception as e:
self._services.logger.warn(f"Failed to parse session graph: {e}")
graph = None
(width, height) = image.size (width, height) = image.size
@ -191,14 +195,12 @@ class ImageService(ImageServiceABC):
is_intermediate=is_intermediate, is_intermediate=is_intermediate,
# Nullable fields # Nullable fields
node_id=node_id, node_id=node_id,
session_id=session_id,
metadata=metadata, metadata=metadata,
session_id=session_id,
) )
self._services.image_files.save( self._services.image_files.save(
image_name=image_name, image_name=image_name, image=image, metadata=metadata, graph=graph
image=image,
metadata=metadata,
) )
image_dto = self.get_dto(image_name) image_dto = self.get_dto(image_name)
@ -268,6 +270,34 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem getting image DTO") self._services.logger.error("Problem getting image DTO")
raise e raise e
def get_metadata(self, image_name: str) -> Optional[ImageMetadata]:
try:
image_record = self._services.image_records.get(image_name)
if not image_record.session_id:
return ImageMetadata()
session_raw = self._services.graph_execution_manager.get_raw(
image_record.session_id
)
graph = None
if session_raw:
try:
graph = get_metadata_graph_from_raw_session(session_raw)
except Exception as e:
self._services.logger.warn(f"Failed to parse session graph: {e}")
graph = None
metadata = self._services.image_records.get_metadata(image_name)
return ImageMetadata(graph=graph, metadata=metadata)
except ImageRecordNotFoundException:
self._services.logger.error("Image record not found")
raise
except Exception as e:
self._services.logger.error("Problem getting image DTO")
raise e
def get_path(self, image_name: str, thumbnail: bool = False) -> str: def get_path(self, image_name: str, thumbnail: bool = False) -> str:
try: try:
return self._services.image_files.get_path(image_name, thumbnail) return self._services.image_files.get_path(image_name, thumbnail)
@ -367,15 +397,3 @@ class ImageService(ImageServiceABC):
except Exception as e: except Exception as e:
self._services.logger.error("Problem deleting image records and files") self._services.logger.error("Problem deleting image records and files")
raise e raise e
def _get_metadata(
self, session_id: Optional[str] = None, node_id: Optional[str] = None
) -> Optional[ImageMetadata]:
"""Get the metadata for a node."""
metadata = None
if node_id is not None and session_id is not None:
session = self._services.graph_execution_manager.get(session_id)
metadata = self._services.metadata.create_image_metadata(session, node_id)
return metadata

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Callable, Generic, TypeVar from typing import Callable, Generic, Optional, TypeVar
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from pydantic.generics import GenericModel from pydantic.generics import GenericModel
@ -29,14 +29,22 @@ class ItemStorageABC(ABC, Generic[T]):
@abstractmethod @abstractmethod
def get(self, item_id: str) -> T: def get(self, item_id: str) -> T:
"""Gets the item, parsing it into a Pydantic model"""
pass
@abstractmethod
def get_raw(self, item_id: str) -> Optional[str]:
"""Gets the raw item as a string, skipping Pydantic parsing"""
pass pass
@abstractmethod @abstractmethod
def set(self, item: T) -> None: def set(self, item: T) -> None:
"""Sets the item"""
pass pass
@abstractmethod @abstractmethod
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]: def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
"""Gets a paginated list of items"""
pass pass
@abstractmethod @abstractmethod

View File

@ -1,142 +0,0 @@
from abc import ABC, abstractmethod
from typing import Any, Optional
import networkx as nx
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.services.graph import Graph, GraphExecutionState
class MetadataServiceBase(ABC):
"""Handles building metadata for nodes, images, and outputs."""
@abstractmethod
def create_image_metadata(
self, session: GraphExecutionState, node_id: str
) -> ImageMetadata:
"""Builds an ImageMetadata object for a node."""
pass
class CoreMetadataService(MetadataServiceBase):
_ANCESTOR_TYPES = ["t2l", "l2l"]
"""The ancestor types that contain the core metadata"""
_ANCESTOR_PARAMS = ["type", "steps", "model", "cfg_scale", "scheduler", "strength"]
"""The core metadata parameters in the ancestor types"""
_NOISE_FIELDS = ["seed", "width", "height"]
"""The core metadata parameters in the noise node"""
def create_image_metadata(
self, session: GraphExecutionState, node_id: str
) -> ImageMetadata:
metadata = self._build_metadata_from_graph(session, node_id)
return metadata
def _find_nearest_ancestor(self, G: nx.DiGraph, node_id: str) -> Optional[str]:
"""
Finds the id of the nearest ancestor (of a valid type) of a given node.
Parameters:
G (nx.DiGraph): The execution graph, converted in to a networkx DiGraph. Its nodes must
have the same data as the execution graph.
node_id (str): The ID of the node.
Returns:
str | None: The ID of the nearest ancestor, or None if there are no valid ancestors.
"""
# Retrieve the node from the graph
node = G.nodes[node_id]
# If the node type is one of the core metadata node types, return its id
if node.get("type") in self._ANCESTOR_TYPES:
return node.get("id")
# Else, look for the ancestor in the predecessor nodes
for predecessor in G.predecessors(node_id):
result = self._find_nearest_ancestor(G, predecessor)
if result:
return result
# If there are no valid ancestors, return None
return None
def _get_additional_metadata(
self, graph: Graph, node_id: str
) -> Optional[dict[str, Any]]:
"""
Returns additional metadata for a given node.
Parameters:
graph (Graph): The execution graph.
node_id (str): The ID of the node.
Returns:
dict[str, Any] | None: A dictionary of additional metadata.
"""
metadata = {}
# Iterate over all edges in the graph
for edge in graph.edges:
dest_node_id = edge.destination.node_id
dest_field = edge.destination.field
source_node_dict = graph.nodes[edge.source.node_id].dict()
# If the destination node ID matches the given node ID, gather necessary metadata
if dest_node_id == node_id:
# Prompt
if dest_field == "positive_conditioning":
metadata["positive_conditioning"] = source_node_dict.get("prompt")
# Negative prompt
if dest_field == "negative_conditioning":
metadata["negative_conditioning"] = source_node_dict.get("prompt")
# Seed, width and height
if dest_field == "noise":
for field in self._NOISE_FIELDS:
metadata[field] = source_node_dict.get(field)
return metadata
def _build_metadata_from_graph(
self, session: GraphExecutionState, node_id: str
) -> ImageMetadata:
"""
Builds an ImageMetadata object for a node.
Parameters:
session (GraphExecutionState): The session.
node_id (str): The ID of the node.
Returns:
ImageMetadata: The metadata for the node.
"""
# We need to do all the traversal on the execution graph
graph = session.execution_graph
# Find the nearest `t2l`/`l2l` ancestor of the given node
ancestor_id = self._find_nearest_ancestor(graph.nx_graph_with_data(), node_id)
# If no ancestor was found, return an empty ImageMetadata object
if ancestor_id is None:
return ImageMetadata()
ancestor_node = graph.get_node(ancestor_id)
# Grab all the core metadata from the ancestor node
ancestor_metadata = {
param: val
for param, val in ancestor_node.dict().items()
if param in self._ANCESTOR_PARAMS
}
# Get this image's prompts and noise parameters
addl_metadata = self._get_additional_metadata(graph, ancestor_id)
# If additional metadata was found, add it to the main metadata
if addl_metadata is not None:
ancestor_metadata.update(addl_metadata)
return ImageMetadata(**ancestor_metadata)

View File

@ -1,13 +1,14 @@
import datetime import datetime
from typing import Optional, Union from typing import Optional, Union
from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr
from invokeai.app.models.image import ImageCategory, ResourceOrigin from invokeai.app.models.image import ImageCategory, ResourceOrigin
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.util.misc import get_iso_timestamp from invokeai.app.util.misc import get_iso_timestamp
class ImageRecord(BaseModel): class ImageRecord(BaseModel):
"""Deserialized image record.""" """Deserialized image record without metadata."""
image_name: str = Field(description="The unique name of the image.") image_name: str = Field(description="The unique name of the image.")
"""The unique name of the image.""" """The unique name of the image."""
@ -43,11 +44,6 @@ class ImageRecord(BaseModel):
description="The node ID that generated this image, if it is a generated image.", description="The node ID that generated this image, if it is a generated image.",
) )
"""The node ID that generated this image, if it is a generated image.""" """The node ID that generated this image, if it is a generated image."""
metadata: Optional[ImageMetadata] = Field(
default=None,
description="A limited subset of the image's generation metadata. Retrieve the image's session for full metadata.",
)
"""A limited subset of the image's generation metadata. Retrieve the image's session for full metadata."""
class ImageRecordChanges(BaseModel, extra=Extra.forbid): class ImageRecordChanges(BaseModel, extra=Extra.forbid):
@ -112,6 +108,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
# Retrieve all the values, setting "reasonable" defaults if they are not present. # Retrieve all the values, setting "reasonable" defaults if they are not present.
# TODO: do we really need to handle default values here? ideally the data is the correct shape...
image_name = image_dict.get("image_name", "unknown") image_name = image_dict.get("image_name", "unknown")
image_origin = ResourceOrigin( image_origin = ResourceOrigin(
image_dict.get("image_origin", ResourceOrigin.INTERNAL.value) image_dict.get("image_origin", ResourceOrigin.INTERNAL.value)
@ -128,13 +125,6 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
deleted_at = image_dict.get("deleted_at", get_iso_timestamp()) deleted_at = image_dict.get("deleted_at", get_iso_timestamp())
is_intermediate = image_dict.get("is_intermediate", False) is_intermediate = image_dict.get("is_intermediate", False)
raw_metadata = image_dict.get("metadata")
if raw_metadata is not None:
metadata = ImageMetadata.parse_raw(raw_metadata)
else:
metadata = None
return ImageRecord( return ImageRecord(
image_name=image_name, image_name=image_name,
image_origin=image_origin, image_origin=image_origin,
@ -143,7 +133,6 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
height=height, height=height,
session_id=session_id, session_id=session_id,
node_id=node_id, node_id=node_id,
metadata=metadata,
created_at=created_at, created_at=created_at,
updated_at=updated_at, updated_at=updated_at,
deleted_at=deleted_at, deleted_at=deleted_at,

View File

@ -1,6 +1,6 @@
import sqlite3 import sqlite3
from threading import Lock from threading import Lock
from typing import Generic, TypeVar, Optional, Union, get_args from typing import Generic, Optional, TypeVar, get_args
from pydantic import BaseModel, parse_raw_as from pydantic import BaseModel, parse_raw_as
@ -78,6 +78,21 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
return self._parse_item(result[0]) return self._parse_item(result[0])
def get_raw(self, id: str) -> Optional[str]:
try:
self._lock.acquire()
self._cursor.execute(
f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),)
)
result = self._cursor.fetchone()
finally:
self._lock.release()
if not result:
return None
return result[0]
def delete(self, id: str): def delete(self, id: str):
try: try:
self._lock.acquire() self._lock.acquire()

View File

@ -22,4 +22,4 @@ class LocalUrlService(UrlServiceBase):
if thumbnail: if thumbnail:
return f"{self._base_url}/images/{image_basename}/thumbnail" return f"{self._base_url}/images/{image_basename}/thumbnail"
return f"{self._base_url}/images/{image_basename}" return f"{self._base_url}/images/{image_basename}/full"

View File

@ -0,0 +1,55 @@
import json
from typing import Optional
from pydantic import ValidationError
from invokeai.app.services.graph import Edge
def get_metadata_graph_from_raw_session(session_raw: str) -> Optional[dict]:
"""
Parses raw session string, returning a dict of the graph.
Only the general graph shape is validated; none of the fields are validated.
Any `metadata_accumulator` nodes and edges are removed.
Any validation failure will return None.
"""
graph = json.loads(session_raw).get("graph", None)
# sanity check make sure the graph is at least reasonably shaped
if (
type(graph) is not dict
or "nodes" not in graph
or type(graph["nodes"]) is not dict
or "edges" not in graph
or type(graph["edges"]) is not list
):
# something has gone terribly awry, return an empty dict
return None
try:
# delete the `metadata_accumulator` node
del graph["nodes"]["metadata_accumulator"]
except KeyError:
# no accumulator node, all good
pass
# delete any edges to or from it
for i, edge in enumerate(graph["edges"]):
try:
# try to parse the edge
Edge(**edge)
except ValidationError:
# something has gone terribly awry, return an empty dict
return None
if (
edge["source"]["node_id"] == "metadata_accumulator"
or edge["destination"]["node_id"] == "metadata_accumulator"
):
del graph["edges"][i]
return graph

View File

@ -121,8 +121,8 @@ class ModelInstall(object):
installed_models = self.mgr.list_models() installed_models = self.mgr.list_models()
for md in installed_models: for md in installed_models:
base = md['base_model'] base = md['base_model']
model_type = md['type'] model_type = md['model_type']
name = md['name'] name = md['model_name']
key = ModelManager.create_key(name, base, model_type) key = ModelManager.create_key(name, base, model_type)
if key in model_dict: if key in model_dict:
model_dict[key].installed = True model_dict[key].installed = True

View File

@ -250,7 +250,7 @@ from .model_cache import ModelCache, ModelLocker
from .models import ( from .models import (
BaseModelType, ModelType, SubModelType, BaseModelType, ModelType, SubModelType,
ModelError, SchedulerPredictionType, MODEL_CLASSES, ModelError, SchedulerPredictionType, MODEL_CLASSES,
ModelConfigBase, ModelNotFoundException, ModelConfigBase, ModelNotFoundException, InvalidModelException,
) )
# We are only starting to number the config file with release 3. # We are only starting to number the config file with release 3.
@ -275,10 +275,6 @@ class ModelInfo():
def __exit__(self,*args, **kwargs): def __exit__(self,*args, **kwargs):
self.context.__exit__(*args, **kwargs) self.context.__exit__(*args, **kwargs)
class InvalidModelError(Exception):
"Raised when an invalid model is requested"
pass
class AddModelResult(BaseModel): class AddModelResult(BaseModel):
name: str = Field(description="The name of the model after installation") name: str = Field(description="The name of the model after installation")
model_type: ModelType = Field(description="The type of model") model_type: ModelType = Field(description="The type of model")
@ -542,9 +538,9 @@ class ModelManager(object):
model_dict = dict( model_dict = dict(
**model_config.dict(exclude_defaults=True), **model_config.dict(exclude_defaults=True),
# OpenAPIModelInfoBase # OpenAPIModelInfoBase
name=cur_model_name, model_name=cur_model_name,
base_model=cur_base_model, base_model=cur_base_model,
type=cur_model_type, model_type=cur_model_type,
) )
models.append(model_dict) models.append(model_dict)
@ -817,6 +813,8 @@ class ModelManager(object):
model_config: ModelConfigBase = model_class.probe_config(str(model_path)) model_config: ModelConfigBase = model_class.probe_config(str(model_path))
self.models[model_key] = model_config self.models[model_key] = model_config
new_models_found = True new_models_found = True
except InvalidModelException:
self.logger.warning(f"Not a valid model: {model_path}")
except NotImplementedError as e: except NotImplementedError as e:
self.logger.warning(e) self.logger.warning(e)

View File

@ -2,7 +2,7 @@ import inspect
from enum import Enum from enum import Enum
from pydantic import BaseModel from pydantic import BaseModel
from typing import Literal, get_origin from typing import Literal, get_origin
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings, ModelNotFoundException from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings, ModelNotFoundException, InvalidModelException
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
from .vae import VaeModel from .vae import VaeModel
from .lora import LoRAModel from .lora import LoRAModel
@ -37,9 +37,9 @@ MODEL_CONFIGS = list()
OPENAPI_MODEL_CONFIGS = list() OPENAPI_MODEL_CONFIGS = list()
class OpenAPIModelInfoBase(BaseModel): class OpenAPIModelInfoBase(BaseModel):
name: str model_name: str
base_model: BaseModelType base_model: BaseModelType
type: ModelType model_type: ModelType
for base_model, models in MODEL_CLASSES.items(): for base_model, models in MODEL_CLASSES.items():
@ -56,7 +56,7 @@ for base_model, models in MODEL_CLASSES.items():
api_wrapper = type(openapi_cfg_name, (cfg, OpenAPIModelInfoBase), dict( api_wrapper = type(openapi_cfg_name, (cfg, OpenAPIModelInfoBase), dict(
__annotations__ = dict( __annotations__ = dict(
type=Literal[model_type.value], model_type=Literal[model_type.value],
), ),
)) ))

View File

@ -15,6 +15,9 @@ from contextlib import suppress
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
class InvalidModelException(Exception):
pass
class ModelNotFoundException(Exception): class ModelNotFoundException(Exception):
pass pass

View File

@ -13,6 +13,7 @@ from .base import (
calc_model_size_by_fs, calc_model_size_by_fs,
calc_model_size_by_data, calc_model_size_by_data,
classproperty, classproperty,
InvalidModelException,
) )
class ControlNetModelFormat(str, Enum): class ControlNetModelFormat(str, Enum):
@ -73,11 +74,19 @@ class ControlNetModel(ModelBase):
@classmethod @classmethod
def detect_format(cls, path: str): def detect_format(cls, path: str):
if not os.path.exists(path):
raise ModelNotFoundException()
if os.path.isdir(path): if os.path.isdir(path):
if os.path.exists(os.path.join(path, "config.json")):
return ControlNetModelFormat.Diffusers return ControlNetModelFormat.Diffusers
else:
if os.path.isfile(path):
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "pth"]]):
return ControlNetModelFormat.Checkpoint return ControlNetModelFormat.Checkpoint
raise InvalidModelException(f"Not a valid model: {path}")
@classmethod @classmethod
def convert_if_required( def convert_if_required(
cls, cls,

View File

@ -9,6 +9,7 @@ from .base import (
ModelType, ModelType,
SubModelType, SubModelType,
classproperty, classproperty,
InvalidModelException,
) )
# TODO: naming # TODO: naming
from ..lora import LoRAModel as LoRAModelRaw from ..lora import LoRAModel as LoRAModelRaw
@ -56,11 +57,19 @@ class LoRAModel(ModelBase):
@classmethod @classmethod
def detect_format(cls, path: str): def detect_format(cls, path: str):
if not os.path.exists(path):
raise ModelNotFoundException()
if os.path.isdir(path): if os.path.isdir(path):
if os.path.exists(os.path.join(path, "pytorch_lora_weights.bin")):
return LoRAModelFormat.Diffusers return LoRAModelFormat.Diffusers
else:
if os.path.isfile(path):
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
return LoRAModelFormat.LyCORIS return LoRAModelFormat.LyCORIS
raise InvalidModelException(f"Not a valid model: {path}")
@classmethod @classmethod
def convert_if_required( def convert_if_required(
cls, cls,

View File

@ -16,6 +16,7 @@ from .base import (
SilenceWarnings, SilenceWarnings,
read_checkpoint_meta, read_checkpoint_meta,
classproperty, classproperty,
InvalidModelException,
) )
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from omegaconf import OmegaConf from omegaconf import OmegaConf
@ -98,11 +99,19 @@ class StableDiffusion1Model(DiffusersModel):
@classmethod @classmethod
def detect_format(cls, model_path: str): def detect_format(cls, model_path: str):
if not os.path.exists(model_path):
raise ModelNotFoundException()
if os.path.isdir(model_path): if os.path.isdir(model_path):
if os.path.exists(os.path.join(model_path, "model_index.json")):
return StableDiffusion1ModelFormat.Diffusers return StableDiffusion1ModelFormat.Diffusers
else:
if os.path.isfile(model_path):
if any([model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
return StableDiffusion1ModelFormat.Checkpoint return StableDiffusion1ModelFormat.Checkpoint
raise InvalidModelException(f"Not a valid model: {model_path}")
@classmethod @classmethod
def convert_if_required( def convert_if_required(
cls, cls,
@ -200,11 +209,19 @@ class StableDiffusion2Model(DiffusersModel):
@classmethod @classmethod
def detect_format(cls, model_path: str): def detect_format(cls, model_path: str):
if not os.path.exists(model_path):
raise ModelNotFoundException()
if os.path.isdir(model_path): if os.path.isdir(model_path):
if os.path.exists(os.path.join(model_path, "model_index.json")):
return StableDiffusion2ModelFormat.Diffusers return StableDiffusion2ModelFormat.Diffusers
else:
if os.path.isfile(model_path):
if any([model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
return StableDiffusion2ModelFormat.Checkpoint return StableDiffusion2ModelFormat.Checkpoint
raise InvalidModelException(f"Not a valid model: {model_path}")
@classmethod @classmethod
def convert_if_required( def convert_if_required(
cls, cls,

View File

@ -9,6 +9,7 @@ from .base import (
SubModelType, SubModelType,
classproperty, classproperty,
ModelNotFoundException, ModelNotFoundException,
InvalidModelException,
) )
# TODO: naming # TODO: naming
from ..lora import TextualInversionModel as TextualInversionModelRaw from ..lora import TextualInversionModel as TextualInversionModelRaw
@ -59,8 +60,19 @@ class TextualInversionModel(ModelBase):
@classmethod @classmethod
def detect_format(cls, path: str): def detect_format(cls, path: str):
if not os.path.exists(path):
raise ModelNotFoundException()
if os.path.isdir(path):
if os.path.exists(os.path.join(path, "learned_embeds.bin")):
return None # diffusers-ti
if os.path.isfile(path):
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
return None return None
raise InvalidModelException(f"Not a valid model: {path}")
@classmethod @classmethod
def convert_if_required( def convert_if_required(
cls, cls,

View File

@ -15,6 +15,7 @@ from .base import (
calc_model_size_by_fs, calc_model_size_by_fs,
calc_model_size_by_data, calc_model_size_by_data,
classproperty, classproperty,
InvalidModelException,
) )
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from diffusers.utils import is_safetensors_available from diffusers.utils import is_safetensors_available
@ -75,11 +76,19 @@ class VaeModel(ModelBase):
@classmethod @classmethod
def detect_format(cls, path: str): def detect_format(cls, path: str):
if not os.path.exists(path):
raise ModelNotFoundException()
if os.path.isdir(path): if os.path.isdir(path):
if os.path.exists(os.path.join(path, "config.json")):
return VaeModelFormat.Diffusers return VaeModelFormat.Diffusers
else:
if os.path.isfile(path):
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
return VaeModelFormat.Checkpoint return VaeModelFormat.Checkpoint
raise InvalidModelException(f"Not a valid model: {path}")
@classmethod @classmethod
def convert_if_required( def convert_if_required(
cls, cls,

View File

@ -127,7 +127,7 @@ class AddsMaskGuidance:
def _t_for_field(self, field_name: str, t): def _t_for_field(self, field_name: str, t):
if field_name == "pred_original_sample": if field_name == "pred_original_sample":
return torch.zeros_like(t, dtype=t.dtype) # it represents t=0 return self.scheduler.timesteps[-1]
return t return t
def apply_mask(self, latents: torch.Tensor, t) -> torch.Tensor: def apply_mask(self, latents: torch.Tensor, t) -> torch.Tensor:

View File

@ -108,6 +108,7 @@
"roarr": "^7.15.0", "roarr": "^7.15.0",
"serialize-error": "^11.0.0", "serialize-error": "^11.0.0",
"socket.io-client": "^4.7.0", "socket.io-client": "^4.7.0",
"use-debounce": "^9.0.4",
"use-image": "^1.1.1", "use-image": "^1.1.1",
"uuid": "^9.0.0", "uuid": "^9.0.0",
"zod": "^3.21.4" "zod": "^3.21.4"

View File

@ -102,7 +102,8 @@
"openInNewTab": "Open in New Tab", "openInNewTab": "Open in New Tab",
"dontAskMeAgain": "Don't ask me again", "dontAskMeAgain": "Don't ask me again",
"areYouSure": "Are you sure?", "areYouSure": "Are you sure?",
"imagePrompt": "Image Prompt" "imagePrompt": "Image Prompt",
"clearNodes": "Are you sure you want to clear all nodes?"
}, },
"gallery": { "gallery": {
"generations": "Generations", "generations": "Generations",
@ -528,7 +529,7 @@
"hidePreview": "Hide Preview", "hidePreview": "Hide Preview",
"showPreview": "Show Preview", "showPreview": "Show Preview",
"controlNetControlMode": "Control Mode", "controlNetControlMode": "Control Mode",
"clipSkip": "Clip Skip", "clipSkip": "CLIP Skip",
"aspectRatio": "Ratio" "aspectRatio": "Ratio"
}, },
"settings": { "settings": {
@ -593,7 +594,12 @@
"metadataLoadFailed": "Failed to load metadata", "metadataLoadFailed": "Failed to load metadata",
"initialImageSet": "Initial Image Set", "initialImageSet": "Initial Image Set",
"initialImageNotSet": "Initial Image Not Set", "initialImageNotSet": "Initial Image Not Set",
"initialImageNotSetDesc": "Could not load initial image" "initialImageNotSetDesc": "Could not load initial image",
"nodesSaved": "Nodes Saved",
"nodesLoaded": "Nodes Loaded",
"nodesLoadedFailed": "Failed To Load Nodes",
"nodesCleared": "Nodes Cleared"
}, },
"tooltip": { "tooltip": {
"feature": { "feature": {
@ -674,5 +680,11 @@
"showProgressImages": "Show Progress Images", "showProgressImages": "Show Progress Images",
"hideProgressImages": "Hide Progress Images", "hideProgressImages": "Hide Progress Images",
"swapSizes": "Swap Sizes" "swapSizes": "Swap Sizes"
},
"nodes": {
"reloadSchema": "Reload Schema",
"saveNodes": "Save Nodes",
"loadNodes": "Load Nodes",
"clearNodes": "Clear Nodes"
} }
} }

View File

@ -51,6 +51,7 @@ import {
} from './listeners/imageUrlsReceived'; } from './listeners/imageUrlsReceived';
import { addInitialImageSelectedListener } from './listeners/initialImageSelected'; import { addInitialImageSelectedListener } from './listeners/initialImageSelected';
import { addModelSelectedListener } from './listeners/modelSelected'; import { addModelSelectedListener } from './listeners/modelSelected';
import { addModelsLoadedListener } from './listeners/modelsLoaded';
import { addReceivedOpenAPISchemaListener } from './listeners/receivedOpenAPISchema'; import { addReceivedOpenAPISchemaListener } from './listeners/receivedOpenAPISchema';
import { import {
addReceivedPageOfImagesFulfilledListener, addReceivedPageOfImagesFulfilledListener,
@ -224,3 +225,4 @@ addModelSelectedListener();
// app startup // app startup
addAppStartedListener(); addAppStartedListener();
addModelsLoadedListener();

View File

@ -1,13 +1,13 @@
import { startAppListening } from '..';
import { imageMetadataReceived } from 'services/api/thunks/image';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { controlNetImageProcessed } from 'features/controlNet/store/actions'; import { controlNetImageProcessed } from 'features/controlNet/store/actions';
import { Graph } from 'services/api/types';
import { sessionCreated } from 'services/api/thunks/session';
import { sessionReadyToInvoke } from 'features/system/store/actions';
import { socketInvocationComplete } from 'services/events/actions';
import { isImageOutput } from 'services/api/guards';
import { controlNetProcessedImageChanged } from 'features/controlNet/store/controlNetSlice'; import { controlNetProcessedImageChanged } from 'features/controlNet/store/controlNetSlice';
import { sessionReadyToInvoke } from 'features/system/store/actions';
import { isImageOutput } from 'services/api/guards';
import { imageDTOReceived } from 'services/api/thunks/image';
import { sessionCreated } from 'services/api/thunks/session';
import { Graph } from 'services/api/types';
import { socketInvocationComplete } from 'services/events/actions';
import { startAppListening } from '..';
const moduleLog = log.child({ namespace: 'controlNet' }); const moduleLog = log.child({ namespace: 'controlNet' });
@ -63,10 +63,8 @@ export const addControlNetImageProcessedListener = () => {
// Wait for the ImageDTO to be received // Wait for the ImageDTO to be received
const [imageMetadataReceivedAction] = await take( const [imageMetadataReceivedAction] = await take(
( (action): action is ReturnType<typeof imageDTOReceived.fulfilled> =>
action imageDTOReceived.fulfilled.match(action) &&
): action is ReturnType<typeof imageMetadataReceived.fulfilled> =>
imageMetadataReceived.fulfilled.match(action) &&
action.payload.image_name === image_name action.payload.image_name === image_name
); );
const processedControlImage = imageMetadataReceivedAction.payload; const processedControlImage = imageMetadataReceivedAction.payload;

View File

@ -1,7 +1,7 @@
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { startAppListening } from '..';
import { imageMetadataReceived } from 'services/api/thunks/image';
import { boardImagesApi } from 'services/api/endpoints/boardImages'; import { boardImagesApi } from 'services/api/endpoints/boardImages';
import { imageDTOReceived } from 'services/api/thunks/image';
import { startAppListening } from '..';
const moduleLog = log.child({ namespace: 'boards' }); const moduleLog = log.child({ namespace: 'boards' });
@ -17,7 +17,7 @@ export const addImageAddedToBoardFulfilledListener = () => {
); );
dispatch( dispatch(
imageMetadataReceived({ imageDTOReceived({
image_name, image_name,
}) })
); );

View File

@ -1,13 +1,13 @@
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { startAppListening } from '..';
import { imageMetadataReceived, imageUpdated } from 'services/api/thunks/image';
import { imageUpserted } from 'features/gallery/store/gallerySlice'; import { imageUpserted } from 'features/gallery/store/gallerySlice';
import { imageDTOReceived, imageUpdated } from 'services/api/thunks/image';
import { startAppListening } from '..';
const moduleLog = log.child({ namespace: 'image' }); const moduleLog = log.child({ namespace: 'image' });
export const addImageMetadataReceivedFulfilledListener = () => { export const addImageMetadataReceivedFulfilledListener = () => {
startAppListening({ startAppListening({
actionCreator: imageMetadataReceived.fulfilled, actionCreator: imageDTOReceived.fulfilled,
effect: (action, { getState, dispatch }) => { effect: (action, { getState, dispatch }) => {
const image = action.payload; const image = action.payload;
@ -40,7 +40,7 @@ export const addImageMetadataReceivedFulfilledListener = () => {
export const addImageMetadataReceivedRejectedListener = () => { export const addImageMetadataReceivedRejectedListener = () => {
startAppListening({ startAppListening({
actionCreator: imageMetadataReceived.rejected, actionCreator: imageDTOReceived.rejected,
effect: (action, { getState, dispatch }) => { effect: (action, { getState, dispatch }) => {
moduleLog.debug( moduleLog.debug(
{ data: { image: action.meta.arg } }, { data: { image: action.meta.arg } },

View File

@ -1,7 +1,7 @@
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { startAppListening } from '..';
import { imageMetadataReceived } from 'services/api/thunks/image';
import { boardImagesApi } from 'services/api/endpoints/boardImages'; import { boardImagesApi } from 'services/api/endpoints/boardImages';
import { imageDTOReceived } from 'services/api/thunks/image';
import { startAppListening } from '..';
const moduleLog = log.child({ namespace: 'boards' }); const moduleLog = log.child({ namespace: 'boards' });
@ -17,7 +17,7 @@ export const addImageRemovedFromBoardFulfilledListener = () => {
); );
dispatch( dispatch(
imageMetadataReceived({ imageDTOReceived({
image_name, image_name,
}) })
); );

View File

@ -14,7 +14,7 @@ export const addModelSelectedListener = () => {
actionCreator: modelSelected, actionCreator: modelSelected,
effect: (action, { getState, dispatch }) => { effect: (action, { getState, dispatch }) => {
const state = getState(); const state = getState();
const [base_model, type, name] = action.payload.split('/'); const { base_model, model_name } = action.payload;
if (state.generation.model?.base_model !== base_model) { if (state.generation.model?.base_model !== base_model) {
dispatch( dispatch(
@ -30,11 +30,7 @@ export const addModelSelectedListener = () => {
// TODO: controlnet cleared // TODO: controlnet cleared
} }
const newModel = zMainModel.parse({ const newModel = zMainModel.parse(action.payload);
id: action.payload,
base_model,
name,
});
dispatch(modelChanged(newModel)); dispatch(modelChanged(newModel));
}, },

View File

@ -0,0 +1,42 @@
import { modelChanged } from 'features/parameters/store/generationSlice';
import { some } from 'lodash-es';
import { modelsApi } from 'services/api/endpoints/models';
import { startAppListening } from '..';
export const addModelsLoadedListener = () => {
startAppListening({
matcher: modelsApi.endpoints.getMainModels.matchFulfilled,
effect: async (action, { getState, dispatch }) => {
// models loaded, we need to ensure the selected model is available and if not, select the first one
const currentModel = getState().generation.model;
const isCurrentModelAvailable = some(
action.payload.entities,
(m) =>
m?.model_name === currentModel?.model_name &&
m?.base_model === currentModel?.base_model
);
if (isCurrentModelAvailable) {
return;
}
const firstModelId = action.payload.ids[0];
const firstModel = action.payload.entities[firstModelId];
if (!firstModel) {
// No models loaded at all
dispatch(modelChanged(null));
return;
}
dispatch(
modelChanged({
base_model: firstModel.base_model,
model_name: firstModel.model_name,
})
);
},
});
};

View File

@ -30,6 +30,7 @@ export const addSessionCreatedRejectedListener = () => {
effect: (action, { getState, dispatch }) => { effect: (action, { getState, dispatch }) => {
if (action.payload) { if (action.payload) {
const { arg, error } = action.payload; const { arg, error } = action.payload;
const stringifiedError = JSON.stringify(error);
moduleLog.error( moduleLog.error(
{ {
data: { data: {
@ -37,7 +38,7 @@ export const addSessionCreatedRejectedListener = () => {
error: serializeError(error), error: serializeError(error),
}, },
}, },
`Problem creating session` `Problem creating session: ${stringifiedError}`
); );
} }
}, },

View File

@ -33,6 +33,7 @@ export const addSessionInvokedRejectedListener = () => {
effect: (action, { getState, dispatch }) => { effect: (action, { getState, dispatch }) => {
if (action.payload) { if (action.payload) {
const { arg, error } = action.payload; const { arg, error } = action.payload;
const stringifiedError = JSON.stringify(error);
moduleLog.error( moduleLog.error(
{ {
data: { data: {
@ -40,7 +41,7 @@ export const addSessionInvokedRejectedListener = () => {
error: serializeError(error), error: serializeError(error),
}, },
}, },
`Problem invoking session` `Problem invoking session: ${stringifiedError}`
); );
} }
}, },

View File

@ -1,15 +1,15 @@
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
import { startAppListening } from '../..';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
import { progressImageSet } from 'features/system/store/systemSlice';
import { boardImagesApi } from 'services/api/endpoints/boardImages';
import { isImageOutput } from 'services/api/guards';
import { imageDTOReceived } from 'services/api/thunks/image';
import { sessionCanceled } from 'services/api/thunks/session';
import { import {
appSocketInvocationComplete, appSocketInvocationComplete,
socketInvocationComplete, socketInvocationComplete,
} from 'services/events/actions'; } from 'services/events/actions';
import { imageMetadataReceived } from 'services/api/thunks/image'; import { startAppListening } from '../..';
import { sessionCanceled } from 'services/api/thunks/session';
import { isImageOutput } from 'services/api/guards';
import { progressImageSet } from 'features/system/store/systemSlice';
import { boardImagesApi } from 'services/api/endpoints/boardImages';
const moduleLog = log.child({ namespace: 'socketio' }); const moduleLog = log.child({ namespace: 'socketio' });
const nodeDenylist = ['dataURL_image']; const nodeDenylist = ['dataURL_image'];
@ -42,13 +42,13 @@ export const addInvocationCompleteEventListener = () => {
// Get its metadata // Get its metadata
dispatch( dispatch(
imageMetadataReceived({ imageDTOReceived({
image_name, image_name,
}) })
); );
const [{ payload: imageDTO }] = await take( const [{ payload: imageDTO }] = await take(
imageMetadataReceived.fulfilled.match imageDTOReceived.fulfilled.match
); );
// Handle canvas image // Handle canvas image

View File

@ -13,7 +13,7 @@ export const addInvocationErrorEventListener = () => {
effect: (action, { dispatch, getState }) => { effect: (action, { dispatch, getState }) => {
moduleLog.error( moduleLog.error(
action.payload, action.payload,
`Invocation error (${action.payload.data.node.type})` `Invocation error (${action.payload.data.node.type}): ${action.payload.data.error}`
); );
dispatch(appSocketInvocationError(action.payload)); dispatch(appSocketInvocationError(action.payload));
}, },

View File

@ -1,6 +1,7 @@
import { import {
AnyAction, AnyAction,
ThunkDispatch, ThunkDispatch,
autoBatchEnhancer,
combineReducers, combineReducers,
configureStore, configureStore,
} from '@reduxjs/toolkit'; } from '@reduxjs/toolkit';
@ -79,14 +80,18 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
export const store = configureStore({ export const store = configureStore({
reducer: rememberedRootReducer, reducer: rememberedRootReducer,
enhancers: [ enhancers: (existingEnhancers) => {
return existingEnhancers
.concat(
rememberEnhancer(window.localStorage, rememberedKeys, { rememberEnhancer(window.localStorage, rememberedKeys, {
persistDebounce: 300, persistDebounce: 300,
serialize, serialize,
unserialize, unserialize,
prefix: LOCALSTORAGE_PREFIX, prefix: LOCALSTORAGE_PREFIX,
}), })
], )
.concat(autoBatchEnhancer());
},
middleware: (getDefaultMiddleware) => middleware: (getDefaultMiddleware) =>
getDefaultMiddleware({ getDefaultMiddleware({
immutableCheck: false, immutableCheck: false,

View File

@ -102,6 +102,8 @@ export type AppFeature =
export type SDFeature = export type SDFeature =
| 'controlNet' | 'controlNet'
| 'noise' | 'noise'
| 'perlinNoise'
| 'noiseThreshold'
| 'variation' | 'variation'
| 'symmetry' | 'symmetry'
| 'seamless' | 'seamless'

View File

@ -12,6 +12,7 @@ import {
setIsMovingBoundingBox, setIsMovingBoundingBox,
setIsTransformingBoundingBox, setIsTransformingBoundingBox,
} from 'features/canvas/store/canvasSlice'; } from 'features/canvas/store/canvasSlice';
import { uiSelector } from 'features/ui/store/uiSelectors';
import Konva from 'konva'; import Konva from 'konva';
import { GroupConfig } from 'konva/lib/Group'; import { GroupConfig } from 'konva/lib/Group';
import { KonvaEventObject } from 'konva/lib/Node'; import { KonvaEventObject } from 'konva/lib/Node';
@ -22,8 +23,8 @@ import { useCallback, useEffect, useRef, useState } from 'react';
import { Group, Rect, Transformer } from 'react-konva'; import { Group, Rect, Transformer } from 'react-konva';
const boundingBoxPreviewSelector = createSelector( const boundingBoxPreviewSelector = createSelector(
canvasSelector, [canvasSelector, uiSelector],
(canvas) => { (canvas, ui) => {
const { const {
boundingBoxCoordinates, boundingBoxCoordinates,
boundingBoxDimensions, boundingBoxDimensions,
@ -35,6 +36,8 @@ const boundingBoxPreviewSelector = createSelector(
shouldSnapToGrid, shouldSnapToGrid,
} = canvas; } = canvas;
const { aspectRatio } = ui;
return { return {
boundingBoxCoordinates, boundingBoxCoordinates,
boundingBoxDimensions, boundingBoxDimensions,
@ -45,6 +48,7 @@ const boundingBoxPreviewSelector = createSelector(
shouldSnapToGrid, shouldSnapToGrid,
tool, tool,
hitStrokeWidth: 20 / stageScale, hitStrokeWidth: 20 / stageScale,
aspectRatio,
}; };
}, },
{ {
@ -70,6 +74,7 @@ const IAICanvasBoundingBox = (props: IAICanvasBoundingBoxPreviewProps) => {
shouldSnapToGrid, shouldSnapToGrid,
tool, tool,
hitStrokeWidth, hitStrokeWidth,
aspectRatio,
} = useAppSelector(boundingBoxPreviewSelector); } = useAppSelector(boundingBoxPreviewSelector);
const transformerRef = useRef<Konva.Transformer>(null); const transformerRef = useRef<Konva.Transformer>(null);
@ -137,12 +142,22 @@ const IAICanvasBoundingBox = (props: IAICanvasBoundingBoxPreviewProps) => {
const x = Math.round(rect.x()); const x = Math.round(rect.x());
const y = Math.round(rect.y()); const y = Math.round(rect.y());
if (aspectRatio) {
const newHeight = roundToMultiple(width / aspectRatio, 64);
dispatch(
setBoundingBoxDimensions({
width: width,
height: newHeight,
})
);
} else {
dispatch( dispatch(
setBoundingBoxDimensions({ setBoundingBoxDimensions({
width, width,
height, height,
}) })
); );
}
dispatch( dispatch(
setBoundingBoxCoordinates({ setBoundingBoxCoordinates({
@ -154,7 +169,7 @@ const IAICanvasBoundingBox = (props: IAICanvasBoundingBoxPreviewProps) => {
// Reset the scale now that the coords/dimensions have been un-scaled // Reset the scale now that the coords/dimensions have been un-scaled
rect.scaleX(1); rect.scaleX(1);
rect.scaleY(1); rect.scaleY(1);
}, [dispatch, shouldSnapToGrid]); }, [dispatch, shouldSnapToGrid, aspectRatio]);
const anchorDragBoundFunc = useCallback( const anchorDragBoundFunc = useCallback(
( (

View File

@ -7,7 +7,14 @@ import {
import { IRect, Vector2d } from 'konva/lib/types'; import { IRect, Vector2d } from 'konva/lib/types';
import { clamp, cloneDeep } from 'lodash-es'; import { clamp, cloneDeep } from 'lodash-es';
// //
import {
setActiveTab,
setAspectRatio,
setShouldUseCanvasBetaLayout,
} from 'features/ui/store/uiSlice';
import { RgbaColor } from 'react-colorful'; import { RgbaColor } from 'react-colorful';
import { sessionCanceled } from 'services/api/thunks/session';
import { ImageDTO } from 'services/api/types';
import calculateCoordinates from '../util/calculateCoordinates'; import calculateCoordinates from '../util/calculateCoordinates';
import calculateScale from '../util/calculateScale'; import calculateScale from '../util/calculateScale';
import { STAGE_PADDING_PERCENTAGE } from '../util/constants'; import { STAGE_PADDING_PERCENTAGE } from '../util/constants';
@ -28,13 +35,6 @@ import {
isCanvasBaseImage, isCanvasBaseImage,
isCanvasMaskLine, isCanvasMaskLine,
} from './canvasTypes'; } from './canvasTypes';
import { ImageDTO } from 'services/api/types';
import { sessionCanceled } from 'services/api/thunks/session';
import {
setActiveTab,
setShouldUseCanvasBetaLayout,
} from 'features/ui/store/uiSlice';
import { imageUrlsReceived } from 'services/api/thunks/image';
export const initialLayerState: CanvasLayerState = { export const initialLayerState: CanvasLayerState = {
objects: [], objects: [],
@ -240,6 +240,16 @@ export const canvasSlice = createSlice({
state.scaledBoundingBoxDimensions = scaledDimensions; state.scaledBoundingBoxDimensions = scaledDimensions;
} }
}, },
flipBoundingBoxAxes: (state) => {
const [currWidth, currHeight] = [
state.boundingBoxDimensions.width,
state.boundingBoxDimensions.height,
];
state.boundingBoxDimensions = {
width: currHeight,
height: currWidth,
};
},
setBoundingBoxCoordinates: (state, action: PayloadAction<Vector2d>) => { setBoundingBoxCoordinates: (state, action: PayloadAction<Vector2d>) => {
state.boundingBoxCoordinates = floorCoordinates(action.payload); state.boundingBoxCoordinates = floorCoordinates(action.payload);
}, },
@ -864,6 +874,15 @@ export const canvasSlice = createSlice({
builder.addCase(setActiveTab, (state, action) => { builder.addCase(setActiveTab, (state, action) => {
state.doesCanvasNeedScaling = true; state.doesCanvasNeedScaling = true;
}); });
builder.addCase(setAspectRatio, (state, action) => {
const ratio = action.payload;
if (ratio) {
state.boundingBoxDimensions.height = roundToMultiple(
state.boundingBoxDimensions.width / ratio,
64
);
}
});
// builder.addCase(imageUrlsReceived.fulfilled, (state, action) => { // builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
// const { image_name, image_url, thumbnail_url } = action.payload; // const { image_name, image_url, thumbnail_url } = action.payload;
@ -912,6 +931,7 @@ export const {
setBoundingBoxDimensions, setBoundingBoxDimensions,
setBoundingBoxPreviewFill, setBoundingBoxPreviewFill,
setBoundingBoxScaleMethod, setBoundingBoxScaleMethod,
flipBoundingBoxAxes,
setBrushColor, setBrushColor,
setBrushSize, setBrushSize,
setCanvasContainerDimensions, setCanvasContainerDimensions,

View File

@ -47,8 +47,8 @@ const ParamEmbeddingPopover = (props: Props) => {
const disabled = currentMainModel?.base_model !== embedding.base_model; const disabled = currentMainModel?.base_model !== embedding.base_model;
data.push({ data.push({
value: embedding.name, value: embedding.model_name,
label: embedding.name, label: embedding.model_name,
group: MODEL_TYPE_MAP[embedding.base_model], group: MODEL_TYPE_MAP[embedding.base_model],
disabled, disabled,
tooltip: disabled tooltip: disabled

View File

@ -45,7 +45,11 @@ import {
FaShare, FaShare,
FaShareAlt, FaShareAlt,
} from 'react-icons/fa'; } from 'react-icons/fa';
import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import {
useGetImageDTOQuery,
useGetImageMetadataQuery,
} from 'services/api/endpoints/images';
import { useDebounce } from 'use-debounce';
import { sentImageToCanvas, sentImageToImg2Img } from '../store/actions'; import { sentImageToCanvas, sentImageToImg2Img } from '../store/actions';
const currentImageButtonsSelector = createSelector( const currentImageButtonsSelector = createSelector(
@ -128,10 +132,23 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
const { recallBothPrompts, recallSeed, recallAllParameters } = const { recallBothPrompts, recallSeed, recallAllParameters } =
useRecallParameters(); useRecallParameters();
const { currentData: image } = useGetImageDTOQuery( const [debouncedMetadataQueryArg, debounceState] = useDebounce(
lastSelectedImage,
500
);
const { currentData: image, isFetching } = useGetImageDTOQuery(
lastSelectedImage ?? skipToken lastSelectedImage ?? skipToken
); );
const { currentData: metadataData } = useGetImageMetadataQuery(
debounceState.isPending()
? skipToken
: debouncedMetadataQueryArg ?? skipToken
);
const metadata = metadataData?.metadata;
// const handleCopyImage = useCallback(async () => { // const handleCopyImage = useCallback(async () => {
// if (!image?.url) { // if (!image?.url) {
// return; // return;
@ -193,29 +210,26 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
}, [toaster, t, image]); }, [toaster, t, image]);
const handleClickUseAllParameters = useCallback(() => { const handleClickUseAllParameters = useCallback(() => {
recallAllParameters(image); recallAllParameters(metadata);
}, [image, recallAllParameters]); }, [metadata, recallAllParameters]);
useHotkeys( useHotkeys(
'a', 'a',
() => { () => {
handleClickUseAllParameters; handleClickUseAllParameters;
}, },
[image, recallAllParameters] [metadata, recallAllParameters]
); );
const handleUseSeed = useCallback(() => { const handleUseSeed = useCallback(() => {
recallSeed(image?.metadata?.seed); recallSeed(metadata?.seed);
}, [image, recallSeed]); }, [metadata?.seed, recallSeed]);
useHotkeys('s', handleUseSeed, [image]); useHotkeys('s', handleUseSeed, [image]);
const handleUsePrompt = useCallback(() => { const handleUsePrompt = useCallback(() => {
recallBothPrompts( recallBothPrompts(metadata?.positive_prompt, metadata?.negative_prompt);
image?.metadata?.positive_conditioning, }, [metadata?.negative_prompt, metadata?.positive_prompt, recallBothPrompts]);
image?.metadata?.negative_conditioning
);
}, [image, recallBothPrompts]);
useHotkeys('p', handleUsePrompt, [image]); useHotkeys('p', handleUsePrompt, [image]);
@ -440,7 +454,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
icon={<FaQuoteRight />} icon={<FaQuoteRight />}
tooltip={`${t('parameters.usePrompt')} (P)`} tooltip={`${t('parameters.usePrompt')} (P)`}
aria-label={`${t('parameters.usePrompt')} (P)`} aria-label={`${t('parameters.usePrompt')} (P)`}
isDisabled={!image?.metadata?.positive_conditioning} isDisabled={!metadata?.positive_prompt}
onClick={handleUsePrompt} onClick={handleUsePrompt}
/> />
@ -448,7 +462,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
icon={<FaSeedling />} icon={<FaSeedling />}
tooltip={`${t('parameters.useSeed')} (S)`} tooltip={`${t('parameters.useSeed')} (S)`}
aria-label={`${t('parameters.useSeed')} (S)`} aria-label={`${t('parameters.useSeed')} (S)`}
isDisabled={!image?.metadata?.seed} isDisabled={!metadata?.seed}
onClick={handleUseSeed} onClick={handleUseSeed}
/> />
@ -456,10 +470,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
icon={<FaAsterisk />} icon={<FaAsterisk />}
tooltip={`${t('parameters.useAll')} (A)`} tooltip={`${t('parameters.useAll')} (A)`}
aria-label={`${t('parameters.useAll')} (A)`} aria-label={`${t('parameters.useAll')} (A)`}
isDisabled={ isDisabled={!metadata}
// not sure what this list should be
!['t2l', 'l2l', 'inpaint'].includes(String(image?.metadata?.type))
}
onClick={handleClickUseAllParameters} onClick={handleClickUseAllParameters}
/> />
</ButtonGroup> </ButtonGroup>

View File

@ -11,7 +11,9 @@ import IAIDndImage from 'common/components/IAIDndImage';
import { selectLastSelectedImage } from 'features/gallery/store/gallerySlice'; import { selectLastSelectedImage } from 'features/gallery/store/gallerySlice';
import { isEqual } from 'lodash-es'; import { isEqual } from 'lodash-es';
import { memo, useMemo } from 'react'; import { memo, useMemo } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { useNextPrevImage } from '../hooks/useNextPrevImage';
import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer'; import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer';
import NextPrevImageButtons from './NextPrevImageButtons'; import NextPrevImageButtons from './NextPrevImageButtons';
@ -49,6 +51,45 @@ const CurrentImagePreview = () => {
shouldAntialiasProgressImage, shouldAntialiasProgressImage,
} = useAppSelector(imagesSelector); } = useAppSelector(imagesSelector);
const {
handlePrevImage,
handleNextImage,
prevImageId,
nextImageId,
isOnLastImage,
handleLoadMoreImages,
areMoreImagesAvailable,
isFetching,
} = useNextPrevImage();
useHotkeys(
'left',
() => {
handlePrevImage();
},
[prevImageId]
);
useHotkeys(
'right',
() => {
if (isOnLastImage && areMoreImagesAvailable && !isFetching) {
handleLoadMoreImages();
return;
}
if (!isOnLastImage) {
handleNextImage();
}
},
[
nextImageId,
isOnLastImage,
areMoreImagesAvailable,
handleLoadMoreImages,
isFetching,
]
);
const { const {
currentData: imageDTO, currentData: imageDTO,
isLoading, isLoading,
@ -118,7 +159,6 @@ const CurrentImagePreview = () => {
width: 'full', width: 'full',
height: 'full', height: 'full',
borderRadius: 'base', borderRadius: 'base',
overflow: 'scroll',
}} }}
> >
<ImageMetadataViewer image={imageDTO} /> <ImageMetadataViewer image={imageDTO} />

View File

@ -6,10 +6,7 @@ import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { ContextMenu, ContextMenuProps } from 'chakra-ui-contextmenu'; import { ContextMenu, ContextMenuProps } from 'chakra-ui-contextmenu';
import { import { imagesAddedToBatch } from 'features/batch/store/batchSlice';
imagesAddedToBatch,
selectionAddedToBatch,
} from 'features/batch/store/batchSlice';
import { import {
resizeAndScaleCanvas, resizeAndScaleCanvas,
setInitialCanvasImage, setInitialCanvasImage,
@ -24,6 +21,7 @@ import { useTranslation } from 'react-i18next';
import { FaExpand, FaFolder, FaShare, FaTrash } from 'react-icons/fa'; import { FaExpand, FaFolder, FaShare, FaTrash } from 'react-icons/fa';
import { IoArrowUndoCircleOutline } from 'react-icons/io5'; import { IoArrowUndoCircleOutline } from 'react-icons/io5';
import { useRemoveImageFromBoardMutation } from 'services/api/endpoints/boardImages'; import { useRemoveImageFromBoardMutation } from 'services/api/endpoints/boardImages';
import { useGetImageMetadataQuery } from 'services/api/endpoints/images';
import { ImageDTO } from 'services/api/types'; import { ImageDTO } from 'services/api/types';
import { AddImageToBoardContext } from '../../../app/contexts/AddImageToBoardContext'; import { AddImageToBoardContext } from '../../../app/contexts/AddImageToBoardContext';
import { sentImageToCanvas, sentImageToImg2Img } from '../store/actions'; import { sentImageToCanvas, sentImageToImg2Img } from '../store/actions';
@ -38,24 +36,17 @@ const ImageContextMenu = ({ image, children }: Props) => {
() => () =>
createSelector( createSelector(
[stateSelector], [stateSelector],
({ gallery, batch }) => { ({ gallery }) => {
const selectionCount = gallery.selection.length; const selectionCount = gallery.selection.length;
const isInBatch = batch.imageNames.includes(image.image_name);
return { selectionCount, isInBatch }; return { selectionCount };
}, },
defaultSelectorOptions defaultSelectorOptions
), ),
[image.image_name] []
); );
const { selectionCount, isInBatch } = useAppSelector(selector); const { selectionCount } = useAppSelector(selector);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation();
const toaster = useAppToaster();
const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled;
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
const { onClickAddToBoard } = useContext(AddImageToBoardContext); const { onClickAddToBoard } = useContext(AddImageToBoardContext);
@ -66,178 +57,17 @@ const ImageContextMenu = ({ image, children }: Props) => {
dispatch(imageToDeleteSelected(image)); dispatch(imageToDeleteSelected(image));
}, [dispatch, image]); }, [dispatch, image]);
const { recallBothPrompts, recallSeed, recallAllParameters } =
useRecallParameters();
const [removeFromBoard] = useRemoveImageFromBoardMutation();
// Recall parameters handlers
const handleRecallPrompt = useCallback(() => {
recallBothPrompts(
image.metadata?.positive_conditioning,
image.metadata?.negative_conditioning
);
}, [
image.metadata?.negative_conditioning,
image.metadata?.positive_conditioning,
recallBothPrompts,
]);
const handleRecallSeed = useCallback(() => {
recallSeed(image.metadata?.seed);
}, [image, recallSeed]);
const handleSendToImageToImage = useCallback(() => {
dispatch(sentImageToImg2Img());
dispatch(initialImageSelected(image));
}, [dispatch, image]);
// const handleRecallInitialImage = useCallback(() => {
// recallInitialImage(image.metadata.invokeai?.node?.image);
// }, [image, recallInitialImage]);
const handleSendToCanvas = () => {
dispatch(sentImageToCanvas());
dispatch(setInitialCanvasImage(image));
dispatch(resizeAndScaleCanvas());
dispatch(setActiveTab('unifiedCanvas'));
toaster({
title: t('toast.sentToUnifiedCanvas'),
status: 'success',
duration: 2500,
isClosable: true,
});
};
const handleUseAllParameters = useCallback(() => {
recallAllParameters(image);
}, [image, recallAllParameters]);
const handleLightBox = () => {
// dispatch(setCurrentImage(image));
// dispatch(setIsLightboxOpen(true));
};
const handleAddToBoard = useCallback(() => { const handleAddToBoard = useCallback(() => {
onClickAddToBoard(image); onClickAddToBoard(image);
}, [image, onClickAddToBoard]); }, [image, onClickAddToBoard]);
const handleRemoveFromBoard = useCallback(() => {
if (!image.board_id) {
return;
}
removeFromBoard({ board_id: image.board_id, image_name: image.image_name });
}, [image.board_id, image.image_name, removeFromBoard]);
const handleOpenInNewTab = () => {
window.open(image.image_url, '_blank');
};
const handleAddSelectionToBatch = useCallback(() => {
dispatch(selectionAddedToBatch());
}, [dispatch]);
const handleAddToBatch = useCallback(() => {
dispatch(imagesAddedToBatch([image.image_name]));
}, [dispatch, image.image_name]);
return ( return (
<ContextMenu<HTMLDivElement> <ContextMenu<HTMLDivElement>
menuProps={{ size: 'sm', isLazy: true }} menuProps={{ size: 'sm', isLazy: true }}
renderMenu={() => ( renderMenu={() => (
<MenuList sx={{ visibility: 'visible !important' }}> <MenuList sx={{ visibility: 'visible !important' }}>
{selectionCount === 1 ? ( {selectionCount === 1 ? (
<> <SingleSelectionMenuItems image={image} />
<MenuItem
icon={<ExternalLinkIcon />}
onClickCapture={handleOpenInNewTab}
>
{t('common.openInNewTab')}
</MenuItem>
{isLightboxEnabled && (
<MenuItem icon={<FaExpand />} onClickCapture={handleLightBox}>
{t('parameters.openInViewer')}
</MenuItem>
)}
<MenuItem
icon={<IoArrowUndoCircleOutline />}
onClickCapture={handleRecallPrompt}
isDisabled={
image?.metadata?.positive_conditioning === undefined
}
>
{t('parameters.usePrompt')}
</MenuItem>
<MenuItem
icon={<IoArrowUndoCircleOutline />}
onClickCapture={handleRecallSeed}
isDisabled={image?.metadata?.seed === undefined}
>
{t('parameters.useSeed')}
</MenuItem>
{/* <MenuItem
icon={<IoArrowUndoCircleOutline />}
onClickCapture={handleRecallInitialImage}
isDisabled={image?.metadata?.type !== 'img2img'}
>
{t('parameters.useInitImg')}
</MenuItem> */}
<MenuItem
icon={<IoArrowUndoCircleOutline />}
onClickCapture={handleUseAllParameters}
isDisabled={
// what should these be
!['t2l', 'l2l', 'inpaint'].includes(
String(image?.metadata?.type)
)
}
>
{t('parameters.useAll')}
</MenuItem>
<MenuItem
icon={<FaShare />}
onClickCapture={handleSendToImageToImage}
id="send-to-img2img"
>
{t('parameters.sendToImg2Img')}
</MenuItem>
{isCanvasEnabled && (
<MenuItem
icon={<FaShare />}
onClickCapture={handleSendToCanvas}
id="send-to-canvas"
>
{t('parameters.sendToUnifiedCanvas')}
</MenuItem>
)}
{/* <MenuItem
icon={<FaFolder />}
isDisabled={isInBatch}
onClickCapture={handleAddToBatch}
>
Add to Batch
</MenuItem> */}
<MenuItem icon={<FaFolder />} onClickCapture={handleAddToBoard}>
{image.board_id ? 'Change Board' : 'Add to Board'}
</MenuItem>
{image.board_id && (
<MenuItem
icon={<FaFolder />}
onClickCapture={handleRemoveFromBoard}
>
Remove from Board
</MenuItem>
)}
<MenuItem
sx={{ color: 'error.600', _dark: { color: 'error.300' } }}
icon={<FaTrash />}
onClickCapture={handleDelete}
>
{t('gallery.deleteImage')}
</MenuItem>
</>
) : ( ) : (
<> <>
<MenuItem <MenuItem
@ -271,3 +101,185 @@ const ImageContextMenu = ({ image, children }: Props) => {
}; };
export default memo(ImageContextMenu); export default memo(ImageContextMenu);
type SingleSelectionMenuItemsProps = {
image: ImageDTO;
};
const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
const { image } = props;
const selector = useMemo(
() =>
createSelector(
[stateSelector],
({ batch }) => {
const isInBatch = batch.imageNames.includes(image.image_name);
return { isInBatch };
},
defaultSelectorOptions
),
[image.image_name]
);
const { isInBatch } = useAppSelector(selector);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const toaster = useAppToaster();
const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled;
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
const { onClickAddToBoard } = useContext(AddImageToBoardContext);
const { currentData } = useGetImageMetadataQuery(image.image_name);
const metadata = currentData?.metadata;
const handleDelete = useCallback(() => {
if (!image) {
return;
}
dispatch(imageToDeleteSelected(image));
}, [dispatch, image]);
const { recallBothPrompts, recallSeed, recallAllParameters } =
useRecallParameters();
const [removeFromBoard] = useRemoveImageFromBoardMutation();
// Recall parameters handlers
const handleRecallPrompt = useCallback(() => {
recallBothPrompts(metadata?.positive_prompt, metadata?.negative_prompt);
}, [metadata?.negative_prompt, metadata?.positive_prompt, recallBothPrompts]);
const handleRecallSeed = useCallback(() => {
recallSeed(metadata?.seed);
}, [metadata?.seed, recallSeed]);
const handleSendToImageToImage = useCallback(() => {
dispatch(sentImageToImg2Img());
dispatch(initialImageSelected(image));
}, [dispatch, image]);
const handleSendToCanvas = () => {
dispatch(sentImageToCanvas());
dispatch(setInitialCanvasImage(image));
dispatch(resizeAndScaleCanvas());
dispatch(setActiveTab('unifiedCanvas'));
toaster({
title: t('toast.sentToUnifiedCanvas'),
status: 'success',
duration: 2500,
isClosable: true,
});
};
const handleUseAllParameters = useCallback(() => {
console.log(metadata);
recallAllParameters(metadata);
}, [metadata, recallAllParameters]);
const handleLightBox = () => {
// dispatch(setCurrentImage(image));
// dispatch(setIsLightboxOpen(true));
};
const handleAddToBoard = useCallback(() => {
onClickAddToBoard(image);
}, [image, onClickAddToBoard]);
const handleRemoveFromBoard = useCallback(() => {
if (!image.board_id) {
return;
}
removeFromBoard({ board_id: image.board_id, image_name: image.image_name });
}, [image.board_id, image.image_name, removeFromBoard]);
const handleOpenInNewTab = () => {
window.open(image.image_url, '_blank');
};
const handleAddToBatch = useCallback(() => {
dispatch(imagesAddedToBatch([image.image_name]));
}, [dispatch, image.image_name]);
return (
<>
<MenuItem icon={<ExternalLinkIcon />} onClickCapture={handleOpenInNewTab}>
{t('common.openInNewTab')}
</MenuItem>
{isLightboxEnabled && (
<MenuItem icon={<FaExpand />} onClickCapture={handleLightBox}>
{t('parameters.openInViewer')}
</MenuItem>
)}
<MenuItem
icon={<IoArrowUndoCircleOutline />}
onClickCapture={handleRecallPrompt}
isDisabled={
metadata?.positive_prompt === undefined &&
metadata?.negative_prompt === undefined
}
>
{t('parameters.usePrompt')}
</MenuItem>
<MenuItem
icon={<IoArrowUndoCircleOutline />}
onClickCapture={handleRecallSeed}
isDisabled={metadata?.seed === undefined}
>
{t('parameters.useSeed')}
</MenuItem>
<MenuItem
icon={<IoArrowUndoCircleOutline />}
onClickCapture={handleUseAllParameters}
isDisabled={!metadata}
>
{t('parameters.useAll')}
</MenuItem>
<MenuItem
icon={<FaShare />}
onClickCapture={handleSendToImageToImage}
id="send-to-img2img"
>
{t('parameters.sendToImg2Img')}
</MenuItem>
{isCanvasEnabled && (
<MenuItem
icon={<FaShare />}
onClickCapture={handleSendToCanvas}
id="send-to-canvas"
>
{t('parameters.sendToUnifiedCanvas')}
</MenuItem>
)}
<MenuItem
icon={<FaFolder />}
isDisabled={isInBatch}
onClickCapture={handleAddToBatch}
>
Add to Batch
</MenuItem>
<MenuItem icon={<FaFolder />} onClickCapture={handleAddToBoard}>
{image.board_id ? 'Change Board' : 'Add to Board'}
</MenuItem>
{image.board_id && (
<MenuItem icon={<FaFolder />} onClickCapture={handleRemoveFromBoard}>
Remove from Board
</MenuItem>
)}
<MenuItem
sx={{ color: 'error.600', _dark: { color: 'error.300' } }}
icon={<FaTrash />}
onClickCapture={handleDelete}
>
{t('gallery.deleteImage')}
</MenuItem>
</>
);
};

View File

@ -0,0 +1,212 @@
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { useCallback } from 'react';
import { UnsafeImageMetadata } from 'services/api/endpoints/images';
import MetadataItem from './MetadataItem';
type Props = {
metadata?: UnsafeImageMetadata['metadata'];
};
const ImageMetadataActions = (props: Props) => {
const { metadata } = props;
const {
recallBothPrompts,
recallPositivePrompt,
recallNegativePrompt,
recallSeed,
recallInitialImage,
recallCfgScale,
recallModel,
recallScheduler,
recallSteps,
recallWidth,
recallHeight,
recallStrength,
recallAllParameters,
} = useRecallParameters();
const handleRecallPositivePrompt = useCallback(() => {
recallPositivePrompt(metadata?.positive_prompt);
}, [metadata?.positive_prompt, recallPositivePrompt]);
const handleRecallNegativePrompt = useCallback(() => {
recallNegativePrompt(metadata?.negative_prompt);
}, [metadata?.negative_prompt, recallNegativePrompt]);
const handleRecallSeed = useCallback(() => {
recallSeed(metadata?.seed);
}, [metadata?.seed, recallSeed]);
const handleRecallModel = useCallback(() => {
recallModel(metadata?.model);
}, [metadata?.model, recallModel]);
const handleRecallWidth = useCallback(() => {
recallWidth(metadata?.width);
}, [metadata?.width, recallWidth]);
const handleRecallHeight = useCallback(() => {
recallHeight(metadata?.height);
}, [metadata?.height, recallHeight]);
const handleRecallScheduler = useCallback(() => {
recallScheduler(metadata?.scheduler);
}, [metadata?.scheduler, recallScheduler]);
const handleRecallSteps = useCallback(() => {
recallSteps(metadata?.steps);
}, [metadata?.steps, recallSteps]);
const handleRecallCfgScale = useCallback(() => {
recallCfgScale(metadata?.cfg_scale);
}, [metadata?.cfg_scale, recallCfgScale]);
const handleRecallStrength = useCallback(() => {
recallStrength(metadata?.strength);
}, [metadata?.strength, recallStrength]);
if (!metadata || Object.keys(metadata).length === 0) {
return null;
}
return (
<>
{metadata.generation_mode && (
<MetadataItem
label="Generation Mode"
value={metadata.generation_mode}
/>
)}
{metadata.positive_prompt && (
<MetadataItem
label="Positive Prompt"
labelPosition="top"
value={metadata.positive_prompt}
onClick={handleRecallPositivePrompt}
/>
)}
{metadata.negative_prompt && (
<MetadataItem
label="Negative Prompt"
labelPosition="top"
value={metadata.negative_prompt}
onClick={handleRecallNegativePrompt}
/>
)}
{metadata.seed !== undefined && (
<MetadataItem
label="Seed"
value={metadata.seed}
onClick={handleRecallSeed}
/>
)}
{metadata.model !== undefined && (
<MetadataItem
label="Model"
value={metadata.model.model_name}
onClick={handleRecallModel}
/>
)}
{metadata.width && (
<MetadataItem
label="Width"
value={metadata.width}
onClick={handleRecallWidth}
/>
)}
{metadata.height && (
<MetadataItem
label="Height"
value={metadata.height}
onClick={handleRecallHeight}
/>
)}
{/* {metadata.threshold !== undefined && (
<MetadataItem
label="Noise Threshold"
value={metadata.threshold}
onClick={() => dispatch(setThreshold(Number(metadata.threshold)))}
/>
)}
{metadata.perlin !== undefined && (
<MetadataItem
label="Perlin Noise"
value={metadata.perlin}
onClick={() => dispatch(setPerlin(Number(metadata.perlin)))}
/>
)} */}
{metadata.scheduler && (
<MetadataItem
label="Scheduler"
value={metadata.scheduler}
onClick={handleRecallScheduler}
/>
)}
{metadata.steps && (
<MetadataItem
label="Steps"
value={metadata.steps}
onClick={handleRecallSteps}
/>
)}
{metadata.cfg_scale !== undefined && (
<MetadataItem
label="CFG scale"
value={metadata.cfg_scale}
onClick={handleRecallCfgScale}
/>
)}
{/* {metadata.variations && metadata.variations.length > 0 && (
<MetadataItem
label="Seed-weight pairs"
value={seedWeightsToString(metadata.variations)}
onClick={() =>
dispatch(
setSeedWeights(seedWeightsToString(metadata.variations))
)
}
/>
)}
{metadata.seamless && (
<MetadataItem
label="Seamless"
value={metadata.seamless}
onClick={() => dispatch(setSeamless(metadata.seamless))}
/>
)}
{metadata.hires_fix && (
<MetadataItem
label="High Resolution Optimization"
value={metadata.hires_fix}
onClick={() => dispatch(setHiresFix(metadata.hires_fix))}
/>
)} */}
{/* {init_image_path && (
<MetadataItem
label="Initial image"
value={init_image_path}
isLink
onClick={() => dispatch(setInitialImage(init_image_path))}
/>
)} */}
{metadata.strength && (
<MetadataItem
label="Image to image strength"
value={metadata.strength}
onClick={handleRecallStrength}
/>
)}
{/* {metadata.fit && (
<MetadataItem
label="Image to image fit"
value={metadata.fit}
onClick={() => dispatch(setShouldFitToWidthHeight(metadata.fit))}
/>
)} */}
</>
);
};
export default ImageMetadataActions;

View File

@ -1,131 +1,74 @@
import { ExternalLinkIcon } from '@chakra-ui/icons'; import { ExternalLinkIcon } from '@chakra-ui/icons';
import { import {
Box,
Center,
Flex, Flex,
IconButton,
Link, Link,
Tab,
TabList,
TabPanel,
TabPanels,
Tabs,
Text, Text,
Tooltip,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks'; import { skipToken } from '@reduxjs/toolkit/dist/query';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters'; import { memo, useMemo } from 'react';
import { setShouldShowImageDetails } from 'features/ui/store/uiSlice'; import { useGetImageMetadataQuery } from 'services/api/endpoints/images';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import { memo } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import { FaCopy } from 'react-icons/fa';
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
import { ImageDTO } from 'services/api/types'; import { ImageDTO } from 'services/api/types';
import { useDebounce } from 'use-debounce';
type MetadataItemProps = { import ImageMetadataActions from './ImageMetadataActions';
isLink?: boolean; import MetadataJSONViewer from './MetadataJSONViewer';
label: string;
onClick?: () => void;
value: number | string | boolean;
labelPosition?: string;
withCopy?: boolean;
};
/**
* Component to display an individual metadata item or parameter.
*/
const MetadataItem = ({
label,
value,
onClick,
isLink,
labelPosition,
withCopy = false,
}: MetadataItemProps) => {
const { t } = useTranslation();
if (!value) {
return null;
}
return (
<Flex gap={2}>
{onClick && (
<Tooltip label={`Recall ${label}`}>
<IconButton
aria-label={t('accessibility.useThisParameter')}
icon={<IoArrowUndoCircleOutline />}
size="xs"
variant="ghost"
fontSize={20}
onClick={onClick}
/>
</Tooltip>
)}
{withCopy && (
<Tooltip label={`Copy ${label}`}>
<IconButton
aria-label={`Copy ${label}`}
icon={<FaCopy />}
size="xs"
variant="ghost"
fontSize={14}
onClick={() => navigator.clipboard.writeText(value.toString())}
/>
</Tooltip>
)}
<Flex direction={labelPosition ? 'column' : 'row'}>
<Text fontWeight="semibold" whiteSpace="pre-wrap" pr={2}>
{label}:
</Text>
{isLink ? (
<Link href={value.toString()} isExternal wordBreak="break-all">
{value.toString()} <ExternalLinkIcon mx="2px" />
</Link>
) : (
<Text overflowY="scroll" wordBreak="break-all">
{value.toString()}
</Text>
)}
</Flex>
</Flex>
);
};
type ImageMetadataViewerProps = { type ImageMetadataViewerProps = {
image: ImageDTO; image: ImageDTO;
}; };
/**
* Image metadata viewer overlays currently selected image and provides
* access to any of its metadata for use in processing.
*/
const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => { const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
const dispatch = useAppDispatch(); // TODO: fix hotkeys
const { // const dispatch = useAppDispatch();
recallBothPrompts, // useHotkeys('esc', () => {
recallPositivePrompt, // dispatch(setShouldShowImageDetails(false));
recallNegativePrompt, // });
recallSeed,
recallInitialImage,
recallCfgScale,
recallModel,
recallScheduler,
recallSteps,
recallWidth,
recallHeight,
recallStrength,
recallAllParameters,
} = useRecallParameters();
useHotkeys('esc', () => { const [debouncedMetadataQueryArg, debounceState] = useDebounce(
dispatch(setShouldShowImageDetails(false)); image.image_name,
500
);
const { currentData } = useGetImageMetadataQuery(
debounceState.isPending()
? skipToken
: debouncedMetadataQueryArg ?? skipToken
);
const metadata = currentData?.metadata;
const graph = currentData?.graph;
const tabData = useMemo(() => {
const _tabData: { label: string; data: object; copyTooltip: string }[] = [];
if (metadata) {
_tabData.push({
label: 'Core Metadata',
data: metadata,
copyTooltip: 'Copy Core Metadata JSON',
}); });
}
const sessionId = image?.session_id; if (image) {
_tabData.push({
label: 'Image Details',
data: image,
copyTooltip: 'Copy Image Details JSON',
});
}
const metadata = image?.metadata; if (graph) {
_tabData.push({
const { t } = useTranslation(); label: 'Graph',
data: graph,
const metadataJSON = JSON.stringify(image, null, 2); copyTooltip: 'Copy Graph JSON',
});
}
return _tabData;
}, [metadata, graph, image]);
return ( return (
<Flex <Flex
@ -136,11 +79,13 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
width: 'full', width: 'full',
height: 'full', height: 'full',
backdropFilter: 'blur(20px)', backdropFilter: 'blur(20px)',
bg: 'whiteAlpha.600', bg: 'baseAlpha.200',
_dark: { _dark: {
bg: 'blackAlpha.600', bg: 'blackAlpha.600',
}, },
overflow: 'scroll', borderRadius: 'base',
position: 'absolute',
overflow: 'hidden',
}} }}
> >
<Flex gap={2}> <Flex gap={2}>
@ -150,179 +95,42 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
<ExternalLinkIcon mx="2px" /> <ExternalLinkIcon mx="2px" />
</Link> </Link>
</Flex> </Flex>
{metadata && Object.keys(metadata).length > 0 ? (
<>
{metadata.type && (
<MetadataItem label="Invocation type" value={metadata.type} />
)}
{sessionId && <MetadataItem label="Session ID" value={sessionId} />}
{metadata.positive_conditioning && (
<MetadataItem
label="Positive Prompt"
labelPosition="top"
value={metadata.positive_conditioning}
onClick={() =>
recallPositivePrompt(metadata.positive_conditioning)
}
/>
)}
{metadata.negative_conditioning && (
<MetadataItem
label="Negative Prompt"
labelPosition="top"
value={metadata.negative_conditioning}
onClick={() =>
recallNegativePrompt(metadata.negative_conditioning)
}
/>
)}
{metadata.seed !== undefined && (
<MetadataItem
label="Seed"
value={metadata.seed}
onClick={() => recallSeed(metadata.seed)}
/>
)}
{metadata.model !== undefined && (
<MetadataItem
label="Model"
value={metadata.model}
onClick={() => recallModel(metadata.model)}
/>
)}
{metadata.width && (
<MetadataItem
label="Width"
value={metadata.width}
onClick={() => recallWidth(metadata.width)}
/>
)}
{metadata.height && (
<MetadataItem
label="Height"
value={metadata.height}
onClick={() => recallHeight(metadata.height)}
/>
)}
{/* {metadata.threshold !== undefined && (
<MetadataItem
label="Noise Threshold"
value={metadata.threshold}
onClick={() => dispatch(setThreshold(Number(metadata.threshold)))}
/>
)}
{metadata.perlin !== undefined && (
<MetadataItem
label="Perlin Noise"
value={metadata.perlin}
onClick={() => dispatch(setPerlin(Number(metadata.perlin)))}
/>
)} */}
{metadata.scheduler && (
<MetadataItem
label="Scheduler"
value={metadata.scheduler}
onClick={() => recallScheduler(metadata.scheduler)}
/>
)}
{metadata.steps && (
<MetadataItem
label="Steps"
value={metadata.steps}
onClick={() => recallSteps(metadata.steps)}
/>
)}
{metadata.cfg_scale !== undefined && (
<MetadataItem
label="CFG scale"
value={metadata.cfg_scale}
onClick={() => recallCfgScale(metadata.cfg_scale)}
/>
)}
{/* {metadata.variations && metadata.variations.length > 0 && (
<MetadataItem
label="Seed-weight pairs"
value={seedWeightsToString(metadata.variations)}
onClick={() =>
dispatch(
setSeedWeights(seedWeightsToString(metadata.variations))
)
}
/>
)}
{metadata.seamless && (
<MetadataItem
label="Seamless"
value={metadata.seamless}
onClick={() => dispatch(setSeamless(metadata.seamless))}
/>
)}
{metadata.hires_fix && (
<MetadataItem
label="High Resolution Optimization"
value={metadata.hires_fix}
onClick={() => dispatch(setHiresFix(metadata.hires_fix))}
/>
)} */}
{/* {init_image_path && ( <ImageMetadataActions metadata={metadata} />
<MetadataItem
label="Initial image" <Tabs
value={init_image_path} variant="line"
isLink sx={{ display: 'flex', flexDir: 'column', w: 'full', h: 'full' }}
onClick={() => dispatch(setInitialImage(init_image_path))} >
/> <TabList>
)} */} {tabData.map((tab) => (
{metadata.strength && ( <Tab
<MetadataItem key={tab.label}
label="Image to image strength"
value={metadata.strength}
onClick={() => recallStrength(metadata.strength)}
/>
)}
{/* {metadata.fit && (
<MetadataItem
label="Image to image fit"
value={metadata.fit}
onClick={() => dispatch(setShouldFitToWidthHeight(metadata.fit))}
/>
)} */}
</>
) : (
<Center width="100%" pt={10}>
<Text fontSize="lg" fontWeight="semibold">
No metadata available
</Text>
</Center>
)}
<Flex gap={2} direction="column" overflow="auto">
<Flex gap={2}>
<Tooltip label="Copy metadata JSON">
<IconButton
aria-label={t('accessibility.copyMetadataJson')}
icon={<FaCopy />}
size="xs"
variant="ghost"
fontSize={14}
onClick={() => navigator.clipboard.writeText(metadataJSON)}
/>
</Tooltip>
<Text fontWeight="semibold">Metadata JSON:</Text>
</Flex>
<OverlayScrollbarsComponent defer>
<Box
sx={{ sx={{
padding: 4, borderTopRadius: 'base',
borderRadius: 'base',
bg: 'whiteAlpha.500',
_dark: { bg: 'blackAlpha.500' },
w: 'full',
}} }}
> >
<pre>{metadataJSON}</pre> <Text sx={{ color: 'base.700', _dark: { color: 'base.300' } }}>
</Box> {tab.label}
</OverlayScrollbarsComponent> </Text>
</Flex> </Tab>
))}
</TabList>
<TabPanels sx={{ w: 'full', h: 'full' }}>
{tabData.map((tab) => (
<TabPanel
key={tab.label}
sx={{ w: 'full', h: 'full', p: 0, pt: 4 }}
>
<MetadataJSONViewer
jsonObject={tab.data}
copyTooltip={tab.copyTooltip}
/>
</TabPanel>
))}
</TabPanels>
</Tabs>
</Flex> </Flex>
); );
}; };

View File

@ -0,0 +1,77 @@
import { ExternalLinkIcon } from '@chakra-ui/icons';
import { Flex, IconButton, Link, Text, Tooltip } from '@chakra-ui/react';
import { useTranslation } from 'react-i18next';
import { FaCopy } from 'react-icons/fa';
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
type MetadataItemProps = {
isLink?: boolean;
label: string;
onClick?: () => void;
value: number | string | boolean;
labelPosition?: string;
withCopy?: boolean;
};
/**
* Component to display an individual metadata item or parameter.
*/
const MetadataItem = ({
label,
value,
onClick,
isLink,
labelPosition,
withCopy = false,
}: MetadataItemProps) => {
const { t } = useTranslation();
if (!value) {
return null;
}
return (
<Flex gap={2}>
{onClick && (
<Tooltip label={`Recall ${label}`}>
<IconButton
aria-label={t('accessibility.useThisParameter')}
icon={<IoArrowUndoCircleOutline />}
size="xs"
variant="ghost"
fontSize={20}
onClick={onClick}
/>
</Tooltip>
)}
{withCopy && (
<Tooltip label={`Copy ${label}`}>
<IconButton
aria-label={`Copy ${label}`}
icon={<FaCopy />}
size="xs"
variant="ghost"
fontSize={14}
onClick={() => navigator.clipboard.writeText(value.toString())}
/>
</Tooltip>
)}
<Flex direction={labelPosition ? 'column' : 'row'}>
<Text fontWeight="semibold" whiteSpace="pre-wrap" pr={2}>
{label}:
</Text>
{isLink ? (
<Link href={value.toString()} isExternal wordBreak="break-all">
{value.toString()} <ExternalLinkIcon mx="2px" />
</Link>
) : (
<Text overflowY="scroll" wordBreak="break-all">
{value.toString()}
</Text>
)}
</Flex>
</Flex>
);
};
export default MetadataItem;

View File

@ -0,0 +1,70 @@
import { Box, Flex, IconButton, Tooltip } from '@chakra-ui/react';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import { useMemo } from 'react';
import { FaCopy } from 'react-icons/fa';
type Props = {
copyTooltip: string;
jsonObject: object;
};
const MetadataJSONViewer = (props: Props) => {
const { copyTooltip, jsonObject } = props;
const jsonString = useMemo(
() => JSON.stringify(jsonObject, null, 2),
[jsonObject]
);
return (
<Flex
sx={{
borderRadius: 'base',
bg: 'whiteAlpha.500',
_dark: { bg: 'blackAlpha.500' },
flexGrow: 1,
w: 'full',
h: 'full',
position: 'relative',
}}
>
<Box
sx={{
position: 'absolute',
top: 0,
left: 0,
right: 0,
bottom: 0,
overflow: 'auto',
p: 4,
}}
>
<OverlayScrollbarsComponent
defer
style={{ height: '100%', width: '100%' }}
options={{
scrollbars: {
visibility: 'auto',
autoHide: 'move',
autoHideDelay: 1300,
theme: 'os-theme-dark',
},
}}
>
<pre>{jsonString}</pre>
</OverlayScrollbarsComponent>
</Box>
<Flex sx={{ position: 'absolute', top: 0, insetInlineEnd: 0, p: 2 }}>
<Tooltip label={copyTooltip}>
<IconButton
aria-label={copyTooltip}
icon={<FaCopy />}
variant="ghost"
onClick={() => navigator.clipboard.writeText(jsonString)}
/>
</Tooltip>
</Flex>
</Flex>
);
};
export default MetadataJSONViewer;

View File

@ -1,18 +1,8 @@
import { ChakraProps, Flex, Grid, IconButton, Spinner } from '@chakra-ui/react'; import { ChakraProps, Flex, Grid, IconButton, Spinner } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
imageSelected,
selectFilteredImages,
selectImagesById,
} from 'features/gallery/store/gallerySlice';
import { clamp, isEqual } from 'lodash-es';
import { memo, useCallback, useState } from 'react'; import { memo, useCallback, useState } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { FaAngleDoubleRight, FaAngleLeft, FaAngleRight } from 'react-icons/fa'; import { FaAngleDoubleRight, FaAngleLeft, FaAngleRight } from 'react-icons/fa';
import { receivedPageOfImages } from 'services/api/thunks/image'; import { useNextPrevImage } from '../hooks/useNextPrevImage';
const nextPrevButtonTriggerAreaStyles: ChakraProps['sx'] = { const nextPrevButtonTriggerAreaStyles: ChakraProps['sx'] = {
height: '100%', height: '100%',
@ -24,74 +14,18 @@ const nextPrevButtonStyles: ChakraProps['sx'] = {
color: 'base.100', color: 'base.100',
}; };
export const nextPrevImageButtonsSelector = createSelector(
[stateSelector, selectFilteredImages],
(state, filteredImages) => {
const { total, isFetching } = state.gallery;
const lastSelectedImage =
state.gallery.selection[state.gallery.selection.length - 1];
if (!lastSelectedImage || filteredImages.length === 0) {
return {
isOnFirstImage: true,
isOnLastImage: true,
};
}
const currentImageIndex = filteredImages.findIndex(
(i) => i.image_name === lastSelectedImage
);
const nextImageIndex = clamp(
currentImageIndex + 1,
0,
filteredImages.length - 1
);
const prevImageIndex = clamp(
currentImageIndex - 1,
0,
filteredImages.length - 1
);
const nextImageId = filteredImages[nextImageIndex].image_name;
const prevImageId = filteredImages[prevImageIndex].image_name;
const nextImage = selectImagesById(state, nextImageId);
const prevImage = selectImagesById(state, prevImageId);
const imagesLength = filteredImages.length;
return {
isOnFirstImage: currentImageIndex === 0,
isOnLastImage:
!isNaN(currentImageIndex) && currentImageIndex === imagesLength - 1,
areMoreImagesAvailable: total > imagesLength,
isFetching,
nextImage,
prevImage,
nextImageId,
prevImageId,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const NextPrevImageButtons = () => { const NextPrevImageButtons = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const { const {
handlePrevImage,
handleNextImage,
isOnFirstImage, isOnFirstImage,
isOnLastImage, isOnLastImage,
nextImageId, handleLoadMoreImages,
prevImageId,
areMoreImagesAvailable, areMoreImagesAvailable,
isFetching, isFetching,
} = useAppSelector(nextPrevImageButtonsSelector); } = useNextPrevImage();
const [shouldShowNextPrevButtons, setShouldShowNextPrevButtons] = const [shouldShowNextPrevButtons, setShouldShowNextPrevButtons] =
useState<boolean>(false); useState<boolean>(false);
@ -104,50 +38,6 @@ const NextPrevImageButtons = () => {
setShouldShowNextPrevButtons(false); setShouldShowNextPrevButtons(false);
}, []); }, []);
const handlePrevImage = useCallback(() => {
prevImageId && dispatch(imageSelected(prevImageId));
}, [dispatch, prevImageId]);
const handleNextImage = useCallback(() => {
nextImageId && dispatch(imageSelected(nextImageId));
}, [dispatch, nextImageId]);
const handleLoadMoreImages = useCallback(() => {
dispatch(
receivedPageOfImages({
is_intermediate: false,
})
);
}, [dispatch]);
useHotkeys(
'left',
() => {
handlePrevImage();
},
[prevImageId]
);
useHotkeys(
'right',
() => {
if (isOnLastImage && areMoreImagesAvailable && !isFetching) {
handleLoadMoreImages();
return;
}
if (!isOnLastImage) {
handleNextImage();
}
},
[
nextImageId,
isOnLastImage,
areMoreImagesAvailable,
handleLoadMoreImages,
isFetching,
]
);
return ( return (
<Flex <Flex
sx={{ sx={{

View File

@ -0,0 +1,108 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
imageSelected,
selectFilteredImages,
selectImagesById,
} from 'features/gallery/store/gallerySlice';
import { clamp, isEqual } from 'lodash-es';
import { useCallback } from 'react';
import { receivedPageOfImages } from 'services/api/thunks/image';
export const nextPrevImageButtonsSelector = createSelector(
[stateSelector, selectFilteredImages],
(state, filteredImages) => {
const { total, isFetching } = state.gallery;
const lastSelectedImage =
state.gallery.selection[state.gallery.selection.length - 1];
if (!lastSelectedImage || filteredImages.length === 0) {
return {
isOnFirstImage: true,
isOnLastImage: true,
};
}
const currentImageIndex = filteredImages.findIndex(
(i) => i.image_name === lastSelectedImage
);
const nextImageIndex = clamp(
currentImageIndex + 1,
0,
filteredImages.length - 1
);
const prevImageIndex = clamp(
currentImageIndex - 1,
0,
filteredImages.length - 1
);
const nextImageId = filteredImages[nextImageIndex].image_name;
const prevImageId = filteredImages[prevImageIndex].image_name;
const nextImage = selectImagesById(state, nextImageId);
const prevImage = selectImagesById(state, prevImageId);
const imagesLength = filteredImages.length;
return {
isOnFirstImage: currentImageIndex === 0,
isOnLastImage:
!isNaN(currentImageIndex) && currentImageIndex === imagesLength - 1,
areMoreImagesAvailable: total > imagesLength,
isFetching,
nextImage,
prevImage,
nextImageId,
prevImageId,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
export const useNextPrevImage = () => {
const dispatch = useAppDispatch();
const {
isOnFirstImage,
isOnLastImage,
nextImageId,
prevImageId,
areMoreImagesAvailable,
isFetching,
} = useAppSelector(nextPrevImageButtonsSelector);
const handlePrevImage = useCallback(() => {
prevImageId && dispatch(imageSelected(prevImageId));
}, [dispatch, prevImageId]);
const handleNextImage = useCallback(() => {
nextImageId && dispatch(imageSelected(nextImageId));
}, [dispatch, nextImageId]);
const handleLoadMoreImages = useCallback(() => {
dispatch(
receivedPageOfImages({
is_intermediate: false,
})
);
}, [dispatch]);
return {
handlePrevImage,
handleNextImage,
isOnFirstImage,
isOnLastImage,
nextImageId,
prevImageId,
areMoreImagesAvailable,
handleLoadMoreImages,
isFetching,
};
};

View File

@ -37,7 +37,7 @@ const ParamLora = (props: Props) => {
return ( return (
<Flex sx={{ gap: 2.5, alignItems: 'flex-end' }}> <Flex sx={{ gap: 2.5, alignItems: 'flex-end' }}>
<IAISlider <IAISlider
label={lora.name} label={lora.model_name}
value={lora.weight} value={lora.weight}
onChange={handleChange} onChange={handleChange}
min={-1} min={-1}

View File

@ -18,7 +18,7 @@ const selector = createSelector(
const ParamLoraList = () => { const ParamLoraList = () => {
const { loras } = useAppSelector(selector); const { loras } = useAppSelector(selector);
return map(loras, (lora) => <ParamLora key={lora.name} lora={lora} />); return map(loras, (lora) => <ParamLora key={lora.model_name} lora={lora} />);
}; };
export default ParamLoraList; export default ParamLoraList;

View File

@ -45,7 +45,7 @@ const ParamLoraSelect = () => {
data.push({ data.push({
value: id, value: id,
label: lora.name, label: lora.model_name,
disabled, disabled,
group: MODEL_TYPE_MAP[lora.base_model], group: MODEL_TYPE_MAP[lora.base_model],
tooltip: disabled tooltip: disabled

View File

@ -1,12 +1,8 @@
import { PayloadAction, createSlice } from '@reduxjs/toolkit'; import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import { LoRAModelParam } from 'features/parameters/store/parameterZodSchemas'; import { LoRAModelParam } from 'features/parameters/store/parameterZodSchemas';
import { LoRAModelConfigEntity } from 'services/api/endpoints/models'; import { LoRAModelConfigEntity } from 'services/api/endpoints/models';
import { BaseModelType } from 'services/api/types';
export type Lora = { export type Lora = LoRAModelParam & {
id: string;
base_model: BaseModelType;
name: string;
weight: number; weight: number;
}; };
@ -27,8 +23,8 @@ export const loraSlice = createSlice({
initialState: intialLoraState, initialState: intialLoraState,
reducers: { reducers: {
loraAdded: (state, action: PayloadAction<LoRAModelConfigEntity>) => { loraAdded: (state, action: PayloadAction<LoRAModelConfigEntity>) => {
const { name, id, base_model } = action.payload; const { model_name, id, base_model } = action.payload;
state.loras[id] = { id, name, base_model, ...defaultLoRAConfig }; state.loras[id] = { id, model_name, base_model, ...defaultLoRAConfig };
}, },
loraRemoved: (state, action: PayloadAction<string>) => { loraRemoved: (state, action: PayloadAction<string>) => {
const id = action.payload; const id = action.payload;

View File

@ -7,6 +7,7 @@ import {
OnConnectEnd, OnConnectEnd,
OnConnectStart, OnConnectStart,
OnEdgesChange, OnEdgesChange,
OnInit,
OnNodesChange, OnNodesChange,
ReactFlow, ReactFlow,
} from 'reactflow'; } from 'reactflow';
@ -16,6 +17,7 @@ import {
connectionStarted, connectionStarted,
edgesChanged, edgesChanged,
nodesChanged, nodesChanged,
setEditorInstance,
} from '../store/nodesSlice'; } from '../store/nodesSlice';
import { InvocationComponent } from './InvocationComponent'; import { InvocationComponent } from './InvocationComponent';
import ProgressImageNode from './ProgressImageNode'; import ProgressImageNode from './ProgressImageNode';
@ -67,6 +69,14 @@ export const Flow = () => {
dispatch(connectionEnded()); dispatch(connectionEnded());
}, [dispatch]); }, [dispatch]);
const onInit: OnInit = useCallback(
(v) => {
dispatch(setEditorInstance(v));
if (v) v.fitView();
},
[dispatch]
);
return ( return (
<ReactFlow <ReactFlow
nodeTypes={nodeTypes} nodeTypes={nodeTypes}
@ -77,6 +87,7 @@ export const Flow = () => {
onConnectStart={onConnectStart} onConnectStart={onConnectStart}
onConnect={onConnect} onConnect={onConnect}
onConnectEnd={onConnectEnd} onConnectEnd={onConnectEnd}
onInit={onInit}
defaultEdgeOptions={{ defaultEdgeOptions={{
style: { strokeWidth: 2 }, style: { strokeWidth: 2 },
}} }}

View File

@ -45,7 +45,7 @@ const LoRAModelInputFieldComponent = (
data.push({ data.push({
value: id, value: id,
label: model.name, label: model.model_name,
group: BASE_MODEL_NAME_MAP[model.base_model], group: BASE_MODEL_NAME_MAP[model.base_model],
}); });
}); });

View File

@ -38,7 +38,7 @@ const ModelInputFieldComponent = (
data.push({ data.push({
value: id, value: id,
label: model.name, label: model.model_name,
group: BASE_MODEL_NAME_MAP[model.base_model], group: BASE_MODEL_NAME_MAP[model.base_model],
}); });
}); });

View File

@ -45,7 +45,7 @@ const VaeModelInputFieldComponent = (
data.push({ data.push({
value: id, value: id,
label: model.name, label: model.model_name,
group: BASE_MODEL_NAME_MAP[model.base_model], group: BASE_MODEL_NAME_MAP[model.base_model],
}); });
}); });

View File

@ -1,25 +1,23 @@
import { HStack } from '@chakra-ui/react'; import { HStack } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import { memo, useCallback } from 'react';
import { Panel } from 'reactflow';
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
import NodeInvokeButton from '../ui/NodeInvokeButton';
import CancelButton from 'features/parameters/components/ProcessButtons/CancelButton'; import CancelButton from 'features/parameters/components/ProcessButtons/CancelButton';
import { memo } from 'react';
import { Panel } from 'reactflow';
import LoadNodesButton from '../ui/LoadNodesButton';
import NodeInvokeButton from '../ui/NodeInvokeButton';
import ReloadSchemaButton from '../ui/ReloadSchemaButton';
import SaveNodesButton from '../ui/SaveNodesButton';
import ClearNodesButton from '../ui/ClearNodesButton';
const TopCenterPanel = () => { const TopCenterPanel = () => {
const dispatch = useAppDispatch();
const handleReloadSchema = useCallback(() => {
dispatch(receivedOpenAPISchema());
}, [dispatch]);
return ( return (
<Panel position="top-center"> <Panel position="top-center">
<HStack> <HStack>
<NodeInvokeButton /> <NodeInvokeButton />
<CancelButton /> <CancelButton />
<IAIButton onClick={handleReloadSchema}>Reload Schema</IAIButton> <ReloadSchemaButton />
<SaveNodesButton />
<LoadNodesButton />
<ClearNodesButton />
</HStack> </HStack>
</Panel> </Panel>
); );

View File

@ -0,0 +1,86 @@
import {
AlertDialog,
AlertDialogBody,
AlertDialogContent,
AlertDialogFooter,
AlertDialogHeader,
AlertDialogOverlay,
Button,
Text,
useDisclosure,
} from '@chakra-ui/react';
import { makeToast } from 'app/components/Toaster';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
import { addToast } from 'features/system/store/systemSlice';
import { memo, useRef, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { FaTrash } from 'react-icons/fa';
const ClearNodesButton = () => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const { isOpen, onOpen, onClose } = useDisclosure();
const cancelRef = useRef<HTMLButtonElement | null>(null);
const nodes = useAppSelector((state: RootState) => state.nodes.nodes);
const handleConfirmClear = useCallback(() => {
dispatch(nodeEditorReset());
dispatch(
addToast(
makeToast({
title: t('toast.nodesCleared'),
status: 'success',
})
)
);
onClose();
}, [dispatch, t, onClose]);
return (
<>
<IAIIconButton
icon={<FaTrash />}
tooltip={t('nodes.clearNodes')}
aria-label={t('nodes.clearNodes')}
onClick={onOpen}
isDisabled={nodes.length === 0}
/>
<AlertDialog
isOpen={isOpen}
onClose={onClose}
leastDestructiveRef={cancelRef}
isCentered
>
<AlertDialogOverlay />
<AlertDialogContent>
<AlertDialogHeader fontSize="lg" fontWeight="bold">
{t('nodes.clearNodes')}
</AlertDialogHeader>
<AlertDialogBody>
<Text>{t('common.clearNodes')}</Text>
</AlertDialogBody>
<AlertDialogFooter>
<Button ref={cancelRef} onClick={onClose}>
{t('common.cancel')}
</Button>
<Button colorScheme="red" ml={3} onClick={handleConfirmClear}>
{t('common.accept')}
</Button>
</AlertDialogFooter>
</AlertDialogContent>
</AlertDialog>
</>
);
};
export default memo(ClearNodesButton);

View File

@ -0,0 +1,79 @@
import { FileButton } from '@mantine/core';
import { makeToast } from 'app/components/Toaster';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import { loadFileEdges, loadFileNodes } from 'features/nodes/store/nodesSlice';
import { addToast } from 'features/system/store/systemSlice';
import { memo, useCallback, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import { FaUpload } from 'react-icons/fa';
import { useReactFlow } from 'reactflow';
const LoadNodesButton = () => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const { fitView } = useReactFlow();
const uploadedFileRef = useRef<() => void>(null);
const restoreJSONToEditor = useCallback(
(v: File | null) => {
if (!v) return;
const reader = new FileReader();
reader.onload = async () => {
const json = reader.result;
const retrievedNodeTree = await JSON.parse(String(json));
if (!retrievedNodeTree) {
dispatch(
addToast(
makeToast({
title: t('toast.nodesLoadedFailed'),
status: 'error',
})
)
);
}
if (retrievedNodeTree) {
dispatch(loadFileNodes(retrievedNodeTree.nodes));
dispatch(loadFileEdges(retrievedNodeTree.edges));
fitView();
dispatch(
addToast(
makeToast({ title: t('toast.nodesLoaded'), status: 'success' })
)
);
}
// Cleanup
reader.abort();
};
reader.readAsText(v);
// Cleanup
uploadedFileRef.current?.();
},
[fitView, dispatch, t]
);
return (
<FileButton
resetRef={uploadedFileRef}
accept="application/json"
onChange={restoreJSONToEditor}
>
{(props) => (
<IAIIconButton
icon={<FaUpload />}
tooltip={t('nodes.loadNodes')}
aria-label={t('nodes.loadNodes')}
{...props}
/>
)}
</FileButton>
);
};
export default memo(LoadNodesButton);

View File

@ -0,0 +1,24 @@
import { useAppDispatch } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { FaSyncAlt } from 'react-icons/fa';
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
export default function ReloadSchemaButton() {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const handleReloadSchema = useCallback(() => {
dispatch(receivedOpenAPISchema());
}, [dispatch]);
return (
<IAIIconButton
icon={<FaSyncAlt />}
tooltip={t('nodes.reloadSchema')}
aria-label={t('nodes.reloadSchema')}
onClick={handleReloadSchema}
/>
);
}

View File

@ -0,0 +1,48 @@
import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import { map, omit } from 'lodash-es';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { FaSave } from 'react-icons/fa';
const SaveNodesButton = () => {
const { t } = useTranslation();
const editorInstance = useAppSelector(
(state: RootState) => state.nodes.editorInstance
);
const nodes = useAppSelector((state: RootState) => state.nodes.nodes);
const saveEditorToJSON = useCallback(() => {
if (editorInstance) {
const editorState = editorInstance.toObject();
editorState.edges = map(editorState.edges, (edge) => {
return omit(edge, ['style']);
});
const nodeSetupJSON = new Blob([JSON.stringify(editorState)]);
const nodeDownloadElement = document.createElement('a');
nodeDownloadElement.href = URL.createObjectURL(nodeSetupJSON);
nodeDownloadElement.download = 'MyNodes.json';
document.body.appendChild(nodeDownloadElement);
nodeDownloadElement.click();
// Cleanup
nodeDownloadElement.remove();
}
}, [editorInstance]);
return (
<IAIIconButton
icon={<FaSave />}
fontSize={18}
tooltip={t('nodes.saveNodes')}
aria-label={t('nodes.saveNodes')}
onClick={saveEditorToJSON}
isDisabled={nodes.length === 0}
/>
);
};
export default memo(SaveNodesButton);

View File

@ -13,6 +13,7 @@ import {
Node, Node,
NodeChange, NodeChange,
OnConnectStartParams, OnConnectStartParams,
ReactFlowInstance,
} from 'reactflow'; } from 'reactflow';
import { receivedOpenAPISchema } from 'services/api/thunks/schema'; import { receivedOpenAPISchema } from 'services/api/thunks/schema';
import { ImageField } from 'services/api/types'; import { ImageField } from 'services/api/types';
@ -25,6 +26,7 @@ export type NodesState = {
invocationTemplates: Record<string, InvocationTemplate>; invocationTemplates: Record<string, InvocationTemplate>;
connectionStartParams: OnConnectStartParams | null; connectionStartParams: OnConnectStartParams | null;
shouldShowGraphOverlay: boolean; shouldShowGraphOverlay: boolean;
editorInstance: ReactFlowInstance | undefined;
}; };
export const initialNodesState: NodesState = { export const initialNodesState: NodesState = {
@ -34,6 +36,7 @@ export const initialNodesState: NodesState = {
invocationTemplates: {}, invocationTemplates: {},
connectionStartParams: null, connectionStartParams: null,
shouldShowGraphOverlay: false, shouldShowGraphOverlay: false,
editorInstance: undefined,
}; };
const nodesSlice = createSlice({ const nodesSlice = createSlice({
@ -118,8 +121,18 @@ const nodesSlice = createSlice({
) => { ) => {
state.invocationTemplates = action.payload; state.invocationTemplates = action.payload;
}, },
nodeEditorReset: () => { nodeEditorReset: (state) => {
return { ...initialNodesState }; state.nodes = [];
state.edges = [];
},
setEditorInstance: (state, action) => {
state.editorInstance = action.payload;
},
loadFileNodes: (state, action: PayloadAction<Node<InvocationValue>[]>) => {
state.nodes = action.payload;
},
loadFileEdges: (state, action: PayloadAction<Edge[]>) => {
state.edges = action.payload;
}, },
}, },
extraReducers: (builder) => { extraReducers: (builder) => {
@ -141,6 +154,9 @@ export const {
nodeTemplatesBuilt, nodeTemplatesBuilt,
nodeEditorReset, nodeEditorReset,
imageCollectionFieldValueChanged, imageCollectionFieldValueChanged,
setEditorInstance,
loadFileNodes,
loadFileEdges,
} = nodesSlice.actions; } = nodesSlice.actions;
export default nodesSlice.reducer; export default nodesSlice.reducer;

View File

@ -1,94 +0,0 @@
import { RootState } from 'app/store/store';
import { getValidControlNets } from 'features/controlNet/util/getValidControlNets';
import { CollectInvocation, ControlNetInvocation } from 'services/api/types';
import { NonNullableGraph } from '../types/types';
import { CONTROL_NET_COLLECT } from './graphBuilders/constants';
export const addControlNetToLinearGraph = (
graph: NonNullableGraph,
baseNodeId: string,
state: RootState
): void => {
const { isEnabled: isControlNetEnabled, controlNets } = state.controlNet;
const validControlNets = getValidControlNets(controlNets);
if (isControlNetEnabled && Boolean(validControlNets.length)) {
if (validControlNets.length > 1) {
// We have multiple controlnets, add ControlNet collector
const controlNetIterateNode: CollectInvocation = {
id: CONTROL_NET_COLLECT,
type: 'collect',
};
graph.nodes[controlNetIterateNode.id] = controlNetIterateNode;
graph.edges.push({
source: { node_id: controlNetIterateNode.id, field: 'collection' },
destination: {
node_id: baseNodeId,
field: 'control',
},
});
}
validControlNets.forEach((controlNet) => {
const {
controlNetId,
controlImage,
processedControlImage,
beginStepPct,
endStepPct,
controlMode,
model,
processorType,
weight,
} = controlNet;
const controlNetNode: ControlNetInvocation = {
id: `control_net_${controlNetId}`,
type: 'controlnet',
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
control_mode: controlMode,
control_model: model as ControlNetInvocation['control_model'],
control_weight: weight,
};
if (processedControlImage && processorType !== 'none') {
// We've already processed the image in the app, so we can just use the processed image
controlNetNode.image = {
image_name: processedControlImage,
};
} else if (controlImage) {
// The control image is preprocessed
controlNetNode.image = {
image_name: controlImage,
};
} else {
// Skip ControlNets without an unprocessed image - should never happen if everything is working correctly
return;
}
graph.nodes[controlNetNode.id] = controlNetNode;
if (validControlNets.length > 1) {
// if we have multiple controlnets, link to the collector
graph.edges.push({
source: { node_id: controlNetNode.id, field: 'control' },
destination: {
node_id: CONTROL_NET_COLLECT,
field: 'item',
},
});
} else {
// otherwise, link directly to the base node
graph.edges.push({
source: { node_id: controlNetNode.id, field: 'control' },
destination: {
node_id: baseNodeId,
field: 'control',
},
});
}
});
}
};

View File

@ -1,40 +0,0 @@
import {
Edge,
ImageToImageInvocation,
InpaintInvocation,
IterateInvocation,
RandomRangeInvocation,
RangeInvocation,
TextToImageInvocation,
} from 'services/api/types';
export const buildEdges = (
baseNode: TextToImageInvocation | ImageToImageInvocation | InpaintInvocation,
rangeNode: RangeInvocation | RandomRangeInvocation,
iterateNode: IterateInvocation
): Edge[] => {
const edges: Edge[] = [
{
source: {
node_id: rangeNode.id,
field: 'collection',
},
destination: {
node_id: iterateNode.id,
field: 'collection',
},
},
{
source: {
node_id: iterateNode.id,
field: 'item',
},
destination: {
node_id: baseNode.id,
field: 'seed',
},
},
];
return edges;
};

View File

@ -0,0 +1,102 @@
import { RootState } from 'app/store/store';
import { getValidControlNets } from 'features/controlNet/util/getValidControlNets';
import { omit } from 'lodash-es';
import {
CollectInvocation,
ControlField,
ControlNetInvocation,
MetadataAccumulatorInvocation,
} from 'services/api/types';
import { NonNullableGraph } from '../../types/types';
import { CONTROL_NET_COLLECT, METADATA_ACCUMULATOR } from './constants';
export const addControlNetToLinearGraph = (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string
): void => {
const { isEnabled: isControlNetEnabled, controlNets } = state.controlNet;
const validControlNets = getValidControlNets(controlNets);
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
| MetadataAccumulatorInvocation
| undefined;
if (isControlNetEnabled && Boolean(validControlNets.length)) {
if (validControlNets.length) {
// We have multiple controlnets, add ControlNet collector
const controlNetIterateNode: CollectInvocation = {
id: CONTROL_NET_COLLECT,
type: 'collect',
};
graph.nodes[CONTROL_NET_COLLECT] = controlNetIterateNode;
graph.edges.push({
source: { node_id: CONTROL_NET_COLLECT, field: 'collection' },
destination: {
node_id: baseNodeId,
field: 'control',
},
});
validControlNets.forEach((controlNet) => {
const {
controlNetId,
controlImage,
processedControlImage,
beginStepPct,
endStepPct,
controlMode,
model,
processorType,
weight,
} = controlNet;
const controlNetNode: ControlNetInvocation = {
id: `control_net_${controlNetId}`,
type: 'controlnet',
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
control_mode: controlMode,
control_model: model as ControlNetInvocation['control_model'],
control_weight: weight,
};
if (processedControlImage && processorType !== 'none') {
// We've already processed the image in the app, so we can just use the processed image
controlNetNode.image = {
image_name: processedControlImage,
};
} else if (controlImage) {
// The control image is preprocessed
controlNetNode.image = {
image_name: controlImage,
};
} else {
// Skip ControlNets without an unprocessed image - should never happen if everything is working correctly
return;
}
graph.nodes[controlNetNode.id] = controlNetNode;
if (metadataAccumulator) {
// metadata accumulator only needs a control field - not the whole node
// extract what we need and add to the accumulator
const controlField = omit(controlNetNode, [
'id',
'type',
]) as ControlField;
metadataAccumulator.controlnets.push(controlField);
}
graph.edges.push({
source: { node_id: controlNetNode.id, field: 'control' },
destination: {
node_id: CONTROL_NET_COLLECT,
field: 'item',
},
});
});
}
}
};

View File

@ -1,8 +1,10 @@
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types'; import { NonNullableGraph } from 'features/nodes/types/types';
import { unset } from 'lodash-es';
import { import {
DynamicPromptInvocation, DynamicPromptInvocation,
IterateInvocation, IterateInvocation,
MetadataAccumulatorInvocation,
NoiseInvocation, NoiseInvocation,
RandomIntInvocation, RandomIntInvocation,
RangeOfSizeInvocation, RangeOfSizeInvocation,
@ -10,16 +12,16 @@ import {
import { import {
DYNAMIC_PROMPT, DYNAMIC_PROMPT,
ITERATE, ITERATE,
METADATA_ACCUMULATOR,
NOISE, NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
RANDOM_INT, RANDOM_INT,
RANGE_OF_SIZE, RANGE_OF_SIZE,
} from './constants'; } from './constants';
import { unset } from 'lodash-es';
export const addDynamicPromptsToGraph = ( export const addDynamicPromptsToGraph = (
graph: NonNullableGraph, state: RootState,
state: RootState graph: NonNullableGraph
): void => { ): void => {
const { positivePrompt, iterations, seed, shouldRandomizeSeed } = const { positivePrompt, iterations, seed, shouldRandomizeSeed } =
state.generation; state.generation;
@ -30,6 +32,10 @@ export const addDynamicPromptsToGraph = (
maxPrompts, maxPrompts,
} = state.dynamicPrompts; } = state.dynamicPrompts;
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
| MetadataAccumulatorInvocation
| undefined;
if (isDynamicPromptsEnabled) { if (isDynamicPromptsEnabled) {
// iteration is handled via dynamic prompts // iteration is handled via dynamic prompts
unset(graph.nodes[POSITIVE_CONDITIONING], 'prompt'); unset(graph.nodes[POSITIVE_CONDITIONING], 'prompt');
@ -74,6 +80,18 @@ export const addDynamicPromptsToGraph = (
} }
); );
// hook up positive prompt to metadata
graph.edges.push({
source: {
node_id: ITERATE,
field: 'item',
},
destination: {
node_id: METADATA_ACCUMULATOR,
field: 'positive_prompt',
},
});
if (shouldRandomizeSeed) { if (shouldRandomizeSeed) {
// Random int node to generate the starting seed // Random int node to generate the starting seed
const randomIntNode: RandomIntInvocation = { const randomIntNode: RandomIntInvocation = {
@ -88,11 +106,26 @@ export const addDynamicPromptsToGraph = (
source: { node_id: RANDOM_INT, field: 'a' }, source: { node_id: RANDOM_INT, field: 'a' },
destination: { node_id: NOISE, field: 'seed' }, destination: { node_id: NOISE, field: 'seed' },
}); });
graph.edges.push({
source: { node_id: RANDOM_INT, field: 'a' },
destination: { node_id: METADATA_ACCUMULATOR, field: 'seed' },
});
} else { } else {
// User specified seed, so set the start of the range of size to the seed // User specified seed, so set the start of the range of size to the seed
(graph.nodes[NOISE] as NoiseInvocation).seed = seed; (graph.nodes[NOISE] as NoiseInvocation).seed = seed;
// hook up seed to metadata
if (metadataAccumulator) {
metadataAccumulator.seed = seed;
}
} }
} else { } else {
// no dynamic prompt - hook up positive prompt
if (metadataAccumulator) {
metadataAccumulator.positive_prompt = positivePrompt;
}
const rangeOfSizeNode: RangeOfSizeInvocation = { const rangeOfSizeNode: RangeOfSizeInvocation = {
id: RANGE_OF_SIZE, id: RANGE_OF_SIZE,
type: 'range_of_size', type: 'range_of_size',
@ -130,6 +163,18 @@ export const addDynamicPromptsToGraph = (
}, },
}); });
// hook up seed to metadata
graph.edges.push({
source: {
node_id: ITERATE,
field: 'item',
},
destination: {
node_id: METADATA_ACCUMULATOR,
field: 'seed',
},
});
// handle seed // handle seed
if (shouldRandomizeSeed) { if (shouldRandomizeSeed) {
// Random int node to generate the starting seed // Random int node to generate the starting seed

View File

@ -1,19 +1,23 @@
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types'; import { NonNullableGraph } from 'features/nodes/types/types';
import { forEach, size } from 'lodash-es'; import { forEach, size } from 'lodash-es';
import { LoraLoaderInvocation } from 'services/api/types'; import {
LoraLoaderInvocation,
MetadataAccumulatorInvocation,
} from 'services/api/types';
import { modelIdToLoRAModelField } from '../modelIdToLoRAName'; import { modelIdToLoRAModelField } from '../modelIdToLoRAName';
import { import {
CLIP_SKIP, CLIP_SKIP,
LORA_LOADER, LORA_LOADER,
MAIN_MODEL_LOADER, MAIN_MODEL_LOADER,
METADATA_ACCUMULATOR,
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
} from './constants'; } from './constants';
export const addLoRAsToGraph = ( export const addLoRAsToGraph = (
graph: NonNullableGraph,
state: RootState, state: RootState,
graph: NonNullableGraph,
baseNodeId: string baseNodeId: string
): void => { ): void => {
/** /**
@ -26,6 +30,9 @@ export const addLoRAsToGraph = (
const { loras } = state.lora; const { loras } = state.lora;
const loraCount = size(loras); const loraCount = size(loras);
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
| MetadataAccumulatorInvocation
| undefined;
if (loraCount > 0) { if (loraCount > 0) {
// Remove MAIN_MODEL_LOADER unet connection to feed it to LoRAs // Remove MAIN_MODEL_LOADER unet connection to feed it to LoRAs
@ -62,6 +69,12 @@ export const addLoRAsToGraph = (
weight, weight,
}; };
// add the lora to the metadata accumulator
if (metadataAccumulator) {
metadataAccumulator.loras.push({ lora: loraField, weight });
}
// add to graph
graph.nodes[currentLoraNodeId] = loraLoaderNode; graph.nodes[currentLoraNodeId] = loraLoaderNode;
if (currentLoraIndex === 0) { if (currentLoraIndex === 0) {

View File

@ -1,5 +1,6 @@
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types'; import { NonNullableGraph } from 'features/nodes/types/types';
import { MetadataAccumulatorInvocation } from 'services/api/types';
import { modelIdToVAEModelField } from '../modelIdToVAEModelField'; import { modelIdToVAEModelField } from '../modelIdToVAEModelField';
import { import {
IMAGE_TO_IMAGE_GRAPH, IMAGE_TO_IMAGE_GRAPH,
@ -8,18 +9,22 @@ import {
INPAINT_GRAPH, INPAINT_GRAPH,
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
MAIN_MODEL_LOADER, MAIN_MODEL_LOADER,
METADATA_ACCUMULATOR,
TEXT_TO_IMAGE_GRAPH, TEXT_TO_IMAGE_GRAPH,
VAE_LOADER, VAE_LOADER,
} from './constants'; } from './constants';
export const addVAEToGraph = ( export const addVAEToGraph = (
graph: NonNullableGraph, state: RootState,
state: RootState graph: NonNullableGraph
): void => { ): void => {
const { vae } = state.generation; const { vae } = state.generation;
const vae_model = modelIdToVAEModelField(vae?.id || ''); const vae_model = modelIdToVAEModelField(vae?.id || '');
const isAutoVae = !vae; const isAutoVae = !vae;
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
| MetadataAccumulatorInvocation
| undefined;
if (!isAutoVae) { if (!isAutoVae) {
graph.nodes[VAE_LOADER] = { graph.nodes[VAE_LOADER] = {
@ -67,4 +72,8 @@ export const addVAEToGraph = (
}, },
}); });
} }
if (vae && metadataAccumulator) {
metadataAccumulator.vae = vae_model;
}
}; };

View File

@ -7,8 +7,7 @@ import {
ImageResizeInvocation, ImageResizeInvocation,
ImageToLatentsInvocation, ImageToLatentsInvocation,
} from 'services/api/types'; } from 'services/api/types';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addVAEToGraph } from './addVAEToGraph'; import { addVAEToGraph } from './addVAEToGraph';
@ -19,6 +18,7 @@ import {
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
LATENTS_TO_LATENTS, LATENTS_TO_LATENTS,
MAIN_MODEL_LOADER, MAIN_MODEL_LOADER,
METADATA_ACCUMULATOR,
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
NOISE, NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
@ -37,7 +37,7 @@ export const buildCanvasImageToImageGraph = (
const { const {
positivePrompt, positivePrompt,
negativePrompt, negativePrompt,
model: currentModel, model,
cfgScale: cfg_scale, cfgScale: cfg_scale,
scheduler, scheduler,
steps, steps,
@ -50,7 +50,10 @@ export const buildCanvasImageToImageGraph = (
// The bounding box determines width and height, not the width and height params // The bounding box determines width and height, not the width and height params
const { width, height } = state.canvas.boundingBoxDimensions; const { width, height } = state.canvas.boundingBoxDimensions;
const model = modelIdToMainModelField(currentModel?.id || ''); if (!model) {
moduleLog.error('No model found in state');
throw new Error('No model found in state');
}
const use_cpu = shouldUseNoiseSettings const use_cpu = shouldUseNoiseSettings
? shouldUseCpuNoise ? shouldUseCpuNoise
@ -275,16 +278,51 @@ export const buildCanvasImageToImageGraph = (
}); });
} }
addLoRAsToGraph(graph, state, LATENTS_TO_LATENTS); // add metadata accumulator, which is only mostly populated - some fields are added later
graph.nodes[METADATA_ACCUMULATOR] = {
id: METADATA_ACCUMULATOR,
type: 'metadata_accumulator',
generation_mode: 'img2img',
cfg_scale,
height,
width,
positive_prompt: '', // set in addDynamicPromptsToGraph
negative_prompt: negativePrompt,
model,
seed: 0, // set in addDynamicPromptsToGraph
steps,
rand_device: use_cpu ? 'cpu' : 'cuda',
scheduler,
vae: undefined, // option; set in addVAEToGraph
controlnets: [], // populated in addControlNetToLinearGraph
loras: [], // populated in addLoRAsToGraph
clip_skip: clipSkip,
strength,
init_image: initialImage.image_name,
};
// Add VAE graph.edges.push({
addVAEToGraph(graph, state); source: {
node_id: METADATA_ACCUMULATOR,
field: 'metadata',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'metadata',
},
});
// add dynamic prompts, mutating `graph` // add LoRA support
addDynamicPromptsToGraph(graph, state); addLoRAsToGraph(state, graph, LATENTS_TO_LATENTS);
// optionally add custom VAE
addVAEToGraph(state, graph);
// add dynamic prompts - also sets up core iteration and seed
addDynamicPromptsToGraph(state, graph);
// add controlnet, mutating `graph` // add controlnet, mutating `graph`
addControlNetToLinearGraph(graph, LATENTS_TO_LATENTS, state); addControlNetToLinearGraph(state, graph, LATENTS_TO_LATENTS);
return graph; return graph;
}; };

View File

@ -7,7 +7,6 @@ import {
RandomIntInvocation, RandomIntInvocation,
RangeOfSizeInvocation, RangeOfSizeInvocation,
} from 'services/api/types'; } from 'services/api/types';
import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addVAEToGraph } from './addVAEToGraph'; import { addVAEToGraph } from './addVAEToGraph';
import { import {
@ -35,7 +34,7 @@ export const buildCanvasInpaintGraph = (
const { const {
positivePrompt, positivePrompt,
negativePrompt, negativePrompt,
model: currentModel, model,
cfgScale: cfg_scale, cfgScale: cfg_scale,
scheduler, scheduler,
steps, steps,
@ -53,14 +52,17 @@ export const buildCanvasInpaintGraph = (
clipSkip, clipSkip,
} = state.generation; } = state.generation;
if (!model) {
moduleLog.error('No model found in state');
throw new Error('No model found in state');
}
// The bounding box determines width and height, not the width and height params // The bounding box determines width and height, not the width and height params
const { width, height } = state.canvas.boundingBoxDimensions; const { width, height } = state.canvas.boundingBoxDimensions;
// We may need to set the inpaint width and height to scale the image // We may need to set the inpaint width and height to scale the image
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas; const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
const model = modelIdToMainModelField(currentModel?.id || '');
const graph: NonNullableGraph = { const graph: NonNullableGraph = {
id: INPAINT_GRAPH, id: INPAINT_GRAPH,
nodes: { nodes: {
@ -212,10 +214,10 @@ export const buildCanvasInpaintGraph = (
], ],
}; };
addLoRAsToGraph(graph, state, INPAINT); addLoRAsToGraph(state, graph, INPAINT);
// Add VAE // Add VAE
addVAEToGraph(graph, state); addVAEToGraph(state, graph);
// handle seed // handle seed
if (shouldRandomizeSeed) { if (shouldRandomizeSeed) {

View File

@ -1,8 +1,8 @@
import { log } from 'app/logging/useLogger';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types'; import { NonNullableGraph } from 'features/nodes/types/types';
import { initialGenerationState } from 'features/parameters/store/generationSlice'; import { initialGenerationState } from 'features/parameters/store/generationSlice';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addVAEToGraph } from './addVAEToGraph'; import { addVAEToGraph } from './addVAEToGraph';
@ -10,6 +10,7 @@ import {
CLIP_SKIP, CLIP_SKIP,
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
MAIN_MODEL_LOADER, MAIN_MODEL_LOADER,
METADATA_ACCUMULATOR,
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
NOISE, NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
@ -17,6 +18,8 @@ import {
TEXT_TO_LATENTS, TEXT_TO_LATENTS,
} from './constants'; } from './constants';
const moduleLog = log.child({ namespace: 'nodes' });
/** /**
* Builds the Canvas tab's Text to Image graph. * Builds the Canvas tab's Text to Image graph.
*/ */
@ -26,7 +29,7 @@ export const buildCanvasTextToImageGraph = (
const { const {
positivePrompt, positivePrompt,
negativePrompt, negativePrompt,
model: currentModel, model,
cfgScale: cfg_scale, cfgScale: cfg_scale,
scheduler, scheduler,
steps, steps,
@ -38,7 +41,10 @@ export const buildCanvasTextToImageGraph = (
// The bounding box determines width and height, not the width and height params // The bounding box determines width and height, not the width and height params
const { width, height } = state.canvas.boundingBoxDimensions; const { width, height } = state.canvas.boundingBoxDimensions;
const model = modelIdToMainModelField(currentModel?.id || ''); if (!model) {
moduleLog.error('No model found in state');
throw new Error('No model found in state');
}
const use_cpu = shouldUseNoiseSettings const use_cpu = shouldUseNoiseSettings
? shouldUseCpuNoise ? shouldUseCpuNoise
@ -180,16 +186,49 @@ export const buildCanvasTextToImageGraph = (
], ],
}; };
addLoRAsToGraph(graph, state, TEXT_TO_LATENTS); // add metadata accumulator, which is only mostly populated - some fields are added later
graph.nodes[METADATA_ACCUMULATOR] = {
id: METADATA_ACCUMULATOR,
type: 'metadata_accumulator',
generation_mode: 'txt2img',
cfg_scale,
height,
width,
positive_prompt: '', // set in addDynamicPromptsToGraph
negative_prompt: negativePrompt,
model,
seed: 0, // set in addDynamicPromptsToGraph
steps,
rand_device: use_cpu ? 'cpu' : 'cuda',
scheduler,
vae: undefined, // option; set in addVAEToGraph
controlnets: [], // populated in addControlNetToLinearGraph
loras: [], // populated in addLoRAsToGraph
clip_skip: clipSkip,
};
// Add VAE graph.edges.push({
addVAEToGraph(graph, state); source: {
node_id: METADATA_ACCUMULATOR,
field: 'metadata',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'metadata',
},
});
// add dynamic prompts, mutating `graph` // add LoRA support
addDynamicPromptsToGraph(graph, state); addLoRAsToGraph(state, graph, TEXT_TO_LATENTS);
// optionally add custom VAE
addVAEToGraph(state, graph);
// add dynamic prompts - also sets up core iteration and seed
addDynamicPromptsToGraph(state, graph);
// add controlnet, mutating `graph` // add controlnet, mutating `graph`
addControlNetToLinearGraph(graph, TEXT_TO_LATENTS, state); addControlNetToLinearGraph(state, graph, TEXT_TO_LATENTS);
return graph; return graph;
}; };

View File

@ -3,25 +3,21 @@ import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types'; import { NonNullableGraph } from 'features/nodes/types/types';
import { initialGenerationState } from 'features/parameters/store/generationSlice'; import { initialGenerationState } from 'features/parameters/store/generationSlice';
import { import {
ImageCollectionInvocation,
ImageResizeInvocation, ImageResizeInvocation,
ImageToLatentsInvocation, ImageToLatentsInvocation,
IterateInvocation,
} from 'services/api/types'; } from 'services/api/types';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addVAEToGraph } from './addVAEToGraph'; import { addVAEToGraph } from './addVAEToGraph';
import { import {
CLIP_SKIP, CLIP_SKIP,
IMAGE_COLLECTION,
IMAGE_COLLECTION_ITERATE,
IMAGE_TO_IMAGE_GRAPH, IMAGE_TO_IMAGE_GRAPH,
IMAGE_TO_LATENTS, IMAGE_TO_LATENTS,
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
LATENTS_TO_LATENTS, LATENTS_TO_LATENTS,
MAIN_MODEL_LOADER, MAIN_MODEL_LOADER,
METADATA_ACCUMULATOR,
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
NOISE, NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
@ -39,7 +35,7 @@ export const buildLinearImageToImageGraph = (
const { const {
positivePrompt, positivePrompt,
negativePrompt, negativePrompt,
model: currentModel, model,
cfgScale: cfg_scale, cfgScale: cfg_scale,
scheduler, scheduler,
steps, steps,
@ -53,14 +49,15 @@ export const buildLinearImageToImageGraph = (
shouldUseNoiseSettings, shouldUseNoiseSettings,
} = state.generation; } = state.generation;
const { // TODO: add batch functionality
isEnabled: isBatchEnabled, // const {
imageNames: batchImageNames, // isEnabled: isBatchEnabled,
asInitialImage, // imageNames: batchImageNames,
} = state.batch; // asInitialImage,
// } = state.batch;
const shouldBatch = // const shouldBatch =
isBatchEnabled && batchImageNames.length > 0 && asInitialImage; // isBatchEnabled && batchImageNames.length > 0 && asInitialImage;
/** /**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
@ -71,12 +68,15 @@ export const buildLinearImageToImageGraph = (
* the `fit` param. These are added to the graph at the end. * the `fit` param. These are added to the graph at the end.
*/ */
if (!initialImage && !shouldBatch) { if (!initialImage) {
moduleLog.error('No initial image found in state'); moduleLog.error('No initial image found in state');
throw new Error('No initial image found in state'); throw new Error('No initial image found in state');
} }
const model = modelIdToMainModelField(currentModel?.id || ''); if (!model) {
moduleLog.error('No model found in state');
throw new Error('No model found in state');
}
const use_cpu = shouldUseNoiseSettings const use_cpu = shouldUseNoiseSettings
? shouldUseCpuNoise ? shouldUseCpuNoise
@ -295,51 +295,87 @@ export const buildLinearImageToImageGraph = (
}); });
} }
if (isBatchEnabled && asInitialImage && batchImageNames.length > 0) { // TODO: add batch functionality
// we are going to connect an iterate up to the init image // if (isBatchEnabled && asInitialImage && batchImageNames.length > 0) {
delete (graph.nodes[IMAGE_TO_LATENTS] as ImageToLatentsInvocation).image; // // we are going to connect an iterate up to the init image
// delete (graph.nodes[IMAGE_TO_LATENTS] as ImageToLatentsInvocation).image;
const imageCollection: ImageCollectionInvocation = { // const imageCollection: ImageCollectionInvocation = {
id: IMAGE_COLLECTION, // id: IMAGE_COLLECTION,
type: 'image_collection', // type: 'image_collection',
images: batchImageNames.map((image_name) => ({ image_name })), // images: batchImageNames.map((image_name) => ({ image_name })),
// };
// const imageCollectionIterate: IterateInvocation = {
// id: IMAGE_COLLECTION_ITERATE,
// type: 'iterate',
// };
// graph.nodes[IMAGE_COLLECTION] = imageCollection;
// graph.nodes[IMAGE_COLLECTION_ITERATE] = imageCollectionIterate;
// graph.edges.push({
// source: { node_id: IMAGE_COLLECTION, field: 'collection' },
// destination: {
// node_id: IMAGE_COLLECTION_ITERATE,
// field: 'collection',
// },
// });
// graph.edges.push({
// source: { node_id: IMAGE_COLLECTION_ITERATE, field: 'item' },
// destination: {
// node_id: IMAGE_TO_LATENTS,
// field: 'image',
// },
// });
// }
// add metadata accumulator, which is only mostly populated - some fields are added later
graph.nodes[METADATA_ACCUMULATOR] = {
id: METADATA_ACCUMULATOR,
type: 'metadata_accumulator',
generation_mode: 'img2img',
cfg_scale,
height,
width,
positive_prompt: '', // set in addDynamicPromptsToGraph
negative_prompt: negativePrompt,
model,
seed: 0, // set in addDynamicPromptsToGraph
steps,
rand_device: use_cpu ? 'cpu' : 'cuda',
scheduler,
vae: undefined, // option; set in addVAEToGraph
controlnets: [], // populated in addControlNetToLinearGraph
loras: [], // populated in addLoRAsToGraph
clip_skip: clipSkip,
strength,
init_image: initialImage.imageName,
}; };
const imageCollectionIterate: IterateInvocation = {
id: IMAGE_COLLECTION_ITERATE,
type: 'iterate',
};
graph.nodes[IMAGE_COLLECTION] = imageCollection;
graph.nodes[IMAGE_COLLECTION_ITERATE] = imageCollectionIterate;
graph.edges.push({ graph.edges.push({
source: { node_id: IMAGE_COLLECTION, field: 'collection' }, source: {
node_id: METADATA_ACCUMULATOR,
field: 'metadata',
},
destination: { destination: {
node_id: IMAGE_COLLECTION_ITERATE, node_id: LATENTS_TO_IMAGE,
field: 'collection', field: 'metadata',
}, },
}); });
graph.edges.push({ // add LoRA support
source: { node_id: IMAGE_COLLECTION_ITERATE, field: 'item' }, addLoRAsToGraph(state, graph, LATENTS_TO_LATENTS);
destination: {
node_id: IMAGE_TO_LATENTS,
field: 'image',
},
});
}
addLoRAsToGraph(graph, state, LATENTS_TO_LATENTS); // optionally add custom VAE
addVAEToGraph(state, graph);
// Add VAE // add dynamic prompts - also sets up core iteration and seed
addVAEToGraph(graph, state); addDynamicPromptsToGraph(state, graph);
// add dynamic prompts, mutating `graph`
addDynamicPromptsToGraph(graph, state);
// add controlnet, mutating `graph` // add controlnet, mutating `graph`
addControlNetToLinearGraph(graph, LATENTS_TO_LATENTS, state); addControlNetToLinearGraph(state, graph, LATENTS_TO_LATENTS);
return graph; return graph;
}; };

View File

@ -1,8 +1,8 @@
import { log } from 'app/logging/useLogger';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types'; import { NonNullableGraph } from 'features/nodes/types/types';
import { initialGenerationState } from 'features/parameters/store/generationSlice'; import { initialGenerationState } from 'features/parameters/store/generationSlice';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addVAEToGraph } from './addVAEToGraph'; import { addVAEToGraph } from './addVAEToGraph';
@ -10,6 +10,7 @@ import {
CLIP_SKIP, CLIP_SKIP,
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
MAIN_MODEL_LOADER, MAIN_MODEL_LOADER,
METADATA_ACCUMULATOR,
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
NOISE, NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
@ -17,13 +18,15 @@ import {
TEXT_TO_LATENTS, TEXT_TO_LATENTS,
} from './constants'; } from './constants';
const moduleLog = log.child({ namespace: 'nodes' });
export const buildLinearTextToImageGraph = ( export const buildLinearTextToImageGraph = (
state: RootState state: RootState
): NonNullableGraph => { ): NonNullableGraph => {
const { const {
positivePrompt, positivePrompt,
negativePrompt, negativePrompt,
model: currentModel, model,
cfgScale: cfg_scale, cfgScale: cfg_scale,
scheduler, scheduler,
steps, steps,
@ -34,12 +37,15 @@ export const buildLinearTextToImageGraph = (
shouldUseNoiseSettings, shouldUseNoiseSettings,
} = state.generation; } = state.generation;
const model = modelIdToMainModelField(currentModel?.id || '');
const use_cpu = shouldUseNoiseSettings const use_cpu = shouldUseNoiseSettings
? shouldUseCpuNoise ? shouldUseCpuNoise
: initialGenerationState.shouldUseCpuNoise; : initialGenerationState.shouldUseCpuNoise;
if (!model) {
moduleLog.error('No model found in state');
throw new Error('No model found in state');
}
/** /**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
* full graph here as a template. Then use the parameters from app state and set friendlier node * full graph here as a template. Then use the parameters from app state and set friendlier node
@ -176,16 +182,49 @@ export const buildLinearTextToImageGraph = (
], ],
}; };
addLoRAsToGraph(graph, state, TEXT_TO_LATENTS); // add metadata accumulator, which is only mostly populated - some fields are added later
graph.nodes[METADATA_ACCUMULATOR] = {
id: METADATA_ACCUMULATOR,
type: 'metadata_accumulator',
generation_mode: 'txt2img',
cfg_scale,
height,
width,
positive_prompt: '', // set in addDynamicPromptsToGraph
negative_prompt: negativePrompt,
model,
seed: 0, // set in addDynamicPromptsToGraph
steps,
rand_device: use_cpu ? 'cpu' : 'cuda',
scheduler,
vae: undefined, // option; set in addVAEToGraph
controlnets: [], // populated in addControlNetToLinearGraph
loras: [], // populated in addLoRAsToGraph
clip_skip: clipSkip,
};
// Add Custom VAE Support graph.edges.push({
addVAEToGraph(graph, state); source: {
node_id: METADATA_ACCUMULATOR,
field: 'metadata',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'metadata',
},
});
// add dynamic prompts, mutating `graph` // add LoRA support
addDynamicPromptsToGraph(graph, state); addLoRAsToGraph(state, graph, TEXT_TO_LATENTS);
// optionally add custom VAE
addVAEToGraph(state, graph);
// add dynamic prompts - also sets up core iteration and seed
addDynamicPromptsToGraph(state, graph);
// add controlnet, mutating `graph` // add controlnet, mutating `graph`
addControlNetToLinearGraph(graph, TEXT_TO_LATENTS, state); addControlNetToLinearGraph(state, graph, TEXT_TO_LATENTS);
return graph; return graph;
}; };

View File

@ -19,6 +19,7 @@ export const CONTROL_NET_COLLECT = 'control_net_collect';
export const DYNAMIC_PROMPT = 'dynamic_prompt'; export const DYNAMIC_PROMPT = 'dynamic_prompt';
export const IMAGE_COLLECTION = 'image_collection'; export const IMAGE_COLLECTION = 'image_collection';
export const IMAGE_COLLECTION_ITERATE = 'image_collection_iterate'; export const IMAGE_COLLECTION_ITERATE = 'image_collection_iterate';
export const METADATA_ACCUMULATOR = 'metadata_accumulator';
// friendly graph ids // friendly graph ids
export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph'; export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph';

View File

@ -5,17 +5,21 @@ import {
InputFieldTemplate, InputFieldTemplate,
InvocationSchemaObject, InvocationSchemaObject,
InvocationTemplate, InvocationTemplate,
isInvocationSchemaObject,
OutputFieldTemplate, OutputFieldTemplate,
isInvocationSchemaObject,
} from '../types/types'; } from '../types/types';
import { import {
buildInputFieldTemplate, buildInputFieldTemplate,
buildOutputFieldTemplates, buildOutputFieldTemplates,
} from './fieldTemplateBuilders'; } from './fieldTemplateBuilders';
const RESERVED_FIELD_NAMES = ['id', 'type', 'is_intermediate']; const RESERVED_FIELD_NAMES = ['id', 'type', 'is_intermediate', 'metadata'];
const invocationDenylist = ['Graph', 'InvocationMeta']; const invocationDenylist = [
'Graph',
'InvocationMeta',
'MetadataAccumulatorInvocation',
];
export const parseSchema = (openAPI: OpenAPIV3.Document) => { export const parseSchema = (openAPI: OpenAPIV3.Document) => {
// filter out non-invocation schemas, plus some tricky invocations for now // filter out non-invocation schemas, plus some tricky invocations for now

View File

@ -2,22 +2,26 @@ import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import { roundToMultiple } from 'common/util/roundDownToMultiple';
import { import {
canvasSelector, canvasSelector,
isStagingSelector, isStagingSelector,
} from 'features/canvas/store/canvasSelectors'; } from 'features/canvas/store/canvasSelectors';
import { setBoundingBoxDimensions } from 'features/canvas/store/canvasSlice'; import { setBoundingBoxDimensions } from 'features/canvas/store/canvasSlice';
import { uiSelector } from 'features/ui/store/uiSelectors';
import { memo } from 'react'; import { memo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
const selector = createSelector( const selector = createSelector(
[canvasSelector, isStagingSelector], [canvasSelector, isStagingSelector, uiSelector],
(canvas, isStaging) => { (canvas, isStaging, ui) => {
const { boundingBoxDimensions } = canvas; const { boundingBoxDimensions } = canvas;
const { aspectRatio } = ui;
return { return {
boundingBoxDimensions, boundingBoxDimensions,
isStaging, isStaging,
aspectRatio,
}; };
}, },
defaultSelectorOptions defaultSelectorOptions
@ -25,7 +29,8 @@ const selector = createSelector(
const ParamBoundingBoxWidth = () => { const ParamBoundingBoxWidth = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { boundingBoxDimensions, isStaging } = useAppSelector(selector); const { boundingBoxDimensions, isStaging, aspectRatio } =
useAppSelector(selector);
const { t } = useTranslation(); const { t } = useTranslation();
@ -36,6 +41,15 @@ const ParamBoundingBoxWidth = () => {
height: Math.floor(v), height: Math.floor(v),
}) })
); );
if (aspectRatio) {
const newWidth = roundToMultiple(v * aspectRatio, 64);
dispatch(
setBoundingBoxDimensions({
width: newWidth,
height: Math.floor(v),
})
);
}
}; };
const handleResetHeight = () => { const handleResetHeight = () => {
@ -45,6 +59,15 @@ const ParamBoundingBoxWidth = () => {
height: Math.floor(512), height: Math.floor(512),
}) })
); );
if (aspectRatio) {
const newWidth = roundToMultiple(512 * aspectRatio, 64);
dispatch(
setBoundingBoxDimensions({
width: newWidth,
height: Math.floor(512),
})
);
}
}; };
return ( return (

View File

@ -0,0 +1,57 @@
import { Flex, Spacer, Text } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import { flipBoundingBoxAxes } from 'features/canvas/store/canvasSlice';
import { useTranslation } from 'react-i18next';
import { MdOutlineSwapVert } from 'react-icons/md';
import ParamAspectRatio from '../../Core/ParamAspectRatio';
import ParamBoundingBoxHeight from './ParamBoundingBoxHeight';
import ParamBoundingBoxWidth from './ParamBoundingBoxWidth';
export default function ParamBoundingBoxSize() {
const dispatch = useAppDispatch();
const { t } = useTranslation();
return (
<Flex
sx={{
gap: 2,
p: 4,
borderRadius: 4,
flexDirection: 'column',
w: 'full',
bg: 'base.150',
_dark: {
bg: 'base.750',
},
}}
>
<Flex alignItems="center" gap={2}>
<Text
sx={{
fontSize: 'sm',
width: 'full',
color: 'base.700',
_dark: {
color: 'base.300',
},
}}
>
{t('parameters.aspectRatio')}
</Text>
<Spacer />
<ParamAspectRatio />
<IAIIconButton
tooltip={t('ui.swapSizes')}
aria-label={t('ui.swapSizes')}
size="sm"
icon={<MdOutlineSwapVert />}
fontSize={20}
onClick={() => dispatch(flipBoundingBoxAxes())}
/>
</Flex>
<ParamBoundingBoxWidth />
<ParamBoundingBoxHeight />
</Flex>
);
}

View File

@ -2,22 +2,26 @@ import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import { roundToMultiple } from 'common/util/roundDownToMultiple';
import { import {
canvasSelector, canvasSelector,
isStagingSelector, isStagingSelector,
} from 'features/canvas/store/canvasSelectors'; } from 'features/canvas/store/canvasSelectors';
import { setBoundingBoxDimensions } from 'features/canvas/store/canvasSlice'; import { setBoundingBoxDimensions } from 'features/canvas/store/canvasSlice';
import { uiSelector } from 'features/ui/store/uiSelectors';
import { memo } from 'react'; import { memo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
const selector = createSelector( const selector = createSelector(
[canvasSelector, isStagingSelector], [canvasSelector, isStagingSelector, uiSelector],
(canvas, isStaging) => { (canvas, isStaging, ui) => {
const { boundingBoxDimensions } = canvas; const { boundingBoxDimensions } = canvas;
const { aspectRatio } = ui;
return { return {
boundingBoxDimensions, boundingBoxDimensions,
isStaging, isStaging,
aspectRatio,
}; };
}, },
defaultSelectorOptions defaultSelectorOptions
@ -25,7 +29,8 @@ const selector = createSelector(
const ParamBoundingBoxWidth = () => { const ParamBoundingBoxWidth = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { boundingBoxDimensions, isStaging } = useAppSelector(selector); const { boundingBoxDimensions, isStaging, aspectRatio } =
useAppSelector(selector);
const { t } = useTranslation(); const { t } = useTranslation();
@ -36,6 +41,15 @@ const ParamBoundingBoxWidth = () => {
width: Math.floor(v), width: Math.floor(v),
}) })
); );
if (aspectRatio) {
const newHeight = roundToMultiple(v / aspectRatio, 64);
dispatch(
setBoundingBoxDimensions({
width: Math.floor(v),
height: newHeight,
})
);
}
}; };
const handleResetWidth = () => { const handleResetWidth = () => {
@ -45,6 +59,15 @@ const ParamBoundingBoxWidth = () => {
width: Math.floor(512), width: Math.floor(512),
}) })
); );
if (aspectRatio) {
const newHeight = roundToMultiple(512 / aspectRatio, 64);
dispatch(
setBoundingBoxDimensions({
width: Math.floor(512),
height: newHeight,
})
);
}
}; };
return ( return (

View File

@ -27,6 +27,9 @@ const ParamNoiseCollapse = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const isNoiseEnabled = useFeatureStatus('noise').isFeatureEnabled; const isNoiseEnabled = useFeatureStatus('noise').isFeatureEnabled;
const isPerlinNoiseEnabled = useFeatureStatus('perlinNoise').isFeatureEnabled;
const isNoiseThresholdEnabled =
useFeatureStatus('noiseThreshold').isFeatureEnabled;
const { activeLabel } = useAppSelector(selector); const { activeLabel } = useAppSelector(selector);
@ -42,8 +45,8 @@ const ParamNoiseCollapse = () => {
<Flex sx={{ gap: 2, flexDirection: 'column' }}> <Flex sx={{ gap: 2, flexDirection: 'column' }}>
<ParamNoiseToggle /> <ParamNoiseToggle />
<ParamCpuNoiseToggle /> <ParamCpuNoiseToggle />
<ParamPerlinNoise /> {isPerlinNoiseEnabled && <ParamPerlinNoise />}
<ParamNoiseThreshold /> {isNoiseThresholdEnabled && <ParamNoiseThreshold />}
</Flex> </Flex>
</IAICollapse> </IAICollapse>
); );

View File

@ -2,6 +2,7 @@ import { useAppToaster } from 'app/components/Toaster';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import { useCallback } from 'react'; import { useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { UnsafeImageMetadata } from 'services/api/endpoints/images';
import { isImageField } from 'services/api/guards'; import { isImageField } from 'services/api/guards';
import { ImageDTO } from 'services/api/types'; import { ImageDTO } from 'services/api/types';
import { initialImageSelected, modelSelected } from '../store/actions'; import { initialImageSelected, modelSelected } from '../store/actions';
@ -162,7 +163,7 @@ export const useRecallParameters = () => {
parameterNotSetToast(); parameterNotSetToast();
return; return;
} }
dispatch(modelSelected(model?.id || '')); dispatch(modelSelected(model));
parameterSetToast(); parameterSetToast();
}, },
[dispatch, parameterSetToast, parameterNotSetToast] [dispatch, parameterSetToast, parameterNotSetToast]
@ -269,28 +270,24 @@ export const useRecallParameters = () => {
); );
const recallAllParameters = useCallback( const recallAllParameters = useCallback(
(image: ImageDTO | undefined) => { (metadata: UnsafeImageMetadata['metadata'] | undefined) => {
if (!image || !image.metadata) { if (!metadata) {
allParameterNotSetToast(); allParameterNotSetToast();
return; return;
} }
const { const {
cfg_scale, cfg_scale,
height, height,
model, model,
positive_conditioning, positive_prompt,
negative_conditioning, negative_prompt,
scheduler, scheduler,
seed, seed,
steps, steps,
width, width,
strength, strength,
clip, } = metadata;
extra,
latents,
unet,
vae,
} = image.metadata;
if (isValidCfgScale(cfg_scale)) { if (isValidCfgScale(cfg_scale)) {
dispatch(setCfgScale(cfg_scale)); dispatch(setCfgScale(cfg_scale));
@ -298,11 +295,11 @@ export const useRecallParameters = () => {
if (isValidMainModel(model)) { if (isValidMainModel(model)) {
dispatch(modelSelected(model)); dispatch(modelSelected(model));
} }
if (isValidPositivePrompt(positive_conditioning)) { if (isValidPositivePrompt(positive_prompt)) {
dispatch(setPositivePrompt(positive_conditioning)); dispatch(setPositivePrompt(positive_prompt));
} }
if (isValidNegativePrompt(negative_conditioning)) { if (isValidNegativePrompt(negative_prompt)) {
dispatch(setNegativePrompt(negative_conditioning)); dispatch(setNegativePrompt(negative_prompt));
} }
if (isValidScheduler(scheduler)) { if (isValidScheduler(scheduler)) {
dispatch(setScheduler(scheduler)); dispatch(setScheduler(scheduler));

View File

@ -1,8 +1,10 @@
import { createAction } from '@reduxjs/toolkit'; import { createAction } from '@reduxjs/toolkit';
import { ImageDTO } from 'services/api/types'; import { ImageDTO, MainModelField } from 'services/api/types';
export const initialImageSelected = createAction<ImageDTO | string | undefined>( export const initialImageSelected = createAction<ImageDTO | string | undefined>(
'generation/initialImageSelected' 'generation/initialImageSelected'
); );
export const modelSelected = createAction<string>('generation/modelSelected'); export const modelSelected = createAction<MainModelField>(
'generation/modelSelected'
);

View File

@ -8,12 +8,11 @@ import {
setShouldShowAdvancedOptions, setShouldShowAdvancedOptions,
} from 'features/ui/store/uiSlice'; } from 'features/ui/store/uiSlice';
import { clamp } from 'lodash-es'; import { clamp } from 'lodash-es';
import { ImageDTO } from 'services/api/types'; import { ImageDTO, MainModelField } from 'services/api/types';
import { clipSkipMap } from '../components/Parameters/Advanced/ParamClipSkip'; import { clipSkipMap } from '../components/Parameters/Advanced/ParamClipSkip';
import { import {
CfgScaleParam, CfgScaleParam,
HeightParam, HeightParam,
MainModelParam,
NegativePromptParam, NegativePromptParam,
PositivePromptParam, PositivePromptParam,
SchedulerParam, SchedulerParam,
@ -54,7 +53,7 @@ export interface GenerationState {
shouldUseSymmetry: boolean; shouldUseSymmetry: boolean;
horizontalSymmetrySteps: number; horizontalSymmetrySteps: number;
verticalSymmetrySteps: number; verticalSymmetrySteps: number;
model: MainModelParam | null; model: MainModelField | null;
vae: VaeModelParam | null; vae: VaeModelParam | null;
seamlessXAxis: boolean; seamlessXAxis: boolean;
seamlessYAxis: boolean; seamlessYAxis: boolean;
@ -227,23 +226,17 @@ export const generationSlice = createSlice({
const { image_name, width, height } = action.payload; const { image_name, width, height } = action.payload;
state.initialImage = { imageName: image_name, width, height }; state.initialImage = { imageName: image_name, width, height };
}, },
modelSelected: (state, action: PayloadAction<string>) => { modelChanged: (state, action: PayloadAction<MainModelField | null>) => {
const [base_model, type, name] = action.payload.split('/'); if (!action.payload) {
state.model = null;
}
state.model = zMainModel.parse({ state.model = zMainModel.parse(action.payload);
id: action.payload,
base_model,
name,
type,
});
// Clamp ClipSkip Based On Selected Model // Clamp ClipSkip Based On Selected Model
const { maxClip } = clipSkipMap[state.model.base_model]; const { maxClip } = clipSkipMap[state.model.base_model];
state.clipSkip = clamp(state.clipSkip, 0, maxClip); state.clipSkip = clamp(state.clipSkip, 0, maxClip);
}, },
modelChanged: (state, action: PayloadAction<MainModelParam>) => {
state.model = action.payload;
},
vaeSelected: (state, action: PayloadAction<VaeModelParam | null>) => { vaeSelected: (state, action: PayloadAction<VaeModelParam | null>) => {
state.vae = action.payload; state.vae = action.payload;
}, },

View File

@ -135,8 +135,7 @@ export type BaseModelParam = z.infer<typeof zBaseModel>;
* TODO: Make this a dynamically generated enum? * TODO: Make this a dynamically generated enum?
*/ */
export const zMainModel = z.object({ export const zMainModel = z.object({
id: z.string(), model_name: z.string(),
name: z.string(),
base_model: zBaseModel, base_model: zBaseModel,
}); });
@ -171,7 +170,7 @@ export const isValidVaeModel = (val: unknown): val is VaeModelParam =>
*/ */
export const zLoRAModel = z.object({ export const zLoRAModel = z.object({
id: z.string(), id: z.string(),
name: z.string(), model_name: z.string(),
base_model: zBaseModel, base_model: zBaseModel,
}); });
/** /**

Some files were not shown because too many files have changed in this diff Show More