mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into fix/controlnet_cfg_inj_cond
This commit is contained in:
commit
9348dc8e0d
@ -1,25 +1,9 @@
|
||||
# use this file as a whitelist
|
||||
*
|
||||
!invokeai
|
||||
!ldm
|
||||
!pyproject.toml
|
||||
!docker/docker-entrypoint.sh
|
||||
!LICENSE
|
||||
|
||||
# ignore frontend/web but whitelist dist
|
||||
invokeai/frontend/web/
|
||||
!invokeai/frontend/web/dist/
|
||||
|
||||
# ignore invokeai/assets but whitelist invokeai/assets/web
|
||||
invokeai/assets/
|
||||
!invokeai/assets/web/
|
||||
|
||||
# Guard against pulling in any models that might exist in the directory tree
|
||||
**/*.pt*
|
||||
**/*.ckpt
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
**/__pycache__/
|
||||
**/*.py[cod]
|
||||
|
||||
# Distribution / packaging
|
||||
**/*.egg-info/
|
||||
**/*.egg
|
||||
**/node_modules
|
||||
**/__pycache__
|
||||
**/*.egg-info
|
83
.github/workflows/build-container.yml
vendored
83
.github/workflows/build-container.yml
vendored
@ -3,17 +3,15 @@ on:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
- 'update/ci/docker/*'
|
||||
- 'update/docker/*'
|
||||
- 'dev/ci/docker/*'
|
||||
- 'dev/docker/*'
|
||||
paths:
|
||||
- 'pyproject.toml'
|
||||
- '.dockerignore'
|
||||
- 'invokeai/**'
|
||||
- 'docker/Dockerfile'
|
||||
- 'docker/docker-entrypoint.sh'
|
||||
- 'workflows/build-container.yml'
|
||||
tags:
|
||||
- 'v*.*.*'
|
||||
- 'v*'
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
@ -26,23 +24,27 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
flavor:
|
||||
- rocm
|
||||
- cuda
|
||||
- cpu
|
||||
include:
|
||||
- flavor: rocm
|
||||
pip-extra-index-url: 'https://download.pytorch.org/whl/rocm5.2'
|
||||
- flavor: cuda
|
||||
pip-extra-index-url: ''
|
||||
- flavor: cpu
|
||||
pip-extra-index-url: 'https://download.pytorch.org/whl/cpu'
|
||||
gpu-driver:
|
||||
- cuda
|
||||
- cpu
|
||||
- rocm
|
||||
runs-on: ubuntu-latest
|
||||
name: ${{ matrix.flavor }}
|
||||
name: ${{ matrix.gpu-driver }}
|
||||
env:
|
||||
PLATFORMS: 'linux/amd64,linux/arm64'
|
||||
DOCKERFILE: 'docker/Dockerfile'
|
||||
# torch/arm64 does not support GPU currently, so arm64 builds
|
||||
# would not be GPU-accelerated.
|
||||
# re-enable arm64 if there is sufficient demand.
|
||||
# PLATFORMS: 'linux/amd64,linux/arm64'
|
||||
PLATFORMS: 'linux/amd64'
|
||||
steps:
|
||||
- name: Free up more disk space on the runner
|
||||
# https://github.com/actions/runner-images/issues/2840#issuecomment-1284059930
|
||||
run: |
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
|
||||
sudo swapoff /mnt/swapfile
|
||||
sudo rm -rf /mnt/swapfile
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
|
||||
@ -53,7 +55,7 @@ jobs:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
images: |
|
||||
ghcr.io/${{ github.repository }}
|
||||
${{ vars.DOCKERHUB_REPOSITORY }}
|
||||
${{ env.DOCKERHUB_REPOSITORY }}
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=ref,event=tag
|
||||
@ -62,8 +64,8 @@ jobs:
|
||||
type=pep440,pattern={{major}}
|
||||
type=sha,enable=true,prefix=sha-,format=short
|
||||
flavor: |
|
||||
latest=${{ matrix.flavor == 'cuda' && github.ref == 'refs/heads/main' }}
|
||||
suffix=-${{ matrix.flavor }},onlatest=false
|
||||
latest=${{ matrix.gpu-driver == 'cuda' && github.ref == 'refs/heads/main' }}
|
||||
suffix=-${{ matrix.gpu-driver }},onlatest=false
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v2
|
||||
@ -81,34 +83,33 @@ jobs:
|
||||
username: ${{ github.repository_owner }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Login to Docker Hub
|
||||
if: github.event_name != 'pull_request' && vars.DOCKERHUB_REPOSITORY != ''
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
# - name: Login to Docker Hub
|
||||
# if: github.event_name != 'pull_request' && vars.DOCKERHUB_REPOSITORY != ''
|
||||
# uses: docker/login-action@v2
|
||||
# with:
|
||||
# username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
# password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Build container
|
||||
id: docker_build
|
||||
uses: docker/build-push-action@v4
|
||||
with:
|
||||
context: .
|
||||
file: ${{ env.DOCKERFILE }}
|
||||
file: docker/Dockerfile
|
||||
platforms: ${{ env.PLATFORMS }}
|
||||
push: ${{ github.ref == 'refs/heads/main' || github.ref_type == 'tag' }}
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
build-args: PIP_EXTRA_INDEX_URL=${{ matrix.pip-extra-index-url }}
|
||||
cache-from: |
|
||||
type=gha,scope=${{ github.ref_name }}-${{ matrix.flavor }}
|
||||
type=gha,scope=main-${{ matrix.flavor }}
|
||||
cache-to: type=gha,mode=max,scope=${{ github.ref_name }}-${{ matrix.flavor }}
|
||||
type=gha,scope=${{ github.ref_name }}-${{ matrix.gpu-driver }}
|
||||
type=gha,scope=main-${{ matrix.gpu-driver }}
|
||||
cache-to: type=gha,mode=max,scope=${{ github.ref_name }}-${{ matrix.gpu-driver }}
|
||||
|
||||
- name: Docker Hub Description
|
||||
if: github.ref == 'refs/heads/main' || github.ref == 'refs/tags/*' && vars.DOCKERHUB_REPOSITORY != ''
|
||||
uses: peter-evans/dockerhub-description@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
repository: ${{ vars.DOCKERHUB_REPOSITORY }}
|
||||
short-description: ${{ github.event.repository.description }}
|
||||
# - name: Docker Hub Description
|
||||
# if: github.ref == 'refs/heads/main' || github.ref == 'refs/tags/*' && vars.DOCKERHUB_REPOSITORY != ''
|
||||
# uses: peter-evans/dockerhub-description@v3
|
||||
# with:
|
||||
# username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
# password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
# repository: ${{ vars.DOCKERHUB_REPOSITORY }}
|
||||
# short-description: ${{ github.event.repository.description }}
|
||||
|
13
docker/.env.sample
Normal file
13
docker/.env.sample
Normal file
@ -0,0 +1,13 @@
|
||||
## Make a copy of this file named `.env` and fill in the values below.
|
||||
## Any environment variables supported by InvokeAI can be specified here.
|
||||
|
||||
# INVOKEAI_ROOT is the path to a path on the local filesystem where InvokeAI will store data.
|
||||
# Outputs will also be stored here by default.
|
||||
# This **must** be an absolute path.
|
||||
INVOKEAI_ROOT=
|
||||
|
||||
HUGGINGFACE_TOKEN=
|
||||
|
||||
## optional variables specific to the docker setup
|
||||
# GPU_DRIVER=cuda
|
||||
# CONTAINER_UID=1000
|
@ -1,107 +1,129 @@
|
||||
# syntax=docker/dockerfile:1
|
||||
# syntax=docker/dockerfile:1.4
|
||||
|
||||
ARG PYTHON_VERSION=3.9
|
||||
##################
|
||||
## base image ##
|
||||
##################
|
||||
FROM --platform=${TARGETPLATFORM} python:${PYTHON_VERSION}-slim AS python-base
|
||||
## Builder stage
|
||||
|
||||
LABEL org.opencontainers.image.authors="mauwii@outlook.de"
|
||||
FROM library/ubuntu:22.04 AS builder
|
||||
|
||||
# Prepare apt for buildkit cache
|
||||
RUN rm -f /etc/apt/apt.conf.d/docker-clean \
|
||||
&& echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' >/etc/apt/apt.conf.d/keep-cache
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
RUN rm -f /etc/apt/apt.conf.d/docker-clean; echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' > /etc/apt/apt.conf.d/keep-cache
|
||||
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
|
||||
--mount=type=cache,target=/var/lib/apt,sharing=locked \
|
||||
apt update && apt-get install -y \
|
||||
git \
|
||||
python3.10-venv \
|
||||
python3-pip \
|
||||
build-essential
|
||||
|
||||
# Install dependencies
|
||||
RUN \
|
||||
--mount=type=cache,target=/var/cache/apt,sharing=locked \
|
||||
--mount=type=cache,target=/var/lib/apt,sharing=locked \
|
||||
apt-get update \
|
||||
&& apt-get install -y \
|
||||
--no-install-recommends \
|
||||
libgl1-mesa-glx=20.3.* \
|
||||
libglib2.0-0=2.66.* \
|
||||
libopencv-dev=4.5.*
|
||||
ENV INVOKEAI_SRC=/opt/invokeai
|
||||
ENV VIRTUAL_ENV=/opt/venv/invokeai
|
||||
|
||||
# Set working directory and env
|
||||
ARG APPDIR=/usr/src
|
||||
ARG APPNAME=InvokeAI
|
||||
WORKDIR ${APPDIR}
|
||||
ENV PATH ${APPDIR}/${APPNAME}/bin:$PATH
|
||||
# Keeps Python from generating .pyc files in the container
|
||||
ENV PYTHONDONTWRITEBYTECODE 1
|
||||
# Turns off buffering for easier container logging
|
||||
ENV PYTHONUNBUFFERED 1
|
||||
# Don't fall back to legacy build system
|
||||
ENV PIP_USE_PEP517=1
|
||||
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
|
||||
ARG TORCH_VERSION=2.0.1
|
||||
ARG TORCHVISION_VERSION=0.15.2
|
||||
ARG GPU_DRIVER=cuda
|
||||
ARG TARGETPLATFORM="linux/amd64"
|
||||
# unused but available
|
||||
ARG BUILDPLATFORM
|
||||
|
||||
#######################
|
||||
## build pyproject ##
|
||||
#######################
|
||||
FROM python-base AS pyproject-builder
|
||||
WORKDIR ${INVOKEAI_SRC}
|
||||
|
||||
# Install build dependencies
|
||||
RUN \
|
||||
--mount=type=cache,target=/var/cache/apt,sharing=locked \
|
||||
--mount=type=cache,target=/var/lib/apt,sharing=locked \
|
||||
apt-get update \
|
||||
&& apt-get install -y \
|
||||
--no-install-recommends \
|
||||
build-essential=12.9 \
|
||||
gcc=4:10.2.* \
|
||||
python3-dev=3.9.*
|
||||
# Install pytorch before all other pip packages
|
||||
# NOTE: there are no pytorch builds for arm64 + cuda, only cpu
|
||||
# x86_64/CUDA is default
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
python3 -m venv ${VIRTUAL_ENV} &&\
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ] || [ "$GPU_DRIVER" = "cpu" ]; then \
|
||||
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/cpu"; \
|
||||
elif [ "$GPU_DRIVER" = "rocm" ]; then \
|
||||
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/rocm5.4.2"; \
|
||||
else \
|
||||
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/cu118"; \
|
||||
fi &&\
|
||||
pip install $extra_index_url_arg \
|
||||
torch==$TORCH_VERSION \
|
||||
torchvision==$TORCHVISION_VERSION
|
||||
|
||||
# Prepare pip for buildkit cache
|
||||
ARG PIP_CACHE_DIR=/var/cache/buildkit/pip
|
||||
ENV PIP_CACHE_DIR ${PIP_CACHE_DIR}
|
||||
RUN mkdir -p ${PIP_CACHE_DIR}
|
||||
# Install the local package.
|
||||
# Editable mode helps use the same image for development:
|
||||
# the local working copy can be bind-mounted into the image
|
||||
# at path defined by ${INVOKEAI_SRC}
|
||||
COPY invokeai ./invokeai
|
||||
COPY pyproject.toml ./
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
# xformers + triton fails to install on arm64
|
||||
if [ "$GPU_DRIVER" = "cuda" ] && [ "$TARGETPLATFORM" = "linux/amd64" ]; then \
|
||||
pip install -e ".[xformers]"; \
|
||||
else \
|
||||
pip install -e "."; \
|
||||
fi
|
||||
|
||||
# Create virtual environment
|
||||
RUN --mount=type=cache,target=${PIP_CACHE_DIR} \
|
||||
python3 -m venv "${APPNAME}" \
|
||||
--upgrade-deps
|
||||
# #### Build the Web UI ------------------------------------
|
||||
|
||||
# Install requirements
|
||||
COPY --link pyproject.toml .
|
||||
COPY --link invokeai/version/invokeai_version.py invokeai/version/__init__.py invokeai/version/
|
||||
ARG PIP_EXTRA_INDEX_URL
|
||||
ENV PIP_EXTRA_INDEX_URL ${PIP_EXTRA_INDEX_URL}
|
||||
RUN --mount=type=cache,target=${PIP_CACHE_DIR} \
|
||||
"${APPNAME}"/bin/pip install .
|
||||
FROM node:18 AS web-builder
|
||||
WORKDIR /build
|
||||
COPY invokeai/frontend/web/ ./
|
||||
RUN --mount=type=cache,target=/usr/lib/node_modules \
|
||||
npm install --include dev
|
||||
RUN --mount=type=cache,target=/usr/lib/node_modules \
|
||||
yarn vite build
|
||||
|
||||
# Install pyproject.toml
|
||||
COPY --link . .
|
||||
RUN --mount=type=cache,target=${PIP_CACHE_DIR} \
|
||||
"${APPNAME}/bin/pip" install .
|
||||
|
||||
# Build patchmatch
|
||||
#### Runtime stage ---------------------------------------
|
||||
|
||||
FROM library/ubuntu:22.04 AS runtime
|
||||
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV PYTHONDONTWRITEBYTECODE=1
|
||||
|
||||
RUN apt update && apt install -y --no-install-recommends \
|
||||
git \
|
||||
curl \
|
||||
vim \
|
||||
tmux \
|
||||
ncdu \
|
||||
iotop \
|
||||
bzip2 \
|
||||
gosu \
|
||||
libglib2.0-0 \
|
||||
libgl1-mesa-glx \
|
||||
python3-venv \
|
||||
python3-pip \
|
||||
build-essential \
|
||||
libopencv-dev \
|
||||
libstdc++-10-dev &&\
|
||||
apt-get clean && apt-get autoclean
|
||||
|
||||
# globally add magic-wormhole
|
||||
# for ease of transferring data to and from the container
|
||||
# when running in sandboxed cloud environments; e.g. Runpod etc.
|
||||
RUN pip install magic-wormhole
|
||||
|
||||
ENV INVOKEAI_SRC=/opt/invokeai
|
||||
ENV VIRTUAL_ENV=/opt/venv/invokeai
|
||||
ENV INVOKEAI_ROOT=/invokeai
|
||||
ENV PATH="$VIRTUAL_ENV/bin:$INVOKEAI_SRC:$PATH"
|
||||
|
||||
# --link requires buldkit w/ dockerfile syntax 1.4
|
||||
COPY --link --from=builder ${INVOKEAI_SRC} ${INVOKEAI_SRC}
|
||||
COPY --link --from=builder ${VIRTUAL_ENV} ${VIRTUAL_ENV}
|
||||
COPY --link --from=web-builder /build/dist ${INVOKEAI_SRC}/invokeai/frontend/web/dist
|
||||
|
||||
# Link amdgpu.ids for ROCm builds
|
||||
# contributed by https://github.com/Rubonnek
|
||||
RUN mkdir -p "/opt/amdgpu/share/libdrm" &&\
|
||||
ln -s "/usr/share/libdrm/amdgpu.ids" "/opt/amdgpu/share/libdrm/amdgpu.ids"
|
||||
|
||||
WORKDIR ${INVOKEAI_SRC}
|
||||
|
||||
# build patchmatch
|
||||
RUN cd /usr/lib/$(uname -p)-linux-gnu/pkgconfig/ && ln -sf opencv4.pc opencv.pc
|
||||
RUN python3 -c "from patchmatch import patch_match"
|
||||
|
||||
#####################
|
||||
## runtime image ##
|
||||
#####################
|
||||
FROM python-base AS runtime
|
||||
# Create unprivileged user and make the local dir
|
||||
RUN useradd --create-home --shell /bin/bash -u 1000 --comment "container local user" invoke
|
||||
RUN mkdir -p ${INVOKEAI_ROOT} && chown -R invoke:invoke ${INVOKEAI_ROOT}
|
||||
|
||||
# Create a new user
|
||||
ARG UNAME=appuser
|
||||
RUN useradd \
|
||||
--no-log-init \
|
||||
-m \
|
||||
-U \
|
||||
"${UNAME}"
|
||||
|
||||
# Create volume directory
|
||||
ARG VOLUME_DIR=/data
|
||||
RUN mkdir -p "${VOLUME_DIR}" \
|
||||
&& chown -hR "${UNAME}:${UNAME}" "${VOLUME_DIR}"
|
||||
|
||||
# Setup runtime environment
|
||||
USER ${UNAME}:${UNAME}
|
||||
COPY --chown=${UNAME}:${UNAME} --from=pyproject-builder ${APPDIR}/${APPNAME} ${APPNAME}
|
||||
ENV INVOKEAI_ROOT ${VOLUME_DIR}
|
||||
ENV TRANSFORMERS_CACHE ${VOLUME_DIR}/.cache
|
||||
ENV INVOKE_MODEL_RECONFIGURE "--yes --default_only"
|
||||
EXPOSE 9090
|
||||
ENTRYPOINT [ "invokeai" ]
|
||||
CMD [ "--web", "--host", "0.0.0.0", "--port", "9090" ]
|
||||
VOLUME [ "${VOLUME_DIR}" ]
|
||||
COPY docker/docker-entrypoint.sh ./
|
||||
ENTRYPOINT ["/opt/invokeai/docker-entrypoint.sh"]
|
||||
CMD ["invokeai-web", "--host", "0.0.0.0"]
|
||||
|
77
docker/README.md
Normal file
77
docker/README.md
Normal file
@ -0,0 +1,77 @@
|
||||
# InvokeAI Containerized
|
||||
|
||||
All commands are to be run from the `docker` directory: `cd docker`
|
||||
|
||||
#### Linux
|
||||
|
||||
1. Ensure builkit is enabled in the Docker daemon settings (`/etc/docker/daemon.json`)
|
||||
2. Install the `docker compose` plugin using your package manager, or follow a [tutorial](https://www.digitalocean.com/community/tutorials/how-to-install-and-use-docker-compose-on-ubuntu-22-04).
|
||||
- The deprecated `docker-compose` (hyphenated) CLI continues to work for now.
|
||||
3. Ensure docker daemon is able to access the GPU.
|
||||
- You may need to install [nvidia-container-toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html)
|
||||
|
||||
#### macOS
|
||||
|
||||
1. Ensure Docker has at least 16GB RAM
|
||||
2. Enable VirtioFS for file sharing
|
||||
3. Enable `docker compose` V2 support
|
||||
|
||||
This is done via Docker Desktop preferences
|
||||
|
||||
## Quickstart
|
||||
|
||||
|
||||
1. Make a copy of `env.sample` and name it `.env` (`cp env.sample .env` (Mac/Linux) or `copy example.env .env` (Windows)). Make changes as necessary. Set `INVOKEAI_ROOT` to an absolute path to:
|
||||
a. the desired location of the InvokeAI runtime directory, or
|
||||
b. an existing, v3.0.0 compatible runtime directory.
|
||||
1. `docker compose up`
|
||||
|
||||
The image will be built automatically if needed.
|
||||
|
||||
The runtime directory (holding models and outputs) will be created in the location specified by `INVOKEAI_ROOT`. The default location is `~/invokeai`. The runtime directory will be populated with the base configs and models necessary to start generating.
|
||||
|
||||
### Use a GPU
|
||||
|
||||
- Linux is *recommended* for GPU support in Docker.
|
||||
- WSL2 is *required* for Windows.
|
||||
- only `x86_64` architecture is supported.
|
||||
|
||||
The Docker daemon on the system must be already set up to use the GPU. In case of Linux, this involves installing `nvidia-docker-runtime` and configuring the `nvidia` runtime as default. Steps will be different for AMD. Please see Docker documentation for the most up-to-date instructions for using your GPU with Docker.
|
||||
|
||||
## Customize
|
||||
|
||||
Check the `.env.sample` file. It contains some environment variables for running in Docker. Copy it, name it `.env`, and fill it in with your own values. Next time you run `docker compose up`, your custom values will be used.
|
||||
|
||||
You can also set these values in `docker compose.yml` directly, but `.env` will help avoid conflicts when code is updated.
|
||||
|
||||
Example (most values are optional):
|
||||
|
||||
```
|
||||
INVOKEAI_ROOT=/Volumes/WorkDrive/invokeai
|
||||
HUGGINGFACE_TOKEN=the_actual_token
|
||||
CONTAINER_UID=1000
|
||||
GPU_DRIVER=cuda
|
||||
```
|
||||
|
||||
## Even Moar Customizing!
|
||||
|
||||
See the `docker compose.yaml` file. The `command` instruction can be uncommented and used to run arbitrary startup commands. Some examples below.
|
||||
|
||||
### Reconfigure the runtime directory
|
||||
|
||||
Can be used to download additional models from the supported model list
|
||||
|
||||
In conjunction with `INVOKEAI_ROOT` can be also used to initialize a runtime directory
|
||||
|
||||
```
|
||||
command:
|
||||
- invokeai-configure
|
||||
- --yes
|
||||
```
|
||||
|
||||
Or install models:
|
||||
|
||||
```
|
||||
command:
|
||||
- invokeai-model-install
|
||||
```
|
@ -1,51 +1,11 @@
|
||||
#!/usr/bin/env bash
|
||||
set -e
|
||||
|
||||
# If you want to build a specific flavor, set the CONTAINER_FLAVOR environment variable
|
||||
# e.g. CONTAINER_FLAVOR=cpu ./build.sh
|
||||
# Possible Values are:
|
||||
# - cpu
|
||||
# - cuda
|
||||
# - rocm
|
||||
# Don't forget to also set it when executing run.sh
|
||||
# if it is not set, the script will try to detect the flavor by itself.
|
||||
#
|
||||
# Doc can be found here:
|
||||
# https://invoke-ai.github.io/InvokeAI/installation/040_INSTALL_DOCKER/
|
||||
build_args=""
|
||||
|
||||
SCRIPTDIR=$(dirname "${BASH_SOURCE[0]}")
|
||||
cd "$SCRIPTDIR" || exit 1
|
||||
[[ -f ".env" ]] && build_args=$(awk '$1 ~ /\=[^$]/ {print "--build-arg " $0 " "}' .env)
|
||||
|
||||
source ./env.sh
|
||||
echo "docker-compose build args:"
|
||||
echo $build_args
|
||||
|
||||
DOCKERFILE=${INVOKE_DOCKERFILE:-./Dockerfile}
|
||||
|
||||
# print the settings
|
||||
echo -e "You are using these values:\n"
|
||||
echo -e "Dockerfile:\t\t${DOCKERFILE}"
|
||||
echo -e "index-url:\t\t${PIP_EXTRA_INDEX_URL:-none}"
|
||||
echo -e "Volumename:\t\t${VOLUMENAME}"
|
||||
echo -e "Platform:\t\t${PLATFORM}"
|
||||
echo -e "Container Registry:\t${CONTAINER_REGISTRY}"
|
||||
echo -e "Container Repository:\t${CONTAINER_REPOSITORY}"
|
||||
echo -e "Container Tag:\t\t${CONTAINER_TAG}"
|
||||
echo -e "Container Flavor:\t${CONTAINER_FLAVOR}"
|
||||
echo -e "Container Image:\t${CONTAINER_IMAGE}\n"
|
||||
|
||||
# Create docker volume
|
||||
if [[ -n "$(docker volume ls -f name="${VOLUMENAME}" -q)" ]]; then
|
||||
echo -e "Volume already exists\n"
|
||||
else
|
||||
echo -n "creating docker volume "
|
||||
docker volume create "${VOLUMENAME}"
|
||||
fi
|
||||
|
||||
# Build Container
|
||||
docker build \
|
||||
--platform="${PLATFORM:-linux/amd64}" \
|
||||
--tag="${CONTAINER_IMAGE:-invokeai}" \
|
||||
${CONTAINER_FLAVOR:+--build-arg="CONTAINER_FLAVOR=${CONTAINER_FLAVOR}"} \
|
||||
${PIP_EXTRA_INDEX_URL:+--build-arg="PIP_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL}"} \
|
||||
${PIP_PACKAGE:+--build-arg="PIP_PACKAGE=${PIP_PACKAGE}"} \
|
||||
--file="${DOCKERFILE}" \
|
||||
..
|
||||
docker-compose build $build_args
|
||||
|
48
docker/docker-compose.yml
Normal file
48
docker/docker-compose.yml
Normal file
@ -0,0 +1,48 @@
|
||||
# Copyright (c) 2023 Eugene Brodsky https://github.com/ebr
|
||||
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
invokeai:
|
||||
image: "local/invokeai:latest"
|
||||
# edit below to run on a container runtime other than nvidia-container-runtime.
|
||||
# not yet tested with rocm/AMD GPUs
|
||||
# Comment out the "deploy" section to run on CPU only
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: docker/Dockerfile
|
||||
|
||||
# variables without a default will automatically inherit from the host environment
|
||||
environment:
|
||||
- INVOKEAI_ROOT
|
||||
- HF_HOME
|
||||
|
||||
# Create a .env file in the same directory as this docker-compose.yml file
|
||||
# and populate it with environment variables. See .env.sample
|
||||
env_file:
|
||||
- .env
|
||||
|
||||
ports:
|
||||
- "${INVOKEAI_PORT:-9090}:9090"
|
||||
volumes:
|
||||
- ${INVOKEAI_ROOT:-~/invokeai}:${INVOKEAI_ROOT:-/invokeai}
|
||||
- ${HF_HOME:-~/.cache/huggingface}:${HF_HOME:-/invokeai/.cache/huggingface}
|
||||
# - ${INVOKEAI_MODELS_DIR:-${INVOKEAI_ROOT:-/invokeai/models}}
|
||||
# - ${INVOKEAI_MODELS_CONFIG_PATH:-${INVOKEAI_ROOT:-/invokeai/configs/models.yaml}}
|
||||
tty: true
|
||||
stdin_open: true
|
||||
|
||||
# # Example of running alternative commands/scripts in the container
|
||||
# command:
|
||||
# - bash
|
||||
# - -c
|
||||
# - |
|
||||
# invokeai-model-install --yes --default-only --config_file ${INVOKEAI_ROOT}/config_custom.yaml
|
||||
# invokeai-nodes-web --host 0.0.0.0
|
65
docker/docker-entrypoint.sh
Executable file
65
docker/docker-entrypoint.sh
Executable file
@ -0,0 +1,65 @@
|
||||
#!/bin/bash
|
||||
set -e -o pipefail
|
||||
|
||||
### Container entrypoint
|
||||
# Runs the CMD as defined by the Dockerfile or passed to `docker run`
|
||||
# Can be used to configure the runtime dir
|
||||
# Bypass by using ENTRYPOINT or `--entrypoint`
|
||||
|
||||
### Set INVOKEAI_ROOT pointing to a valid runtime directory
|
||||
# Otherwise configure the runtime dir first.
|
||||
|
||||
### Configure the InvokeAI runtime directory (done by default)):
|
||||
# docker run --rm -it <this image> --configure
|
||||
# or skip with --no-configure
|
||||
|
||||
### Set the CONTAINER_UID envvar to match your user.
|
||||
# Ensures files created in the container are owned by you:
|
||||
# docker run --rm -it -v /some/path:/invokeai -e CONTAINER_UID=$(id -u) <this image>
|
||||
# Default UID: 1000 chosen due to popularity on Linux systems. Possibly 501 on MacOS.
|
||||
|
||||
USER_ID=${CONTAINER_UID:-1000}
|
||||
USER=invoke
|
||||
usermod -u ${USER_ID} ${USER} 1>/dev/null
|
||||
|
||||
configure() {
|
||||
# Configure the runtime directory
|
||||
if [[ -f ${INVOKEAI_ROOT}/invokeai.yaml ]]; then
|
||||
echo "${INVOKEAI_ROOT}/invokeai.yaml exists. InvokeAI is already configured."
|
||||
echo "To reconfigure InvokeAI, delete the above file."
|
||||
echo "======================================================================"
|
||||
else
|
||||
mkdir -p ${INVOKEAI_ROOT}
|
||||
chown --recursive ${USER} ${INVOKEAI_ROOT}
|
||||
gosu ${USER} invokeai-configure --yes --default_only
|
||||
fi
|
||||
}
|
||||
|
||||
## Skip attempting to configure.
|
||||
## Must be passed first, before any other args.
|
||||
if [[ $1 != "--no-configure" ]]; then
|
||||
configure
|
||||
else
|
||||
shift
|
||||
fi
|
||||
|
||||
### Set the $PUBLIC_KEY env var to enable SSH access.
|
||||
# We do not install openssh-server in the image by default to avoid bloat.
|
||||
# but it is useful to have the full SSH server e.g. on Runpod.
|
||||
# (use SCP to copy files to/from the image, etc)
|
||||
if [[ -v "PUBLIC_KEY" ]] && [[ ! -d "${HOME}/.ssh" ]]; then
|
||||
apt-get update
|
||||
apt-get install -y openssh-server
|
||||
pushd $HOME
|
||||
mkdir -p .ssh
|
||||
echo ${PUBLIC_KEY} > .ssh/authorized_keys
|
||||
chmod -R 700 .ssh
|
||||
popd
|
||||
service ssh start
|
||||
fi
|
||||
|
||||
|
||||
cd ${INVOKEAI_ROOT}
|
||||
|
||||
# Run the CMD as the Container User (not root).
|
||||
exec gosu ${USER} "$@"
|
@ -1,54 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# This file is used to set environment variables for the build.sh and run.sh scripts.
|
||||
|
||||
# Try to detect the container flavor if no PIP_EXTRA_INDEX_URL got specified
|
||||
if [[ -z "$PIP_EXTRA_INDEX_URL" ]]; then
|
||||
|
||||
# Activate virtual environment if not already activated and exists
|
||||
if [[ -z $VIRTUAL_ENV ]]; then
|
||||
[[ -e "$(dirname "${BASH_SOURCE[0]}")/../.venv/bin/activate" ]] \
|
||||
&& source "$(dirname "${BASH_SOURCE[0]}")/../.venv/bin/activate" \
|
||||
&& echo "Activated virtual environment: $VIRTUAL_ENV"
|
||||
fi
|
||||
|
||||
# Decide which container flavor to build if not specified
|
||||
if [[ -z "$CONTAINER_FLAVOR" ]] && python -c "import torch" &>/dev/null; then
|
||||
# Check for CUDA and ROCm
|
||||
CUDA_AVAILABLE=$(python -c "import torch;print(torch.cuda.is_available())")
|
||||
ROCM_AVAILABLE=$(python -c "import torch;print(torch.version.hip is not None)")
|
||||
if [[ "${CUDA_AVAILABLE}" == "True" ]]; then
|
||||
CONTAINER_FLAVOR="cuda"
|
||||
elif [[ "${ROCM_AVAILABLE}" == "True" ]]; then
|
||||
CONTAINER_FLAVOR="rocm"
|
||||
else
|
||||
CONTAINER_FLAVOR="cpu"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Set PIP_EXTRA_INDEX_URL based on container flavor
|
||||
if [[ "$CONTAINER_FLAVOR" == "rocm" ]]; then
|
||||
PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/rocm"
|
||||
elif [[ "$CONTAINER_FLAVOR" == "cpu" ]]; then
|
||||
PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu"
|
||||
# elif [[ -z "$CONTAINER_FLAVOR" || "$CONTAINER_FLAVOR" == "cuda" ]]; then
|
||||
# PIP_PACKAGE=${PIP_PACKAGE-".[xformers]"}
|
||||
fi
|
||||
fi
|
||||
|
||||
# Variables shared by build.sh and run.sh
|
||||
REPOSITORY_NAME="${REPOSITORY_NAME-$(basename "$(git rev-parse --show-toplevel)")}"
|
||||
REPOSITORY_NAME="${REPOSITORY_NAME,,}"
|
||||
VOLUMENAME="${VOLUMENAME-"${REPOSITORY_NAME}_data"}"
|
||||
ARCH="${ARCH-$(uname -m)}"
|
||||
PLATFORM="${PLATFORM-linux/${ARCH}}"
|
||||
INVOKEAI_BRANCH="${INVOKEAI_BRANCH-$(git branch --show)}"
|
||||
CONTAINER_REGISTRY="${CONTAINER_REGISTRY-"ghcr.io"}"
|
||||
CONTAINER_REPOSITORY="${CONTAINER_REPOSITORY-"$(whoami)/${REPOSITORY_NAME}"}"
|
||||
CONTAINER_FLAVOR="${CONTAINER_FLAVOR-cuda}"
|
||||
CONTAINER_TAG="${CONTAINER_TAG-"${INVOKEAI_BRANCH##*/}-${CONTAINER_FLAVOR}"}"
|
||||
CONTAINER_IMAGE="${CONTAINER_REGISTRY}/${CONTAINER_REPOSITORY}:${CONTAINER_TAG}"
|
||||
CONTAINER_IMAGE="${CONTAINER_IMAGE,,}"
|
||||
|
||||
# enable docker buildkit
|
||||
export DOCKER_BUILDKIT=1
|
@ -1,41 +1,8 @@
|
||||
#!/usr/bin/env bash
|
||||
set -e
|
||||
|
||||
# How to use: https://invoke-ai.github.io/InvokeAI/installation/040_INSTALL_DOCKER/
|
||||
|
||||
SCRIPTDIR=$(dirname "${BASH_SOURCE[0]}")
|
||||
cd "$SCRIPTDIR" || exit 1
|
||||
|
||||
source ./env.sh
|
||||
|
||||
# Create outputs directory if it does not exist
|
||||
[[ -d ./outputs ]] || mkdir ./outputs
|
||||
|
||||
echo -e "You are using these values:\n"
|
||||
echo -e "Volumename:\t${VOLUMENAME}"
|
||||
echo -e "Invokeai_tag:\t${CONTAINER_IMAGE}"
|
||||
echo -e "local Models:\t${MODELSPATH:-unset}\n"
|
||||
|
||||
docker run \
|
||||
--interactive \
|
||||
--tty \
|
||||
--rm \
|
||||
--platform="${PLATFORM}" \
|
||||
--name="${REPOSITORY_NAME}" \
|
||||
--hostname="${REPOSITORY_NAME}" \
|
||||
--mount type=volume,volume-driver=local,source="${VOLUMENAME}",target=/data \
|
||||
--mount type=bind,source="$(pwd)"/outputs/,target=/data/outputs/ \
|
||||
${MODELSPATH:+--mount="type=bind,source=${MODELSPATH},target=/data/models"} \
|
||||
${HUGGING_FACE_HUB_TOKEN:+--env="HUGGING_FACE_HUB_TOKEN=${HUGGING_FACE_HUB_TOKEN}"} \
|
||||
--publish=9090:9090 \
|
||||
--cap-add=sys_nice \
|
||||
${GPU_FLAGS:+--gpus="${GPU_FLAGS}"} \
|
||||
"${CONTAINER_IMAGE}" ${@:+$@}
|
||||
|
||||
echo -e "\nCleaning trash folder ..."
|
||||
for f in outputs/.Trash*; do
|
||||
if [ -e "$f" ]; then
|
||||
rm -Rf "$f"
|
||||
break
|
||||
fi
|
||||
done
|
||||
docker-compose up --build -d
|
||||
docker-compose logs -f
|
||||
|
60
docker/runpod-readme.md
Normal file
60
docker/runpod-readme.md
Normal file
@ -0,0 +1,60 @@
|
||||
# InvokeAI - A Stable Diffusion Toolkit
|
||||
|
||||
Stable Diffusion distribution by InvokeAI: https://github.com/invoke-ai
|
||||
|
||||
The Docker image tracks the `main` branch of the InvokeAI project, which means it includes the latest features, but may contain some bugs.
|
||||
|
||||
Your working directory is mounted under the `/workspace` path inside the pod. The models are in `/workspace/invokeai/models`, and outputs are in `/workspace/invokeai/outputs`.
|
||||
|
||||
> **Only the /workspace directory will persist between pod restarts!**
|
||||
|
||||
> **If you _terminate_ (not just _stop_) the pod, the /workspace will be lost.**
|
||||
|
||||
## Quickstart
|
||||
|
||||
1. Launch a pod from this template. **It will take about 5-10 minutes to run through the initial setup**. Be patient.
|
||||
1. Wait for the application to load.
|
||||
- TIP: you know it's ready when the CPU usage goes idle
|
||||
- You can also check the logs for a line that says "_Point your browser at..._"
|
||||
1. Open the Invoke AI web UI: click the `Connect` => `connect over HTTP` button.
|
||||
1. Generate some art!
|
||||
|
||||
## Other things you can do
|
||||
|
||||
At any point you may edit the pod configuration and set an arbitrary Docker command. For example, you could run a command to downloads some models using `curl`, or fetch some images and place them into your outputs to continue a working session.
|
||||
|
||||
If you need to run *multiple commands*, define them in the Docker Command field like this:
|
||||
|
||||
`bash -c "cd ${INVOKEAI_ROOT}/outputs; wormhole receive 2-foo-bar; invoke.py --web --host 0.0.0.0"`
|
||||
|
||||
### Copying your data in and out of the pod
|
||||
|
||||
This image includes a couple of handy tools to help you get the data into the pod (such as your custom models or embeddings), and out of the pod (such as downloading your outputs). Here are your options for getting your data in and out of the pod:
|
||||
|
||||
- **SSH server**:
|
||||
1. Make sure to create and set your Public Key in the RunPod settings (follow the official instructions)
|
||||
1. Add an exposed port 22 (TCP) in the pod settings!
|
||||
1. When your pod restarts, you will see a new entry in the `Connect` dialog. Use this SSH server to `scp` or `sftp` your files as necessary, or SSH into the pod using the fully fledged SSH server.
|
||||
|
||||
- [**Magic Wormhole**](https://magic-wormhole.readthedocs.io/en/latest/welcome.html):
|
||||
1. On your computer, `pip install magic-wormhole` (see above instructions for details)
|
||||
1. Connect to the command line **using the "light" SSH client** or the browser-based console. _Currently there's a bug where `wormhole` isn't available when connected to "full" SSH server, as described above_.
|
||||
1. `wormhole send /workspace/invokeai/outputs` will send the entire `outputs` directory. You can also send individual files.
|
||||
1. Once packaged, you will see a `wormhole receive <123-some-words>` command. Copy it
|
||||
1. Paste this command into the terminal on your local machine to securely download the payload.
|
||||
1. It works the same in reverse: you can `wormhole send` some models from your computer to the pod. Again, save your files somewhere in `/workspace` or they will be lost when the pod is stopped.
|
||||
|
||||
- **RunPod's Cloud Sync feature** may be used to sync the persistent volume to cloud storage. You could, for example, copy the entire `/workspace` to S3, add some custom models to it, and copy it back from S3 when launching new pod configurations. Follow the Cloud Sync instructions.
|
||||
|
||||
|
||||
### Disable the NSFW checker
|
||||
|
||||
The NSFW checker is enabled by default. To disable it, edit the pod configuration and set the following command:
|
||||
|
||||
```
|
||||
invoke --web --host 0.0.0.0 --no-nsfw_checker
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
Template ©2023 Eugene Brodsky [ebr](https://github.com/ebr)
|
@ -248,6 +248,7 @@ class InvokeAiInstance:
|
||||
"install",
|
||||
"--require-virtualenv",
|
||||
"torch~=2.0.0",
|
||||
"torchmetrics==0.11.4",
|
||||
"torchvision>=0.14.1",
|
||||
"--force-reinstall",
|
||||
"--find-links" if find_links is not None else None,
|
||||
|
@ -13,7 +13,6 @@ from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
|
||||
from invokeai.app.services.boards import BoardService, BoardServiceDependencies
|
||||
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
||||
from invokeai.app.services.metadata import CoreMetadataService
|
||||
from invokeai.app.services.resource_name import SimpleNameService
|
||||
from invokeai.app.services.urls import LocalUrlService
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
@ -75,7 +74,6 @@ class ApiDependencies:
|
||||
)
|
||||
|
||||
urls = LocalUrlService()
|
||||
metadata = CoreMetadataService()
|
||||
image_record_storage = SqliteImageRecordStorage(db_location)
|
||||
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
||||
names = SimpleNameService()
|
||||
@ -111,7 +109,6 @@ class ApiDependencies:
|
||||
board_image_record_storage=board_image_record_storage,
|
||||
image_record_storage=image_record_storage,
|
||||
image_file_storage=image_file_storage,
|
||||
metadata=metadata,
|
||||
url=urls,
|
||||
logger=logger,
|
||||
names=names,
|
||||
|
@ -1,20 +1,19 @@
|
||||
import io
|
||||
from typing import Optional
|
||||
from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile
|
||||
from fastapi.routing import APIRouter
|
||||
|
||||
from fastapi import (Body, HTTPException, Path, Query, Request, Response,
|
||||
UploadFile)
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.routing import APIRouter
|
||||
from PIL import Image
|
||||
from invokeai.app.models.image import (
|
||||
ImageCategory,
|
||||
ResourceOrigin,
|
||||
)
|
||||
|
||||
from invokeai.app.invocations.metadata import ImageMetadata
|
||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||
from invokeai.app.services.models.image_record import (
|
||||
ImageDTO,
|
||||
ImageRecordChanges,
|
||||
ImageUrlsDTO,
|
||||
)
|
||||
from invokeai.app.services.item_storage import PaginatedResults
|
||||
from invokeai.app.services.models.image_record import (ImageDTO,
|
||||
ImageRecordChanges,
|
||||
ImageUrlsDTO)
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
@ -103,23 +102,38 @@ async def update_image(
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/{image_name}/metadata",
|
||||
operation_id="get_image_metadata",
|
||||
"/{image_name}",
|
||||
operation_id="get_image_dto",
|
||||
response_model=ImageDTO,
|
||||
)
|
||||
async def get_image_metadata(
|
||||
async def get_image_dto(
|
||||
image_name: str = Path(description="The name of image to get"),
|
||||
) -> ImageDTO:
|
||||
"""Gets an image's metadata"""
|
||||
"""Gets an image's DTO"""
|
||||
|
||||
try:
|
||||
return ApiDependencies.invoker.services.images.get_dto(image_name)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
@images_router.get(
|
||||
"/{image_name}/metadata",
|
||||
operation_id="get_image_metadata",
|
||||
response_model=ImageMetadata,
|
||||
)
|
||||
async def get_image_metadata(
|
||||
image_name: str = Path(description="The name of image to get"),
|
||||
) -> ImageMetadata:
|
||||
"""Gets an image's metadata"""
|
||||
|
||||
try:
|
||||
return ApiDependencies.invoker.services.images.get_metadata(image_name)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/{image_name}",
|
||||
"/{image_name}/full",
|
||||
operation_id="get_image_full",
|
||||
response_class=Response,
|
||||
responses={
|
||||
@ -208,10 +222,10 @@ async def get_image_urls(
|
||||
|
||||
@images_router.get(
|
||||
"/",
|
||||
operation_id="list_images_with_metadata",
|
||||
operation_id="list_image_dtos",
|
||||
response_model=OffsetPaginatedResults[ImageDTO],
|
||||
)
|
||||
async def list_images_with_metadata(
|
||||
async def list_image_dtos(
|
||||
image_origin: Optional[ResourceOrigin] = Query(
|
||||
default=None, description="The origin of images to list"
|
||||
),
|
||||
@ -227,7 +241,7 @@ async def list_images_with_metadata(
|
||||
offset: int = Query(default=0, description="The page offset"),
|
||||
limit: int = Query(default=10, description="The number of images per page"),
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
"""Gets a list of images"""
|
||||
"""Gets a list of image DTOs"""
|
||||
|
||||
image_dtos = ApiDependencies.invoker.services.images.get_many(
|
||||
offset,
|
||||
|
@ -34,7 +34,6 @@ from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
|
||||
from invokeai.app.services.boards import BoardService, BoardServiceDependencies
|
||||
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
||||
from invokeai.app.services.metadata import CoreMetadataService
|
||||
from invokeai.app.services.resource_name import SimpleNameService
|
||||
from invokeai.app.services.urls import LocalUrlService
|
||||
from .services.default_graphs import (default_text_to_image_graph_id,
|
||||
@ -244,7 +243,6 @@ def invoke_cli():
|
||||
)
|
||||
|
||||
urls = LocalUrlService()
|
||||
metadata = CoreMetadataService()
|
||||
image_record_storage = SqliteImageRecordStorage(db_location)
|
||||
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
||||
names = SimpleNameService()
|
||||
@ -277,7 +275,6 @@ def invoke_cli():
|
||||
board_image_record_storage=board_image_record_storage,
|
||||
image_record_storage=image_record_storage,
|
||||
image_file_storage=image_file_storage,
|
||||
metadata=metadata,
|
||||
url=urls,
|
||||
logger=logger,
|
||||
names=names,
|
||||
|
@ -154,40 +154,42 @@ class InpaintInvocation(BaseInvocation):
|
||||
|
||||
@contextmanager
|
||||
def load_model_old_way(self, context, scheduler):
|
||||
def _lora_loader():
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.dict(exclude={"weight"}))
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
||||
vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
|
||||
|
||||
#unet = unet_info.context.model
|
||||
#vae = vae_info.context.model
|
||||
with vae_info as vae,\
|
||||
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
||||
unet_info as unet:
|
||||
|
||||
with ExitStack() as stack:
|
||||
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
||||
device = context.services.model_manager.mgr.cache.execution_device
|
||||
dtype = context.services.model_manager.mgr.cache.precision
|
||||
|
||||
with vae_info as vae,\
|
||||
unet_info as unet,\
|
||||
ModelPatcher.apply_lora_unet(unet, loras):
|
||||
pipeline = StableDiffusionGeneratorPipeline(
|
||||
vae=vae,
|
||||
text_encoder=None,
|
||||
tokenizer=None,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
precision="float16" if dtype == torch.float16 else "float32",
|
||||
execution_device=device,
|
||||
)
|
||||
|
||||
device = context.services.model_manager.mgr.cache.execution_device
|
||||
dtype = context.services.model_manager.mgr.cache.precision
|
||||
|
||||
pipeline = StableDiffusionGeneratorPipeline(
|
||||
vae=vae,
|
||||
text_encoder=None,
|
||||
tokenizer=None,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
precision="float16" if dtype == torch.float16 else "float32",
|
||||
execution_device=device,
|
||||
)
|
||||
|
||||
yield OldModelInfo(
|
||||
name=self.unet.unet.model_name,
|
||||
hash="<NO-HASH>",
|
||||
model=pipeline,
|
||||
)
|
||||
yield OldModelInfo(
|
||||
name=self.unet.unet.model_name,
|
||||
hash="<NO-HASH>",
|
||||
model=pipeline,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = (
|
||||
@ -226,21 +228,21 @@ class InpaintInvocation(BaseInvocation):
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
|
||||
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
||||
# each time it is called. We only need the first one.
|
||||
generator_output = next(outputs)
|
||||
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
||||
# each time it is called. We only need the first one.
|
||||
generator_output = next(outputs)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=generator_output.image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
session_id=context.graph_execution_state_id,
|
||||
node_id=self.id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
image_dto = context.services.images.create(
|
||||
image=generator_output.image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
session_id=context.graph_execution_state_id,
|
||||
node_id=self.id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
@ -9,9 +9,9 @@ from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
from invokeai.app.invocations.metadata import CoreMetadata
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
|
||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||
@ -21,6 +21,7 @@ from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \
|
||||
PostprocessingSettings
|
||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from ...backend.util.devices import torch_dtype
|
||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||
InvocationConfig, InvocationContext)
|
||||
from .compel import ConditioningField
|
||||
@ -449,6 +450,8 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
tiled: bool = Field(
|
||||
default=False,
|
||||
description="Decode latents by overlaping tiles(less memory consumption)")
|
||||
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
|
||||
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
@ -493,7 +496,8 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.dict() if self.metadata else None,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
|
124
invokeai/app/invocations/metadata.py
Normal file
124
invokeai/app/invocations/metadata.py
Normal file
@ -0,0 +1,124 @@
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InvocationContext)
|
||||
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||
from invokeai.app.invocations.model import (LoRAModelField, MainModelField,
|
||||
VAEModelField)
|
||||
|
||||
|
||||
class LoRAMetadataField(BaseModel):
|
||||
"""LoRA metadata for an image generated in InvokeAI."""
|
||||
lora: LoRAModelField = Field(description="The LoRA model")
|
||||
weight: float = Field(description="The weight of the LoRA model")
|
||||
|
||||
|
||||
class CoreMetadata(BaseModel):
|
||||
"""Core generation metadata for an image generated in InvokeAI."""
|
||||
|
||||
generation_mode: str = Field(description="The generation mode that output this image",)
|
||||
positive_prompt: str = Field(description="The positive prompt parameter")
|
||||
negative_prompt: str = Field(description="The negative prompt parameter")
|
||||
width: int = Field(description="The width parameter")
|
||||
height: int = Field(description="The height parameter")
|
||||
seed: int = Field(description="The seed used for noise generation")
|
||||
rand_device: str = Field(description="The device used for random number generation")
|
||||
cfg_scale: float = Field(description="The classifier-free guidance scale parameter")
|
||||
steps: int = Field(description="The number of steps used for inference")
|
||||
scheduler: str = Field(description="The scheduler used for inference")
|
||||
clip_skip: int = Field(description="The number of skipped CLIP layers",)
|
||||
model: MainModelField = Field(description="The main model used for inference")
|
||||
controlnets: list[ControlField]= Field(description="The ControlNets used for inference")
|
||||
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
|
||||
strength: Union[float, None] = Field(
|
||||
default=None,
|
||||
description="The strength used for latents-to-latents",
|
||||
)
|
||||
init_image: Union[str, None] = Field(
|
||||
default=None, description="The name of the initial image"
|
||||
)
|
||||
vae: Union[VAEModelField, None] = Field(
|
||||
default=None,
|
||||
description="The VAE used for decoding, if the main model's default was not used",
|
||||
)
|
||||
|
||||
|
||||
class ImageMetadata(BaseModel):
|
||||
"""An image's generation metadata"""
|
||||
|
||||
metadata: Optional[dict] = Field(
|
||||
default=None,
|
||||
description="The image's core metadata, if it was created in the Linear or Canvas UI",
|
||||
)
|
||||
graph: Optional[dict] = Field(
|
||||
default=None, description="The graph that created the image"
|
||||
)
|
||||
|
||||
|
||||
class MetadataAccumulatorOutput(BaseInvocationOutput):
|
||||
"""The output of the MetadataAccumulator node"""
|
||||
|
||||
type: Literal["metadata_accumulator_output"] = "metadata_accumulator_output"
|
||||
|
||||
metadata: CoreMetadata = Field(description="The core metadata for the image")
|
||||
|
||||
|
||||
class MetadataAccumulatorInvocation(BaseInvocation):
|
||||
"""Outputs a Core Metadata Object"""
|
||||
|
||||
type: Literal["metadata_accumulator"] = "metadata_accumulator"
|
||||
|
||||
generation_mode: str = Field(description="The generation mode that output this image",)
|
||||
positive_prompt: str = Field(description="The positive prompt parameter")
|
||||
negative_prompt: str = Field(description="The negative prompt parameter")
|
||||
width: int = Field(description="The width parameter")
|
||||
height: int = Field(description="The height parameter")
|
||||
seed: int = Field(description="The seed used for noise generation")
|
||||
rand_device: str = Field(description="The device used for random number generation")
|
||||
cfg_scale: float = Field(description="The classifier-free guidance scale parameter")
|
||||
steps: int = Field(description="The number of steps used for inference")
|
||||
scheduler: str = Field(description="The scheduler used for inference")
|
||||
clip_skip: int = Field(description="The number of skipped CLIP layers",)
|
||||
model: MainModelField = Field(description="The main model used for inference")
|
||||
controlnets: list[ControlField]= Field(description="The ControlNets used for inference")
|
||||
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
|
||||
strength: Union[float, None] = Field(
|
||||
default=None,
|
||||
description="The strength used for latents-to-latents",
|
||||
)
|
||||
init_image: Union[str, None] = Field(
|
||||
default=None, description="The name of the initial image"
|
||||
)
|
||||
vae: Union[VAEModelField, None] = Field(
|
||||
default=None,
|
||||
description="The VAE used for decoding, if the main model's default was not used",
|
||||
)
|
||||
|
||||
|
||||
def invoke(self, context: InvocationContext) -> MetadataAccumulatorOutput:
|
||||
"""Collects and outputs a CoreMetadata object"""
|
||||
|
||||
return MetadataAccumulatorOutput(
|
||||
metadata=CoreMetadata(
|
||||
generation_mode=self.generation_mode,
|
||||
positive_prompt=self.positive_prompt,
|
||||
negative_prompt=self.negative_prompt,
|
||||
width=self.width,
|
||||
height=self.height,
|
||||
seed=self.seed,
|
||||
rand_device=self.rand_device,
|
||||
cfg_scale=self.cfg_scale,
|
||||
steps=self.steps,
|
||||
scheduler=self.scheduler,
|
||||
model=self.model,
|
||||
strength=self.strength,
|
||||
init_image=self.init_image,
|
||||
vae=self.vae,
|
||||
controlnets=self.controlnets,
|
||||
loras=self.loras,
|
||||
clip_skip=self.clip_skip,
|
||||
)
|
||||
)
|
@ -1,93 +0,0 @@
|
||||
from typing import Optional, Union, List
|
||||
from pydantic import BaseModel, Extra, Field, StrictFloat, StrictInt, StrictStr
|
||||
|
||||
|
||||
class ImageMetadata(BaseModel):
|
||||
"""
|
||||
Core generation metadata for an image/tensor generated in InvokeAI.
|
||||
|
||||
Also includes any metadata from the image's PNG tEXt chunks.
|
||||
|
||||
Generated by traversing the execution graph, collecting the parameters of the nearest ancestors
|
||||
of a given node.
|
||||
|
||||
Full metadata may be accessed by querying for the session in the `graph_executions` table.
|
||||
"""
|
||||
|
||||
class Config:
|
||||
extra = Extra.allow
|
||||
"""
|
||||
This lets the ImageMetadata class accept arbitrary additional fields. The CoreMetadataService
|
||||
won't add any fields that are not already defined, but other a different metadata service
|
||||
implementation might.
|
||||
"""
|
||||
|
||||
type: Optional[StrictStr] = Field(
|
||||
default=None,
|
||||
description="The type of the ancestor node of the image output node.",
|
||||
)
|
||||
"""The type of the ancestor node of the image output node."""
|
||||
positive_conditioning: Optional[StrictStr] = Field(
|
||||
default=None, description="The positive conditioning."
|
||||
)
|
||||
"""The positive conditioning"""
|
||||
negative_conditioning: Optional[StrictStr] = Field(
|
||||
default=None, description="The negative conditioning."
|
||||
)
|
||||
"""The negative conditioning"""
|
||||
width: Optional[StrictInt] = Field(
|
||||
default=None, description="Width of the image/latents in pixels."
|
||||
)
|
||||
"""Width of the image/latents in pixels"""
|
||||
height: Optional[StrictInt] = Field(
|
||||
default=None, description="Height of the image/latents in pixels."
|
||||
)
|
||||
"""Height of the image/latents in pixels"""
|
||||
seed: Optional[StrictInt] = Field(
|
||||
default=None, description="The seed used for noise generation."
|
||||
)
|
||||
"""The seed used for noise generation"""
|
||||
# cfg_scale: Optional[StrictFloat] = Field(
|
||||
# cfg_scale: Union[float, list[float]] = Field(
|
||||
cfg_scale: Union[StrictFloat, List[StrictFloat]] = Field(
|
||||
default=None, description="The classifier-free guidance scale."
|
||||
)
|
||||
"""The classifier-free guidance scale"""
|
||||
steps: Optional[StrictInt] = Field(
|
||||
default=None, description="The number of steps used for inference."
|
||||
)
|
||||
"""The number of steps used for inference"""
|
||||
scheduler: Optional[StrictStr] = Field(
|
||||
default=None, description="The scheduler used for inference."
|
||||
)
|
||||
"""The scheduler used for inference"""
|
||||
model: Optional[StrictStr] = Field(
|
||||
default=None, description="The model used for inference."
|
||||
)
|
||||
"""The model used for inference"""
|
||||
strength: Optional[StrictFloat] = Field(
|
||||
default=None,
|
||||
description="The strength used for image-to-image/latents-to-latents.",
|
||||
)
|
||||
"""The strength used for image-to-image/latents-to-latents."""
|
||||
latents: Optional[StrictStr] = Field(
|
||||
default=None, description="The ID of the initial latents."
|
||||
)
|
||||
"""The ID of the initial latents"""
|
||||
vae: Optional[StrictStr] = Field(
|
||||
default=None, description="The VAE used for decoding."
|
||||
)
|
||||
"""The VAE used for decoding"""
|
||||
unet: Optional[StrictStr] = Field(
|
||||
default=None, description="The UNet used dor inference."
|
||||
)
|
||||
"""The UNet used dor inference"""
|
||||
clip: Optional[StrictStr] = Field(
|
||||
default=None, description="The CLIP Encoder used for conditioning."
|
||||
)
|
||||
"""The CLIP Encoder used for conditioning"""
|
||||
extra: Optional[StrictStr] = Field(
|
||||
default=None,
|
||||
description="Uploaded image metadata, extracted from the PNG tEXt chunk.",
|
||||
)
|
||||
"""Uploaded image metadata, extracted from the PNG tEXt chunk."""
|
@ -1,14 +1,14 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from PIL.Image import Image as PILImageType
|
||||
from PIL import Image, PngImagePlugin
|
||||
from PIL.Image import Image as PILImageType
|
||||
from send2trash import send2trash
|
||||
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
||||
|
||||
|
||||
@ -59,7 +59,8 @@ class ImageFileStorageBase(ABC):
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_name: str,
|
||||
metadata: Optional[ImageMetadata] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
graph: Optional[dict] = None,
|
||||
thumbnail_size: int = 256,
|
||||
) -> None:
|
||||
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
|
||||
@ -110,20 +111,22 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_name: str,
|
||||
metadata: Optional[ImageMetadata] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
graph: Optional[dict] = None,
|
||||
thumbnail_size: int = 256,
|
||||
) -> None:
|
||||
try:
|
||||
self.__validate_storage_folders()
|
||||
image_path = self.get_path(image_name)
|
||||
|
||||
pnginfo = PngImagePlugin.PngInfo()
|
||||
|
||||
if metadata is not None:
|
||||
pnginfo = PngImagePlugin.PngInfo()
|
||||
pnginfo.add_text("invokeai", metadata.json())
|
||||
image.save(image_path, "PNG", pnginfo=pnginfo)
|
||||
else:
|
||||
image.save(image_path, "PNG")
|
||||
pnginfo.add_text("invokeai_metadata", json.dumps(metadata))
|
||||
if graph is not None:
|
||||
pnginfo.add_text("invokeai_graph", json.dumps(graph))
|
||||
|
||||
image.save(image_path, "PNG", pnginfo=pnginfo)
|
||||
thumbnail_name = get_thumbnail_name(image_name)
|
||||
thumbnail_path = self.get_path(thumbnail_name, thumbnail=True)
|
||||
thumbnail_image = make_thumbnail(image, thumbnail_size)
|
||||
|
@ -1,3 +1,4 @@
|
||||
import json
|
||||
import sqlite3
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
@ -8,7 +9,6 @@ from pydantic import BaseModel, Field
|
||||
from pydantic.generics import GenericModel
|
||||
|
||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.services.models.image_record import (
|
||||
ImageRecord, ImageRecordChanges, deserialize_image_record)
|
||||
|
||||
@ -48,6 +48,28 @@ class ImageRecordDeleteException(Exception):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
IMAGE_DTO_COLS = ", ".join(
|
||||
list(
|
||||
map(
|
||||
lambda c: "images." + c,
|
||||
[
|
||||
"image_name",
|
||||
"image_origin",
|
||||
"image_category",
|
||||
"width",
|
||||
"height",
|
||||
"session_id",
|
||||
"node_id",
|
||||
"is_intermediate",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
"deleted_at",
|
||||
],
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class ImageRecordStorageBase(ABC):
|
||||
"""Low-level service responsible for interfacing with the image record store."""
|
||||
|
||||
@ -58,6 +80,11 @@ class ImageRecordStorageBase(ABC):
|
||||
"""Gets an image record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_metadata(self, image_name: str) -> Optional[dict]:
|
||||
"""Gets an image's metadata'."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(
|
||||
self,
|
||||
@ -102,7 +129,7 @@ class ImageRecordStorageBase(ABC):
|
||||
height: int,
|
||||
session_id: Optional[str],
|
||||
node_id: Optional[str],
|
||||
metadata: Optional[ImageMetadata],
|
||||
metadata: Optional[dict],
|
||||
is_intermediate: bool = False,
|
||||
) -> datetime:
|
||||
"""Saves an image record."""
|
||||
@ -206,7 +233,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
SELECT * FROM images
|
||||
SELECT {IMAGE_DTO_COLS} FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
@ -224,6 +251,28 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
|
||||
return deserialize_image_record(dict(result))
|
||||
|
||||
def get_metadata(self, image_name: str) -> Optional[dict]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
SELECT images.metadata FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
|
||||
result = cast(Optional[sqlite3.Row], self._cursor.fetchone())
|
||||
if not result or not result[0]:
|
||||
return None
|
||||
return json.loads(result[0])
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def update(
|
||||
self,
|
||||
image_name: str,
|
||||
@ -291,8 +340,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
WHERE 1=1
|
||||
"""
|
||||
|
||||
images_query = """--sql
|
||||
SELECT images.*
|
||||
images_query = f"""--sql
|
||||
SELECT {IMAGE_DTO_COLS}
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
@ -410,12 +459,12 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
width: int,
|
||||
height: int,
|
||||
node_id: Optional[str],
|
||||
metadata: Optional[ImageMetadata],
|
||||
metadata: Optional[dict],
|
||||
is_intermediate: bool = False,
|
||||
) -> datetime:
|
||||
try:
|
||||
metadata_json = (
|
||||
None if metadata is None else metadata.json(exclude_none=True)
|
||||
None if metadata is None else json.dumps(metadata)
|
||||
)
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
@ -465,9 +514,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def get_most_recent_image_for_board(
|
||||
self, board_id: str
|
||||
) -> Optional[ImageRecord]:
|
||||
def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
|
@ -1,39 +1,30 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from logging import Logger
|
||||
from typing import Optional, TYPE_CHECKING, Union
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
from invokeai.app.models.image import (
|
||||
ImageCategory,
|
||||
ResourceOrigin,
|
||||
InvalidImageCategoryException,
|
||||
InvalidOriginException,
|
||||
)
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
|
||||
from invokeai.app.services.image_record_storage import (
|
||||
ImageRecordDeleteException,
|
||||
ImageRecordNotFoundException,
|
||||
ImageRecordSaveException,
|
||||
ImageRecordStorageBase,
|
||||
OffsetPaginatedResults,
|
||||
)
|
||||
from invokeai.app.services.models.image_record import (
|
||||
ImageRecord,
|
||||
ImageDTO,
|
||||
ImageRecordChanges,
|
||||
image_record_to_dto,
|
||||
)
|
||||
from invokeai.app.invocations.metadata import ImageMetadata
|
||||
from invokeai.app.models.image import (ImageCategory,
|
||||
InvalidImageCategoryException,
|
||||
InvalidOriginException, ResourceOrigin)
|
||||
from invokeai.app.services.board_image_record_storage import \
|
||||
BoardImageRecordStorageBase
|
||||
from invokeai.app.services.graph import Graph
|
||||
from invokeai.app.services.image_file_storage import (
|
||||
ImageFileDeleteException,
|
||||
ImageFileNotFoundException,
|
||||
ImageFileSaveException,
|
||||
ImageFileStorageBase,
|
||||
)
|
||||
from invokeai.app.services.item_storage import ItemStorageABC, PaginatedResults
|
||||
from invokeai.app.services.metadata import MetadataServiceBase
|
||||
ImageFileDeleteException, ImageFileNotFoundException,
|
||||
ImageFileSaveException, ImageFileStorageBase)
|
||||
from invokeai.app.services.image_record_storage import (
|
||||
ImageRecordDeleteException, ImageRecordNotFoundException,
|
||||
ImageRecordSaveException, ImageRecordStorageBase, OffsetPaginatedResults)
|
||||
from invokeai.app.services.item_storage import ItemStorageABC
|
||||
from invokeai.app.services.models.image_record import (ImageDTO, ImageRecord,
|
||||
ImageRecordChanges,
|
||||
image_record_to_dto)
|
||||
from invokeai.app.services.resource_name import NameServiceBase
|
||||
from invokeai.app.services.urls import UrlServiceBase
|
||||
from invokeai.app.util.metadata import get_metadata_graph_from_raw_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.graph import GraphExecutionState
|
||||
@ -51,6 +42,7 @@ class ImageServiceABC(ABC):
|
||||
node_id: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
is_intermediate: bool = False,
|
||||
metadata: Optional[dict] = None,
|
||||
) -> ImageDTO:
|
||||
"""Creates an image, storing the file and its metadata."""
|
||||
pass
|
||||
@ -79,6 +71,11 @@ class ImageServiceABC(ABC):
|
||||
"""Gets an image DTO."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_metadata(self, image_name: str) -> ImageMetadata:
|
||||
"""Gets an image's metadata."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
"""Gets an image's path."""
|
||||
@ -124,7 +121,6 @@ class ImageServiceDependencies:
|
||||
image_records: ImageRecordStorageBase
|
||||
image_files: ImageFileStorageBase
|
||||
board_image_records: BoardImageRecordStorageBase
|
||||
metadata: MetadataServiceBase
|
||||
urls: UrlServiceBase
|
||||
logger: Logger
|
||||
names: NameServiceBase
|
||||
@ -135,7 +131,6 @@ class ImageServiceDependencies:
|
||||
image_record_storage: ImageRecordStorageBase,
|
||||
image_file_storage: ImageFileStorageBase,
|
||||
board_image_record_storage: BoardImageRecordStorageBase,
|
||||
metadata: MetadataServiceBase,
|
||||
url: UrlServiceBase,
|
||||
logger: Logger,
|
||||
names: NameServiceBase,
|
||||
@ -144,7 +139,6 @@ class ImageServiceDependencies:
|
||||
self.image_records = image_record_storage
|
||||
self.image_files = image_file_storage
|
||||
self.board_image_records = board_image_record_storage
|
||||
self.metadata = metadata
|
||||
self.urls = url
|
||||
self.logger = logger
|
||||
self.names = names
|
||||
@ -165,6 +159,7 @@ class ImageService(ImageServiceABC):
|
||||
node_id: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
is_intermediate: bool = False,
|
||||
metadata: Optional[dict] = None,
|
||||
) -> ImageDTO:
|
||||
if image_origin not in ResourceOrigin:
|
||||
raise InvalidOriginException
|
||||
@ -174,7 +169,16 @@ class ImageService(ImageServiceABC):
|
||||
|
||||
image_name = self._services.names.create_image_name()
|
||||
|
||||
metadata = self._get_metadata(session_id, node_id)
|
||||
graph = None
|
||||
|
||||
if session_id is not None:
|
||||
session_raw = self._services.graph_execution_manager.get_raw(session_id)
|
||||
if session_raw is not None:
|
||||
try:
|
||||
graph = get_metadata_graph_from_raw_session(session_raw)
|
||||
except Exception as e:
|
||||
self._services.logger.warn(f"Failed to parse session graph: {e}")
|
||||
graph = None
|
||||
|
||||
(width, height) = image.size
|
||||
|
||||
@ -191,14 +195,12 @@ class ImageService(ImageServiceABC):
|
||||
is_intermediate=is_intermediate,
|
||||
# Nullable fields
|
||||
node_id=node_id,
|
||||
session_id=session_id,
|
||||
metadata=metadata,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
self._services.image_files.save(
|
||||
image_name=image_name,
|
||||
image=image,
|
||||
metadata=metadata,
|
||||
image_name=image_name, image=image, metadata=metadata, graph=graph
|
||||
)
|
||||
|
||||
image_dto = self.get_dto(image_name)
|
||||
@ -268,6 +270,34 @@ class ImageService(ImageServiceABC):
|
||||
self._services.logger.error("Problem getting image DTO")
|
||||
raise e
|
||||
|
||||
def get_metadata(self, image_name: str) -> Optional[ImageMetadata]:
|
||||
try:
|
||||
image_record = self._services.image_records.get(image_name)
|
||||
|
||||
if not image_record.session_id:
|
||||
return ImageMetadata()
|
||||
|
||||
session_raw = self._services.graph_execution_manager.get_raw(
|
||||
image_record.session_id
|
||||
)
|
||||
graph = None
|
||||
|
||||
if session_raw:
|
||||
try:
|
||||
graph = get_metadata_graph_from_raw_session(session_raw)
|
||||
except Exception as e:
|
||||
self._services.logger.warn(f"Failed to parse session graph: {e}")
|
||||
graph = None
|
||||
|
||||
metadata = self._services.image_records.get_metadata(image_name)
|
||||
return ImageMetadata(graph=graph, metadata=metadata)
|
||||
except ImageRecordNotFoundException:
|
||||
self._services.logger.error("Image record not found")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting image DTO")
|
||||
raise e
|
||||
|
||||
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
try:
|
||||
return self._services.image_files.get_path(image_name, thumbnail)
|
||||
@ -367,15 +397,3 @@ class ImageService(ImageServiceABC):
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem deleting image records and files")
|
||||
raise e
|
||||
|
||||
def _get_metadata(
|
||||
self, session_id: Optional[str] = None, node_id: Optional[str] = None
|
||||
) -> Optional[ImageMetadata]:
|
||||
"""Get the metadata for a node."""
|
||||
metadata = None
|
||||
|
||||
if node_id is not None and session_id is not None:
|
||||
session = self._services.graph_execution_manager.get(session_id)
|
||||
metadata = self._services.metadata.create_image_metadata(session, node_id)
|
||||
|
||||
return metadata
|
||||
|
@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Generic, TypeVar
|
||||
from typing import Callable, Generic, Optional, TypeVar
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.generics import GenericModel
|
||||
@ -29,14 +29,22 @@ class ItemStorageABC(ABC, Generic[T]):
|
||||
|
||||
@abstractmethod
|
||||
def get(self, item_id: str) -> T:
|
||||
"""Gets the item, parsing it into a Pydantic model"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_raw(self, item_id: str) -> Optional[str]:
|
||||
"""Gets the raw item as a string, skipping Pydantic parsing"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set(self, item: T) -> None:
|
||||
"""Sets the item"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
||||
"""Gets a paginated list of items"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
@ -1,142 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional
|
||||
import networkx as nx
|
||||
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.services.graph import Graph, GraphExecutionState
|
||||
|
||||
|
||||
class MetadataServiceBase(ABC):
|
||||
"""Handles building metadata for nodes, images, and outputs."""
|
||||
|
||||
@abstractmethod
|
||||
def create_image_metadata(
|
||||
self, session: GraphExecutionState, node_id: str
|
||||
) -> ImageMetadata:
|
||||
"""Builds an ImageMetadata object for a node."""
|
||||
pass
|
||||
|
||||
|
||||
class CoreMetadataService(MetadataServiceBase):
|
||||
_ANCESTOR_TYPES = ["t2l", "l2l"]
|
||||
"""The ancestor types that contain the core metadata"""
|
||||
|
||||
_ANCESTOR_PARAMS = ["type", "steps", "model", "cfg_scale", "scheduler", "strength"]
|
||||
"""The core metadata parameters in the ancestor types"""
|
||||
|
||||
_NOISE_FIELDS = ["seed", "width", "height"]
|
||||
"""The core metadata parameters in the noise node"""
|
||||
|
||||
def create_image_metadata(
|
||||
self, session: GraphExecutionState, node_id: str
|
||||
) -> ImageMetadata:
|
||||
metadata = self._build_metadata_from_graph(session, node_id)
|
||||
|
||||
return metadata
|
||||
|
||||
def _find_nearest_ancestor(self, G: nx.DiGraph, node_id: str) -> Optional[str]:
|
||||
"""
|
||||
Finds the id of the nearest ancestor (of a valid type) of a given node.
|
||||
|
||||
Parameters:
|
||||
G (nx.DiGraph): The execution graph, converted in to a networkx DiGraph. Its nodes must
|
||||
have the same data as the execution graph.
|
||||
node_id (str): The ID of the node.
|
||||
|
||||
Returns:
|
||||
str | None: The ID of the nearest ancestor, or None if there are no valid ancestors.
|
||||
"""
|
||||
|
||||
# Retrieve the node from the graph
|
||||
node = G.nodes[node_id]
|
||||
|
||||
# If the node type is one of the core metadata node types, return its id
|
||||
if node.get("type") in self._ANCESTOR_TYPES:
|
||||
return node.get("id")
|
||||
|
||||
# Else, look for the ancestor in the predecessor nodes
|
||||
for predecessor in G.predecessors(node_id):
|
||||
result = self._find_nearest_ancestor(G, predecessor)
|
||||
if result:
|
||||
return result
|
||||
|
||||
# If there are no valid ancestors, return None
|
||||
return None
|
||||
|
||||
def _get_additional_metadata(
|
||||
self, graph: Graph, node_id: str
|
||||
) -> Optional[dict[str, Any]]:
|
||||
"""
|
||||
Returns additional metadata for a given node.
|
||||
|
||||
Parameters:
|
||||
graph (Graph): The execution graph.
|
||||
node_id (str): The ID of the node.
|
||||
|
||||
Returns:
|
||||
dict[str, Any] | None: A dictionary of additional metadata.
|
||||
"""
|
||||
|
||||
metadata = {}
|
||||
|
||||
# Iterate over all edges in the graph
|
||||
for edge in graph.edges:
|
||||
dest_node_id = edge.destination.node_id
|
||||
dest_field = edge.destination.field
|
||||
source_node_dict = graph.nodes[edge.source.node_id].dict()
|
||||
|
||||
# If the destination node ID matches the given node ID, gather necessary metadata
|
||||
if dest_node_id == node_id:
|
||||
# Prompt
|
||||
if dest_field == "positive_conditioning":
|
||||
metadata["positive_conditioning"] = source_node_dict.get("prompt")
|
||||
# Negative prompt
|
||||
if dest_field == "negative_conditioning":
|
||||
metadata["negative_conditioning"] = source_node_dict.get("prompt")
|
||||
# Seed, width and height
|
||||
if dest_field == "noise":
|
||||
for field in self._NOISE_FIELDS:
|
||||
metadata[field] = source_node_dict.get(field)
|
||||
return metadata
|
||||
|
||||
def _build_metadata_from_graph(
|
||||
self, session: GraphExecutionState, node_id: str
|
||||
) -> ImageMetadata:
|
||||
"""
|
||||
Builds an ImageMetadata object for a node.
|
||||
|
||||
Parameters:
|
||||
session (GraphExecutionState): The session.
|
||||
node_id (str): The ID of the node.
|
||||
|
||||
Returns:
|
||||
ImageMetadata: The metadata for the node.
|
||||
"""
|
||||
|
||||
# We need to do all the traversal on the execution graph
|
||||
graph = session.execution_graph
|
||||
|
||||
# Find the nearest `t2l`/`l2l` ancestor of the given node
|
||||
ancestor_id = self._find_nearest_ancestor(graph.nx_graph_with_data(), node_id)
|
||||
|
||||
# If no ancestor was found, return an empty ImageMetadata object
|
||||
if ancestor_id is None:
|
||||
return ImageMetadata()
|
||||
|
||||
ancestor_node = graph.get_node(ancestor_id)
|
||||
|
||||
# Grab all the core metadata from the ancestor node
|
||||
ancestor_metadata = {
|
||||
param: val
|
||||
for param, val in ancestor_node.dict().items()
|
||||
if param in self._ANCESTOR_PARAMS
|
||||
}
|
||||
|
||||
# Get this image's prompts and noise parameters
|
||||
addl_metadata = self._get_additional_metadata(graph, ancestor_id)
|
||||
|
||||
# If additional metadata was found, add it to the main metadata
|
||||
if addl_metadata is not None:
|
||||
ancestor_metadata.update(addl_metadata)
|
||||
|
||||
return ImageMetadata(**ancestor_metadata)
|
@ -1,13 +1,14 @@
|
||||
import datetime
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr
|
||||
|
||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.util.misc import get_iso_timestamp
|
||||
|
||||
|
||||
class ImageRecord(BaseModel):
|
||||
"""Deserialized image record."""
|
||||
"""Deserialized image record without metadata."""
|
||||
|
||||
image_name: str = Field(description="The unique name of the image.")
|
||||
"""The unique name of the image."""
|
||||
@ -43,11 +44,6 @@ class ImageRecord(BaseModel):
|
||||
description="The node ID that generated this image, if it is a generated image.",
|
||||
)
|
||||
"""The node ID that generated this image, if it is a generated image."""
|
||||
metadata: Optional[ImageMetadata] = Field(
|
||||
default=None,
|
||||
description="A limited subset of the image's generation metadata. Retrieve the image's session for full metadata.",
|
||||
)
|
||||
"""A limited subset of the image's generation metadata. Retrieve the image's session for full metadata."""
|
||||
|
||||
|
||||
class ImageRecordChanges(BaseModel, extra=Extra.forbid):
|
||||
@ -112,6 +108,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
||||
|
||||
# Retrieve all the values, setting "reasonable" defaults if they are not present.
|
||||
|
||||
# TODO: do we really need to handle default values here? ideally the data is the correct shape...
|
||||
image_name = image_dict.get("image_name", "unknown")
|
||||
image_origin = ResourceOrigin(
|
||||
image_dict.get("image_origin", ResourceOrigin.INTERNAL.value)
|
||||
@ -128,13 +125,6 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
||||
deleted_at = image_dict.get("deleted_at", get_iso_timestamp())
|
||||
is_intermediate = image_dict.get("is_intermediate", False)
|
||||
|
||||
raw_metadata = image_dict.get("metadata")
|
||||
|
||||
if raw_metadata is not None:
|
||||
metadata = ImageMetadata.parse_raw(raw_metadata)
|
||||
else:
|
||||
metadata = None
|
||||
|
||||
return ImageRecord(
|
||||
image_name=image_name,
|
||||
image_origin=image_origin,
|
||||
@ -143,7 +133,6 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
||||
height=height,
|
||||
session_id=session_id,
|
||||
node_id=node_id,
|
||||
metadata=metadata,
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
deleted_at=deleted_at,
|
||||
|
@ -1,6 +1,6 @@
|
||||
import sqlite3
|
||||
from threading import Lock
|
||||
from typing import Generic, TypeVar, Optional, Union, get_args
|
||||
from typing import Generic, Optional, TypeVar, get_args
|
||||
|
||||
from pydantic import BaseModel, parse_raw_as
|
||||
|
||||
@ -78,6 +78,21 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
|
||||
return self._parse_item(result[0])
|
||||
|
||||
def get_raw(self, id: str) -> Optional[str]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),)
|
||||
)
|
||||
result = self._cursor.fetchone()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
if not result:
|
||||
return None
|
||||
|
||||
return result[0]
|
||||
|
||||
def delete(self, id: str):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
@ -22,4 +22,4 @@ class LocalUrlService(UrlServiceBase):
|
||||
if thumbnail:
|
||||
return f"{self._base_url}/images/{image_basename}/thumbnail"
|
||||
|
||||
return f"{self._base_url}/images/{image_basename}"
|
||||
return f"{self._base_url}/images/{image_basename}/full"
|
||||
|
55
invokeai/app/util/metadata.py
Normal file
55
invokeai/app/util/metadata.py
Normal file
@ -0,0 +1,55 @@
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from invokeai.app.services.graph import Edge
|
||||
|
||||
|
||||
def get_metadata_graph_from_raw_session(session_raw: str) -> Optional[dict]:
|
||||
"""
|
||||
Parses raw session string, returning a dict of the graph.
|
||||
|
||||
Only the general graph shape is validated; none of the fields are validated.
|
||||
|
||||
Any `metadata_accumulator` nodes and edges are removed.
|
||||
|
||||
Any validation failure will return None.
|
||||
"""
|
||||
|
||||
graph = json.loads(session_raw).get("graph", None)
|
||||
|
||||
# sanity check make sure the graph is at least reasonably shaped
|
||||
if (
|
||||
type(graph) is not dict
|
||||
or "nodes" not in graph
|
||||
or type(graph["nodes"]) is not dict
|
||||
or "edges" not in graph
|
||||
or type(graph["edges"]) is not list
|
||||
):
|
||||
# something has gone terribly awry, return an empty dict
|
||||
return None
|
||||
|
||||
try:
|
||||
# delete the `metadata_accumulator` node
|
||||
del graph["nodes"]["metadata_accumulator"]
|
||||
except KeyError:
|
||||
# no accumulator node, all good
|
||||
pass
|
||||
|
||||
# delete any edges to or from it
|
||||
for i, edge in enumerate(graph["edges"]):
|
||||
try:
|
||||
# try to parse the edge
|
||||
Edge(**edge)
|
||||
except ValidationError:
|
||||
# something has gone terribly awry, return an empty dict
|
||||
return None
|
||||
|
||||
if (
|
||||
edge["source"]["node_id"] == "metadata_accumulator"
|
||||
or edge["destination"]["node_id"] == "metadata_accumulator"
|
||||
):
|
||||
del graph["edges"][i]
|
||||
|
||||
return graph
|
@ -121,8 +121,8 @@ class ModelInstall(object):
|
||||
installed_models = self.mgr.list_models()
|
||||
for md in installed_models:
|
||||
base = md['base_model']
|
||||
model_type = md['type']
|
||||
name = md['name']
|
||||
model_type = md['model_type']
|
||||
name = md['model_name']
|
||||
key = ModelManager.create_key(name, base, model_type)
|
||||
if key in model_dict:
|
||||
model_dict[key].installed = True
|
||||
|
@ -250,8 +250,8 @@ from .model_cache import ModelCache, ModelLocker
|
||||
from .models import (
|
||||
BaseModelType, ModelType, SubModelType,
|
||||
ModelError, SchedulerPredictionType, MODEL_CLASSES,
|
||||
ModelConfigBase, ModelNotFoundException,
|
||||
)
|
||||
ModelConfigBase, ModelNotFoundException, InvalidModelException,
|
||||
)
|
||||
|
||||
# We are only starting to number the config file with release 3.
|
||||
# The config file version doesn't have to start at release version, but it will help
|
||||
@ -275,10 +275,6 @@ class ModelInfo():
|
||||
def __exit__(self,*args, **kwargs):
|
||||
self.context.__exit__(*args, **kwargs)
|
||||
|
||||
class InvalidModelError(Exception):
|
||||
"Raised when an invalid model is requested"
|
||||
pass
|
||||
|
||||
class AddModelResult(BaseModel):
|
||||
name: str = Field(description="The name of the model after installation")
|
||||
model_type: ModelType = Field(description="The type of model")
|
||||
@ -542,9 +538,9 @@ class ModelManager(object):
|
||||
model_dict = dict(
|
||||
**model_config.dict(exclude_defaults=True),
|
||||
# OpenAPIModelInfoBase
|
||||
name=cur_model_name,
|
||||
model_name=cur_model_name,
|
||||
base_model=cur_base_model,
|
||||
type=cur_model_type,
|
||||
model_type=cur_model_type,
|
||||
)
|
||||
|
||||
models.append(model_dict)
|
||||
@ -817,6 +813,8 @@ class ModelManager(object):
|
||||
model_config: ModelConfigBase = model_class.probe_config(str(model_path))
|
||||
self.models[model_key] = model_config
|
||||
new_models_found = True
|
||||
except InvalidModelException:
|
||||
self.logger.warning(f"Not a valid model: {model_path}")
|
||||
except NotImplementedError as e:
|
||||
self.logger.warning(e)
|
||||
|
||||
|
@ -2,7 +2,7 @@ import inspect
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel
|
||||
from typing import Literal, get_origin
|
||||
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings, ModelNotFoundException
|
||||
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings, ModelNotFoundException, InvalidModelException
|
||||
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
|
||||
from .vae import VaeModel
|
||||
from .lora import LoRAModel
|
||||
@ -37,9 +37,9 @@ MODEL_CONFIGS = list()
|
||||
OPENAPI_MODEL_CONFIGS = list()
|
||||
|
||||
class OpenAPIModelInfoBase(BaseModel):
|
||||
name: str
|
||||
model_name: str
|
||||
base_model: BaseModelType
|
||||
type: ModelType
|
||||
model_type: ModelType
|
||||
|
||||
|
||||
for base_model, models in MODEL_CLASSES.items():
|
||||
@ -56,7 +56,7 @@ for base_model, models in MODEL_CLASSES.items():
|
||||
|
||||
api_wrapper = type(openapi_cfg_name, (cfg, OpenAPIModelInfoBase), dict(
|
||||
__annotations__ = dict(
|
||||
type=Literal[model_type.value],
|
||||
model_type=Literal[model_type.value],
|
||||
),
|
||||
))
|
||||
|
||||
|
@ -15,6 +15,9 @@ from contextlib import suppress
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
|
||||
|
||||
class InvalidModelException(Exception):
|
||||
pass
|
||||
|
||||
class ModelNotFoundException(Exception):
|
||||
pass
|
||||
|
||||
|
@ -13,6 +13,7 @@ from .base import (
|
||||
calc_model_size_by_fs,
|
||||
calc_model_size_by_data,
|
||||
classproperty,
|
||||
InvalidModelException,
|
||||
)
|
||||
|
||||
class ControlNetModelFormat(str, Enum):
|
||||
@ -73,10 +74,18 @@ class ControlNetModel(ModelBase):
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, path: str):
|
||||
if not os.path.exists(path):
|
||||
raise ModelNotFoundException()
|
||||
|
||||
if os.path.isdir(path):
|
||||
return ControlNetModelFormat.Diffusers
|
||||
else:
|
||||
return ControlNetModelFormat.Checkpoint
|
||||
if os.path.exists(os.path.join(path, "config.json")):
|
||||
return ControlNetModelFormat.Diffusers
|
||||
|
||||
if os.path.isfile(path):
|
||||
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "pth"]]):
|
||||
return ControlNetModelFormat.Checkpoint
|
||||
|
||||
raise InvalidModelException(f"Not a valid model: {path}")
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
|
@ -9,6 +9,7 @@ from .base import (
|
||||
ModelType,
|
||||
SubModelType,
|
||||
classproperty,
|
||||
InvalidModelException,
|
||||
)
|
||||
# TODO: naming
|
||||
from ..lora import LoRAModel as LoRAModelRaw
|
||||
@ -56,10 +57,18 @@ class LoRAModel(ModelBase):
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, path: str):
|
||||
if not os.path.exists(path):
|
||||
raise ModelNotFoundException()
|
||||
|
||||
if os.path.isdir(path):
|
||||
return LoRAModelFormat.Diffusers
|
||||
else:
|
||||
return LoRAModelFormat.LyCORIS
|
||||
if os.path.exists(os.path.join(path, "pytorch_lora_weights.bin")):
|
||||
return LoRAModelFormat.Diffusers
|
||||
|
||||
if os.path.isfile(path):
|
||||
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
|
||||
return LoRAModelFormat.LyCORIS
|
||||
|
||||
raise InvalidModelException(f"Not a valid model: {path}")
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
|
@ -16,6 +16,7 @@ from .base import (
|
||||
SilenceWarnings,
|
||||
read_checkpoint_meta,
|
||||
classproperty,
|
||||
InvalidModelException,
|
||||
)
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from omegaconf import OmegaConf
|
||||
@ -98,10 +99,18 @@ class StableDiffusion1Model(DiffusersModel):
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, model_path: str):
|
||||
if not os.path.exists(model_path):
|
||||
raise ModelNotFoundException()
|
||||
|
||||
if os.path.isdir(model_path):
|
||||
return StableDiffusion1ModelFormat.Diffusers
|
||||
else:
|
||||
return StableDiffusion1ModelFormat.Checkpoint
|
||||
if os.path.exists(os.path.join(model_path, "model_index.json")):
|
||||
return StableDiffusion1ModelFormat.Diffusers
|
||||
|
||||
if os.path.isfile(model_path):
|
||||
if any([model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
|
||||
return StableDiffusion1ModelFormat.Checkpoint
|
||||
|
||||
raise InvalidModelException(f"Not a valid model: {model_path}")
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
@ -200,10 +209,18 @@ class StableDiffusion2Model(DiffusersModel):
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, model_path: str):
|
||||
if not os.path.exists(model_path):
|
||||
raise ModelNotFoundException()
|
||||
|
||||
if os.path.isdir(model_path):
|
||||
return StableDiffusion2ModelFormat.Diffusers
|
||||
else:
|
||||
return StableDiffusion2ModelFormat.Checkpoint
|
||||
if os.path.exists(os.path.join(model_path, "model_index.json")):
|
||||
return StableDiffusion2ModelFormat.Diffusers
|
||||
|
||||
if os.path.isfile(model_path):
|
||||
if any([model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
|
||||
return StableDiffusion2ModelFormat.Checkpoint
|
||||
|
||||
raise InvalidModelException(f"Not a valid model: {model_path}")
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
|
@ -9,6 +9,7 @@ from .base import (
|
||||
SubModelType,
|
||||
classproperty,
|
||||
ModelNotFoundException,
|
||||
InvalidModelException,
|
||||
)
|
||||
# TODO: naming
|
||||
from ..lora import TextualInversionModel as TextualInversionModelRaw
|
||||
@ -59,7 +60,18 @@ class TextualInversionModel(ModelBase):
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, path: str):
|
||||
return None
|
||||
if not os.path.exists(path):
|
||||
raise ModelNotFoundException()
|
||||
|
||||
if os.path.isdir(path):
|
||||
if os.path.exists(os.path.join(path, "learned_embeds.bin")):
|
||||
return None # diffusers-ti
|
||||
|
||||
if os.path.isfile(path):
|
||||
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
|
||||
return None
|
||||
|
||||
raise InvalidModelException(f"Not a valid model: {path}")
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
|
@ -15,6 +15,7 @@ from .base import (
|
||||
calc_model_size_by_fs,
|
||||
calc_model_size_by_data,
|
||||
classproperty,
|
||||
InvalidModelException,
|
||||
)
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from diffusers.utils import is_safetensors_available
|
||||
@ -75,10 +76,18 @@ class VaeModel(ModelBase):
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, path: str):
|
||||
if not os.path.exists(path):
|
||||
raise ModelNotFoundException()
|
||||
|
||||
if os.path.isdir(path):
|
||||
return VaeModelFormat.Diffusers
|
||||
else:
|
||||
return VaeModelFormat.Checkpoint
|
||||
if os.path.exists(os.path.join(path, "config.json")):
|
||||
return VaeModelFormat.Diffusers
|
||||
|
||||
if os.path.isfile(path):
|
||||
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
|
||||
return VaeModelFormat.Checkpoint
|
||||
|
||||
raise InvalidModelException(f"Not a valid model: {path}")
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
|
@ -127,7 +127,7 @@ class AddsMaskGuidance:
|
||||
|
||||
def _t_for_field(self, field_name: str, t):
|
||||
if field_name == "pred_original_sample":
|
||||
return torch.zeros_like(t, dtype=t.dtype) # it represents t=0
|
||||
return self.scheduler.timesteps[-1]
|
||||
return t
|
||||
|
||||
def apply_mask(self, latents: torch.Tensor, t) -> torch.Tensor:
|
||||
|
@ -108,6 +108,7 @@
|
||||
"roarr": "^7.15.0",
|
||||
"serialize-error": "^11.0.0",
|
||||
"socket.io-client": "^4.7.0",
|
||||
"use-debounce": "^9.0.4",
|
||||
"use-image": "^1.1.1",
|
||||
"uuid": "^9.0.0",
|
||||
"zod": "^3.21.4"
|
||||
|
@ -102,7 +102,8 @@
|
||||
"openInNewTab": "Open in New Tab",
|
||||
"dontAskMeAgain": "Don't ask me again",
|
||||
"areYouSure": "Are you sure?",
|
||||
"imagePrompt": "Image Prompt"
|
||||
"imagePrompt": "Image Prompt",
|
||||
"clearNodes": "Are you sure you want to clear all nodes?"
|
||||
},
|
||||
"gallery": {
|
||||
"generations": "Generations",
|
||||
@ -528,7 +529,7 @@
|
||||
"hidePreview": "Hide Preview",
|
||||
"showPreview": "Show Preview",
|
||||
"controlNetControlMode": "Control Mode",
|
||||
"clipSkip": "Clip Skip",
|
||||
"clipSkip": "CLIP Skip",
|
||||
"aspectRatio": "Ratio"
|
||||
},
|
||||
"settings": {
|
||||
@ -593,7 +594,12 @@
|
||||
"metadataLoadFailed": "Failed to load metadata",
|
||||
"initialImageSet": "Initial Image Set",
|
||||
"initialImageNotSet": "Initial Image Not Set",
|
||||
"initialImageNotSetDesc": "Could not load initial image"
|
||||
"initialImageNotSetDesc": "Could not load initial image",
|
||||
"nodesSaved": "Nodes Saved",
|
||||
"nodesLoaded": "Nodes Loaded",
|
||||
"nodesLoadedFailed": "Failed To Load Nodes",
|
||||
"nodesCleared": "Nodes Cleared"
|
||||
|
||||
},
|
||||
"tooltip": {
|
||||
"feature": {
|
||||
@ -674,5 +680,11 @@
|
||||
"showProgressImages": "Show Progress Images",
|
||||
"hideProgressImages": "Hide Progress Images",
|
||||
"swapSizes": "Swap Sizes"
|
||||
},
|
||||
"nodes": {
|
||||
"reloadSchema": "Reload Schema",
|
||||
"saveNodes": "Save Nodes",
|
||||
"loadNodes": "Load Nodes",
|
||||
"clearNodes": "Clear Nodes"
|
||||
}
|
||||
}
|
||||
|
@ -51,6 +51,7 @@ import {
|
||||
} from './listeners/imageUrlsReceived';
|
||||
import { addInitialImageSelectedListener } from './listeners/initialImageSelected';
|
||||
import { addModelSelectedListener } from './listeners/modelSelected';
|
||||
import { addModelsLoadedListener } from './listeners/modelsLoaded';
|
||||
import { addReceivedOpenAPISchemaListener } from './listeners/receivedOpenAPISchema';
|
||||
import {
|
||||
addReceivedPageOfImagesFulfilledListener,
|
||||
@ -224,3 +225,4 @@ addModelSelectedListener();
|
||||
|
||||
// app startup
|
||||
addAppStartedListener();
|
||||
addModelsLoadedListener();
|
||||
|
@ -1,13 +1,13 @@
|
||||
import { startAppListening } from '..';
|
||||
import { imageMetadataReceived } from 'services/api/thunks/image';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { controlNetImageProcessed } from 'features/controlNet/store/actions';
|
||||
import { Graph } from 'services/api/types';
|
||||
import { sessionCreated } from 'services/api/thunks/session';
|
||||
import { sessionReadyToInvoke } from 'features/system/store/actions';
|
||||
import { socketInvocationComplete } from 'services/events/actions';
|
||||
import { isImageOutput } from 'services/api/guards';
|
||||
import { controlNetProcessedImageChanged } from 'features/controlNet/store/controlNetSlice';
|
||||
import { sessionReadyToInvoke } from 'features/system/store/actions';
|
||||
import { isImageOutput } from 'services/api/guards';
|
||||
import { imageDTOReceived } from 'services/api/thunks/image';
|
||||
import { sessionCreated } from 'services/api/thunks/session';
|
||||
import { Graph } from 'services/api/types';
|
||||
import { socketInvocationComplete } from 'services/events/actions';
|
||||
import { startAppListening } from '..';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'controlNet' });
|
||||
|
||||
@ -63,10 +63,8 @@ export const addControlNetImageProcessedListener = () => {
|
||||
|
||||
// Wait for the ImageDTO to be received
|
||||
const [imageMetadataReceivedAction] = await take(
|
||||
(
|
||||
action
|
||||
): action is ReturnType<typeof imageMetadataReceived.fulfilled> =>
|
||||
imageMetadataReceived.fulfilled.match(action) &&
|
||||
(action): action is ReturnType<typeof imageDTOReceived.fulfilled> =>
|
||||
imageDTOReceived.fulfilled.match(action) &&
|
||||
action.payload.image_name === image_name
|
||||
);
|
||||
const processedControlImage = imageMetadataReceivedAction.payload;
|
||||
|
@ -1,7 +1,7 @@
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { startAppListening } from '..';
|
||||
import { imageMetadataReceived } from 'services/api/thunks/image';
|
||||
import { boardImagesApi } from 'services/api/endpoints/boardImages';
|
||||
import { imageDTOReceived } from 'services/api/thunks/image';
|
||||
import { startAppListening } from '..';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'boards' });
|
||||
|
||||
@ -17,7 +17,7 @@ export const addImageAddedToBoardFulfilledListener = () => {
|
||||
);
|
||||
|
||||
dispatch(
|
||||
imageMetadataReceived({
|
||||
imageDTOReceived({
|
||||
image_name,
|
||||
})
|
||||
);
|
||||
|
@ -1,13 +1,13 @@
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { startAppListening } from '..';
|
||||
import { imageMetadataReceived, imageUpdated } from 'services/api/thunks/image';
|
||||
import { imageUpserted } from 'features/gallery/store/gallerySlice';
|
||||
import { imageDTOReceived, imageUpdated } from 'services/api/thunks/image';
|
||||
import { startAppListening } from '..';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'image' });
|
||||
|
||||
export const addImageMetadataReceivedFulfilledListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: imageMetadataReceived.fulfilled,
|
||||
actionCreator: imageDTOReceived.fulfilled,
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
const image = action.payload;
|
||||
|
||||
@ -40,7 +40,7 @@ export const addImageMetadataReceivedFulfilledListener = () => {
|
||||
|
||||
export const addImageMetadataReceivedRejectedListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: imageMetadataReceived.rejected,
|
||||
actionCreator: imageDTOReceived.rejected,
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
moduleLog.debug(
|
||||
{ data: { image: action.meta.arg } },
|
||||
|
@ -1,7 +1,7 @@
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { startAppListening } from '..';
|
||||
import { imageMetadataReceived } from 'services/api/thunks/image';
|
||||
import { boardImagesApi } from 'services/api/endpoints/boardImages';
|
||||
import { imageDTOReceived } from 'services/api/thunks/image';
|
||||
import { startAppListening } from '..';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'boards' });
|
||||
|
||||
@ -17,7 +17,7 @@ export const addImageRemovedFromBoardFulfilledListener = () => {
|
||||
);
|
||||
|
||||
dispatch(
|
||||
imageMetadataReceived({
|
||||
imageDTOReceived({
|
||||
image_name,
|
||||
})
|
||||
);
|
||||
|
@ -14,7 +14,7 @@ export const addModelSelectedListener = () => {
|
||||
actionCreator: modelSelected,
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
const state = getState();
|
||||
const [base_model, type, name] = action.payload.split('/');
|
||||
const { base_model, model_name } = action.payload;
|
||||
|
||||
if (state.generation.model?.base_model !== base_model) {
|
||||
dispatch(
|
||||
@ -30,11 +30,7 @@ export const addModelSelectedListener = () => {
|
||||
// TODO: controlnet cleared
|
||||
}
|
||||
|
||||
const newModel = zMainModel.parse({
|
||||
id: action.payload,
|
||||
base_model,
|
||||
name,
|
||||
});
|
||||
const newModel = zMainModel.parse(action.payload);
|
||||
|
||||
dispatch(modelChanged(newModel));
|
||||
},
|
||||
|
@ -0,0 +1,42 @@
|
||||
import { modelChanged } from 'features/parameters/store/generationSlice';
|
||||
import { some } from 'lodash-es';
|
||||
import { modelsApi } from 'services/api/endpoints/models';
|
||||
import { startAppListening } from '..';
|
||||
|
||||
export const addModelsLoadedListener = () => {
|
||||
startAppListening({
|
||||
matcher: modelsApi.endpoints.getMainModels.matchFulfilled,
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
// models loaded, we need to ensure the selected model is available and if not, select the first one
|
||||
|
||||
const currentModel = getState().generation.model;
|
||||
|
||||
const isCurrentModelAvailable = some(
|
||||
action.payload.entities,
|
||||
(m) =>
|
||||
m?.model_name === currentModel?.model_name &&
|
||||
m?.base_model === currentModel?.base_model
|
||||
);
|
||||
|
||||
if (isCurrentModelAvailable) {
|
||||
return;
|
||||
}
|
||||
|
||||
const firstModelId = action.payload.ids[0];
|
||||
const firstModel = action.payload.entities[firstModelId];
|
||||
|
||||
if (!firstModel) {
|
||||
// No models loaded at all
|
||||
dispatch(modelChanged(null));
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(
|
||||
modelChanged({
|
||||
base_model: firstModel.base_model,
|
||||
model_name: firstModel.model_name,
|
||||
})
|
||||
);
|
||||
},
|
||||
});
|
||||
};
|
@ -30,6 +30,7 @@ export const addSessionCreatedRejectedListener = () => {
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
if (action.payload) {
|
||||
const { arg, error } = action.payload;
|
||||
const stringifiedError = JSON.stringify(error);
|
||||
moduleLog.error(
|
||||
{
|
||||
data: {
|
||||
@ -37,7 +38,7 @@ export const addSessionCreatedRejectedListener = () => {
|
||||
error: serializeError(error),
|
||||
},
|
||||
},
|
||||
`Problem creating session`
|
||||
`Problem creating session: ${stringifiedError}`
|
||||
);
|
||||
}
|
||||
},
|
||||
|
@ -33,6 +33,7 @@ export const addSessionInvokedRejectedListener = () => {
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
if (action.payload) {
|
||||
const { arg, error } = action.payload;
|
||||
const stringifiedError = JSON.stringify(error);
|
||||
moduleLog.error(
|
||||
{
|
||||
data: {
|
||||
@ -40,7 +41,7 @@ export const addSessionInvokedRejectedListener = () => {
|
||||
error: serializeError(error),
|
||||
},
|
||||
},
|
||||
`Problem invoking session`
|
||||
`Problem invoking session: ${stringifiedError}`
|
||||
);
|
||||
}
|
||||
},
|
||||
|
@ -1,15 +1,15 @@
|
||||
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
|
||||
import { startAppListening } from '../..';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
|
||||
import { progressImageSet } from 'features/system/store/systemSlice';
|
||||
import { boardImagesApi } from 'services/api/endpoints/boardImages';
|
||||
import { isImageOutput } from 'services/api/guards';
|
||||
import { imageDTOReceived } from 'services/api/thunks/image';
|
||||
import { sessionCanceled } from 'services/api/thunks/session';
|
||||
import {
|
||||
appSocketInvocationComplete,
|
||||
socketInvocationComplete,
|
||||
} from 'services/events/actions';
|
||||
import { imageMetadataReceived } from 'services/api/thunks/image';
|
||||
import { sessionCanceled } from 'services/api/thunks/session';
|
||||
import { isImageOutput } from 'services/api/guards';
|
||||
import { progressImageSet } from 'features/system/store/systemSlice';
|
||||
import { boardImagesApi } from 'services/api/endpoints/boardImages';
|
||||
import { startAppListening } from '../..';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'socketio' });
|
||||
const nodeDenylist = ['dataURL_image'];
|
||||
@ -42,13 +42,13 @@ export const addInvocationCompleteEventListener = () => {
|
||||
|
||||
// Get its metadata
|
||||
dispatch(
|
||||
imageMetadataReceived({
|
||||
imageDTOReceived({
|
||||
image_name,
|
||||
})
|
||||
);
|
||||
|
||||
const [{ payload: imageDTO }] = await take(
|
||||
imageMetadataReceived.fulfilled.match
|
||||
imageDTOReceived.fulfilled.match
|
||||
);
|
||||
|
||||
// Handle canvas image
|
||||
|
@ -13,7 +13,7 @@ export const addInvocationErrorEventListener = () => {
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
moduleLog.error(
|
||||
action.payload,
|
||||
`Invocation error (${action.payload.data.node.type})`
|
||||
`Invocation error (${action.payload.data.node.type}): ${action.payload.data.error}`
|
||||
);
|
||||
dispatch(appSocketInvocationError(action.payload));
|
||||
},
|
||||
|
@ -1,6 +1,7 @@
|
||||
import {
|
||||
AnyAction,
|
||||
ThunkDispatch,
|
||||
autoBatchEnhancer,
|
||||
combineReducers,
|
||||
configureStore,
|
||||
} from '@reduxjs/toolkit';
|
||||
@ -79,14 +80,18 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
|
||||
|
||||
export const store = configureStore({
|
||||
reducer: rememberedRootReducer,
|
||||
enhancers: [
|
||||
rememberEnhancer(window.localStorage, rememberedKeys, {
|
||||
persistDebounce: 300,
|
||||
serialize,
|
||||
unserialize,
|
||||
prefix: LOCALSTORAGE_PREFIX,
|
||||
}),
|
||||
],
|
||||
enhancers: (existingEnhancers) => {
|
||||
return existingEnhancers
|
||||
.concat(
|
||||
rememberEnhancer(window.localStorage, rememberedKeys, {
|
||||
persistDebounce: 300,
|
||||
serialize,
|
||||
unserialize,
|
||||
prefix: LOCALSTORAGE_PREFIX,
|
||||
})
|
||||
)
|
||||
.concat(autoBatchEnhancer());
|
||||
},
|
||||
middleware: (getDefaultMiddleware) =>
|
||||
getDefaultMiddleware({
|
||||
immutableCheck: false,
|
||||
|
@ -102,6 +102,8 @@ export type AppFeature =
|
||||
export type SDFeature =
|
||||
| 'controlNet'
|
||||
| 'noise'
|
||||
| 'perlinNoise'
|
||||
| 'noiseThreshold'
|
||||
| 'variation'
|
||||
| 'symmetry'
|
||||
| 'seamless'
|
||||
|
@ -12,6 +12,7 @@ import {
|
||||
setIsMovingBoundingBox,
|
||||
setIsTransformingBoundingBox,
|
||||
} from 'features/canvas/store/canvasSlice';
|
||||
import { uiSelector } from 'features/ui/store/uiSelectors';
|
||||
import Konva from 'konva';
|
||||
import { GroupConfig } from 'konva/lib/Group';
|
||||
import { KonvaEventObject } from 'konva/lib/Node';
|
||||
@ -22,8 +23,8 @@ import { useCallback, useEffect, useRef, useState } from 'react';
|
||||
import { Group, Rect, Transformer } from 'react-konva';
|
||||
|
||||
const boundingBoxPreviewSelector = createSelector(
|
||||
canvasSelector,
|
||||
(canvas) => {
|
||||
[canvasSelector, uiSelector],
|
||||
(canvas, ui) => {
|
||||
const {
|
||||
boundingBoxCoordinates,
|
||||
boundingBoxDimensions,
|
||||
@ -35,6 +36,8 @@ const boundingBoxPreviewSelector = createSelector(
|
||||
shouldSnapToGrid,
|
||||
} = canvas;
|
||||
|
||||
const { aspectRatio } = ui;
|
||||
|
||||
return {
|
||||
boundingBoxCoordinates,
|
||||
boundingBoxDimensions,
|
||||
@ -45,6 +48,7 @@ const boundingBoxPreviewSelector = createSelector(
|
||||
shouldSnapToGrid,
|
||||
tool,
|
||||
hitStrokeWidth: 20 / stageScale,
|
||||
aspectRatio,
|
||||
};
|
||||
},
|
||||
{
|
||||
@ -70,6 +74,7 @@ const IAICanvasBoundingBox = (props: IAICanvasBoundingBoxPreviewProps) => {
|
||||
shouldSnapToGrid,
|
||||
tool,
|
||||
hitStrokeWidth,
|
||||
aspectRatio,
|
||||
} = useAppSelector(boundingBoxPreviewSelector);
|
||||
|
||||
const transformerRef = useRef<Konva.Transformer>(null);
|
||||
@ -137,12 +142,22 @@ const IAICanvasBoundingBox = (props: IAICanvasBoundingBoxPreviewProps) => {
|
||||
const x = Math.round(rect.x());
|
||||
const y = Math.round(rect.y());
|
||||
|
||||
dispatch(
|
||||
setBoundingBoxDimensions({
|
||||
width,
|
||||
height,
|
||||
})
|
||||
);
|
||||
if (aspectRatio) {
|
||||
const newHeight = roundToMultiple(width / aspectRatio, 64);
|
||||
dispatch(
|
||||
setBoundingBoxDimensions({
|
||||
width: width,
|
||||
height: newHeight,
|
||||
})
|
||||
);
|
||||
} else {
|
||||
dispatch(
|
||||
setBoundingBoxDimensions({
|
||||
width,
|
||||
height,
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
dispatch(
|
||||
setBoundingBoxCoordinates({
|
||||
@ -154,7 +169,7 @@ const IAICanvasBoundingBox = (props: IAICanvasBoundingBoxPreviewProps) => {
|
||||
// Reset the scale now that the coords/dimensions have been un-scaled
|
||||
rect.scaleX(1);
|
||||
rect.scaleY(1);
|
||||
}, [dispatch, shouldSnapToGrid]);
|
||||
}, [dispatch, shouldSnapToGrid, aspectRatio]);
|
||||
|
||||
const anchorDragBoundFunc = useCallback(
|
||||
(
|
||||
|
@ -7,7 +7,14 @@ import {
|
||||
import { IRect, Vector2d } from 'konva/lib/types';
|
||||
import { clamp, cloneDeep } from 'lodash-es';
|
||||
//
|
||||
import {
|
||||
setActiveTab,
|
||||
setAspectRatio,
|
||||
setShouldUseCanvasBetaLayout,
|
||||
} from 'features/ui/store/uiSlice';
|
||||
import { RgbaColor } from 'react-colorful';
|
||||
import { sessionCanceled } from 'services/api/thunks/session';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import calculateCoordinates from '../util/calculateCoordinates';
|
||||
import calculateScale from '../util/calculateScale';
|
||||
import { STAGE_PADDING_PERCENTAGE } from '../util/constants';
|
||||
@ -28,13 +35,6 @@ import {
|
||||
isCanvasBaseImage,
|
||||
isCanvasMaskLine,
|
||||
} from './canvasTypes';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import { sessionCanceled } from 'services/api/thunks/session';
|
||||
import {
|
||||
setActiveTab,
|
||||
setShouldUseCanvasBetaLayout,
|
||||
} from 'features/ui/store/uiSlice';
|
||||
import { imageUrlsReceived } from 'services/api/thunks/image';
|
||||
|
||||
export const initialLayerState: CanvasLayerState = {
|
||||
objects: [],
|
||||
@ -240,6 +240,16 @@ export const canvasSlice = createSlice({
|
||||
state.scaledBoundingBoxDimensions = scaledDimensions;
|
||||
}
|
||||
},
|
||||
flipBoundingBoxAxes: (state) => {
|
||||
const [currWidth, currHeight] = [
|
||||
state.boundingBoxDimensions.width,
|
||||
state.boundingBoxDimensions.height,
|
||||
];
|
||||
state.boundingBoxDimensions = {
|
||||
width: currHeight,
|
||||
height: currWidth,
|
||||
};
|
||||
},
|
||||
setBoundingBoxCoordinates: (state, action: PayloadAction<Vector2d>) => {
|
||||
state.boundingBoxCoordinates = floorCoordinates(action.payload);
|
||||
},
|
||||
@ -864,6 +874,15 @@ export const canvasSlice = createSlice({
|
||||
builder.addCase(setActiveTab, (state, action) => {
|
||||
state.doesCanvasNeedScaling = true;
|
||||
});
|
||||
builder.addCase(setAspectRatio, (state, action) => {
|
||||
const ratio = action.payload;
|
||||
if (ratio) {
|
||||
state.boundingBoxDimensions.height = roundToMultiple(
|
||||
state.boundingBoxDimensions.width / ratio,
|
||||
64
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
// builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
||||
// const { image_name, image_url, thumbnail_url } = action.payload;
|
||||
@ -912,6 +931,7 @@ export const {
|
||||
setBoundingBoxDimensions,
|
||||
setBoundingBoxPreviewFill,
|
||||
setBoundingBoxScaleMethod,
|
||||
flipBoundingBoxAxes,
|
||||
setBrushColor,
|
||||
setBrushSize,
|
||||
setCanvasContainerDimensions,
|
||||
|
@ -47,8 +47,8 @@ const ParamEmbeddingPopover = (props: Props) => {
|
||||
const disabled = currentMainModel?.base_model !== embedding.base_model;
|
||||
|
||||
data.push({
|
||||
value: embedding.name,
|
||||
label: embedding.name,
|
||||
value: embedding.model_name,
|
||||
label: embedding.model_name,
|
||||
group: MODEL_TYPE_MAP[embedding.base_model],
|
||||
disabled,
|
||||
tooltip: disabled
|
||||
|
@ -45,7 +45,11 @@ import {
|
||||
FaShare,
|
||||
FaShareAlt,
|
||||
} from 'react-icons/fa';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
import {
|
||||
useGetImageDTOQuery,
|
||||
useGetImageMetadataQuery,
|
||||
} from 'services/api/endpoints/images';
|
||||
import { useDebounce } from 'use-debounce';
|
||||
import { sentImageToCanvas, sentImageToImg2Img } from '../store/actions';
|
||||
|
||||
const currentImageButtonsSelector = createSelector(
|
||||
@ -128,10 +132,23 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
const { recallBothPrompts, recallSeed, recallAllParameters } =
|
||||
useRecallParameters();
|
||||
|
||||
const { currentData: image } = useGetImageDTOQuery(
|
||||
const [debouncedMetadataQueryArg, debounceState] = useDebounce(
|
||||
lastSelectedImage,
|
||||
500
|
||||
);
|
||||
|
||||
const { currentData: image, isFetching } = useGetImageDTOQuery(
|
||||
lastSelectedImage ?? skipToken
|
||||
);
|
||||
|
||||
const { currentData: metadataData } = useGetImageMetadataQuery(
|
||||
debounceState.isPending()
|
||||
? skipToken
|
||||
: debouncedMetadataQueryArg ?? skipToken
|
||||
);
|
||||
|
||||
const metadata = metadataData?.metadata;
|
||||
|
||||
// const handleCopyImage = useCallback(async () => {
|
||||
// if (!image?.url) {
|
||||
// return;
|
||||
@ -193,29 +210,26 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
}, [toaster, t, image]);
|
||||
|
||||
const handleClickUseAllParameters = useCallback(() => {
|
||||
recallAllParameters(image);
|
||||
}, [image, recallAllParameters]);
|
||||
recallAllParameters(metadata);
|
||||
}, [metadata, recallAllParameters]);
|
||||
|
||||
useHotkeys(
|
||||
'a',
|
||||
() => {
|
||||
handleClickUseAllParameters;
|
||||
},
|
||||
[image, recallAllParameters]
|
||||
[metadata, recallAllParameters]
|
||||
);
|
||||
|
||||
const handleUseSeed = useCallback(() => {
|
||||
recallSeed(image?.metadata?.seed);
|
||||
}, [image, recallSeed]);
|
||||
recallSeed(metadata?.seed);
|
||||
}, [metadata?.seed, recallSeed]);
|
||||
|
||||
useHotkeys('s', handleUseSeed, [image]);
|
||||
|
||||
const handleUsePrompt = useCallback(() => {
|
||||
recallBothPrompts(
|
||||
image?.metadata?.positive_conditioning,
|
||||
image?.metadata?.negative_conditioning
|
||||
);
|
||||
}, [image, recallBothPrompts]);
|
||||
recallBothPrompts(metadata?.positive_prompt, metadata?.negative_prompt);
|
||||
}, [metadata?.negative_prompt, metadata?.positive_prompt, recallBothPrompts]);
|
||||
|
||||
useHotkeys('p', handleUsePrompt, [image]);
|
||||
|
||||
@ -440,7 +454,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
icon={<FaQuoteRight />}
|
||||
tooltip={`${t('parameters.usePrompt')} (P)`}
|
||||
aria-label={`${t('parameters.usePrompt')} (P)`}
|
||||
isDisabled={!image?.metadata?.positive_conditioning}
|
||||
isDisabled={!metadata?.positive_prompt}
|
||||
onClick={handleUsePrompt}
|
||||
/>
|
||||
|
||||
@ -448,7 +462,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
icon={<FaSeedling />}
|
||||
tooltip={`${t('parameters.useSeed')} (S)`}
|
||||
aria-label={`${t('parameters.useSeed')} (S)`}
|
||||
isDisabled={!image?.metadata?.seed}
|
||||
isDisabled={!metadata?.seed}
|
||||
onClick={handleUseSeed}
|
||||
/>
|
||||
|
||||
@ -456,10 +470,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
icon={<FaAsterisk />}
|
||||
tooltip={`${t('parameters.useAll')} (A)`}
|
||||
aria-label={`${t('parameters.useAll')} (A)`}
|
||||
isDisabled={
|
||||
// not sure what this list should be
|
||||
!['t2l', 'l2l', 'inpaint'].includes(String(image?.metadata?.type))
|
||||
}
|
||||
isDisabled={!metadata}
|
||||
onClick={handleClickUseAllParameters}
|
||||
/>
|
||||
</ButtonGroup>
|
||||
|
@ -11,7 +11,9 @@ import IAIDndImage from 'common/components/IAIDndImage';
|
||||
import { selectLastSelectedImage } from 'features/gallery/store/gallerySlice';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
import { useNextPrevImage } from '../hooks/useNextPrevImage';
|
||||
import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer';
|
||||
import NextPrevImageButtons from './NextPrevImageButtons';
|
||||
|
||||
@ -49,6 +51,45 @@ const CurrentImagePreview = () => {
|
||||
shouldAntialiasProgressImage,
|
||||
} = useAppSelector(imagesSelector);
|
||||
|
||||
const {
|
||||
handlePrevImage,
|
||||
handleNextImage,
|
||||
prevImageId,
|
||||
nextImageId,
|
||||
isOnLastImage,
|
||||
handleLoadMoreImages,
|
||||
areMoreImagesAvailable,
|
||||
isFetching,
|
||||
} = useNextPrevImage();
|
||||
|
||||
useHotkeys(
|
||||
'left',
|
||||
() => {
|
||||
handlePrevImage();
|
||||
},
|
||||
[prevImageId]
|
||||
);
|
||||
|
||||
useHotkeys(
|
||||
'right',
|
||||
() => {
|
||||
if (isOnLastImage && areMoreImagesAvailable && !isFetching) {
|
||||
handleLoadMoreImages();
|
||||
return;
|
||||
}
|
||||
if (!isOnLastImage) {
|
||||
handleNextImage();
|
||||
}
|
||||
},
|
||||
[
|
||||
nextImageId,
|
||||
isOnLastImage,
|
||||
areMoreImagesAvailable,
|
||||
handleLoadMoreImages,
|
||||
isFetching,
|
||||
]
|
||||
);
|
||||
|
||||
const {
|
||||
currentData: imageDTO,
|
||||
isLoading,
|
||||
@ -118,7 +159,6 @@ const CurrentImagePreview = () => {
|
||||
width: 'full',
|
||||
height: 'full',
|
||||
borderRadius: 'base',
|
||||
overflow: 'scroll',
|
||||
}}
|
||||
>
|
||||
<ImageMetadataViewer image={imageDTO} />
|
||||
|
@ -6,10 +6,7 @@ import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { ContextMenu, ContextMenuProps } from 'chakra-ui-contextmenu';
|
||||
import {
|
||||
imagesAddedToBatch,
|
||||
selectionAddedToBatch,
|
||||
} from 'features/batch/store/batchSlice';
|
||||
import { imagesAddedToBatch } from 'features/batch/store/batchSlice';
|
||||
import {
|
||||
resizeAndScaleCanvas,
|
||||
setInitialCanvasImage,
|
||||
@ -24,6 +21,7 @@ import { useTranslation } from 'react-i18next';
|
||||
import { FaExpand, FaFolder, FaShare, FaTrash } from 'react-icons/fa';
|
||||
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
|
||||
import { useRemoveImageFromBoardMutation } from 'services/api/endpoints/boardImages';
|
||||
import { useGetImageMetadataQuery } from 'services/api/endpoints/images';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import { AddImageToBoardContext } from '../../../app/contexts/AddImageToBoardContext';
|
||||
import { sentImageToCanvas, sentImageToImg2Img } from '../store/actions';
|
||||
@ -38,24 +36,17 @@ const ImageContextMenu = ({ image, children }: Props) => {
|
||||
() =>
|
||||
createSelector(
|
||||
[stateSelector],
|
||||
({ gallery, batch }) => {
|
||||
({ gallery }) => {
|
||||
const selectionCount = gallery.selection.length;
|
||||
const isInBatch = batch.imageNames.includes(image.image_name);
|
||||
|
||||
return { selectionCount, isInBatch };
|
||||
return { selectionCount };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
[image.image_name]
|
||||
[]
|
||||
);
|
||||
const { selectionCount, isInBatch } = useAppSelector(selector);
|
||||
const { selectionCount } = useAppSelector(selector);
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const toaster = useAppToaster();
|
||||
|
||||
const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled;
|
||||
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
|
||||
|
||||
const { onClickAddToBoard } = useContext(AddImageToBoardContext);
|
||||
|
||||
@ -66,178 +57,17 @@ const ImageContextMenu = ({ image, children }: Props) => {
|
||||
dispatch(imageToDeleteSelected(image));
|
||||
}, [dispatch, image]);
|
||||
|
||||
const { recallBothPrompts, recallSeed, recallAllParameters } =
|
||||
useRecallParameters();
|
||||
|
||||
const [removeFromBoard] = useRemoveImageFromBoardMutation();
|
||||
|
||||
// Recall parameters handlers
|
||||
const handleRecallPrompt = useCallback(() => {
|
||||
recallBothPrompts(
|
||||
image.metadata?.positive_conditioning,
|
||||
image.metadata?.negative_conditioning
|
||||
);
|
||||
}, [
|
||||
image.metadata?.negative_conditioning,
|
||||
image.metadata?.positive_conditioning,
|
||||
recallBothPrompts,
|
||||
]);
|
||||
|
||||
const handleRecallSeed = useCallback(() => {
|
||||
recallSeed(image.metadata?.seed);
|
||||
}, [image, recallSeed]);
|
||||
|
||||
const handleSendToImageToImage = useCallback(() => {
|
||||
dispatch(sentImageToImg2Img());
|
||||
dispatch(initialImageSelected(image));
|
||||
}, [dispatch, image]);
|
||||
|
||||
// const handleRecallInitialImage = useCallback(() => {
|
||||
// recallInitialImage(image.metadata.invokeai?.node?.image);
|
||||
// }, [image, recallInitialImage]);
|
||||
|
||||
const handleSendToCanvas = () => {
|
||||
dispatch(sentImageToCanvas());
|
||||
dispatch(setInitialCanvasImage(image));
|
||||
dispatch(resizeAndScaleCanvas());
|
||||
dispatch(setActiveTab('unifiedCanvas'));
|
||||
|
||||
toaster({
|
||||
title: t('toast.sentToUnifiedCanvas'),
|
||||
status: 'success',
|
||||
duration: 2500,
|
||||
isClosable: true,
|
||||
});
|
||||
};
|
||||
|
||||
const handleUseAllParameters = useCallback(() => {
|
||||
recallAllParameters(image);
|
||||
}, [image, recallAllParameters]);
|
||||
|
||||
const handleLightBox = () => {
|
||||
// dispatch(setCurrentImage(image));
|
||||
// dispatch(setIsLightboxOpen(true));
|
||||
};
|
||||
|
||||
const handleAddToBoard = useCallback(() => {
|
||||
onClickAddToBoard(image);
|
||||
}, [image, onClickAddToBoard]);
|
||||
|
||||
const handleRemoveFromBoard = useCallback(() => {
|
||||
if (!image.board_id) {
|
||||
return;
|
||||
}
|
||||
removeFromBoard({ board_id: image.board_id, image_name: image.image_name });
|
||||
}, [image.board_id, image.image_name, removeFromBoard]);
|
||||
|
||||
const handleOpenInNewTab = () => {
|
||||
window.open(image.image_url, '_blank');
|
||||
};
|
||||
|
||||
const handleAddSelectionToBatch = useCallback(() => {
|
||||
dispatch(selectionAddedToBatch());
|
||||
}, [dispatch]);
|
||||
|
||||
const handleAddToBatch = useCallback(() => {
|
||||
dispatch(imagesAddedToBatch([image.image_name]));
|
||||
}, [dispatch, image.image_name]);
|
||||
|
||||
return (
|
||||
<ContextMenu<HTMLDivElement>
|
||||
menuProps={{ size: 'sm', isLazy: true }}
|
||||
renderMenu={() => (
|
||||
<MenuList sx={{ visibility: 'visible !important' }}>
|
||||
{selectionCount === 1 ? (
|
||||
<>
|
||||
<MenuItem
|
||||
icon={<ExternalLinkIcon />}
|
||||
onClickCapture={handleOpenInNewTab}
|
||||
>
|
||||
{t('common.openInNewTab')}
|
||||
</MenuItem>
|
||||
{isLightboxEnabled && (
|
||||
<MenuItem icon={<FaExpand />} onClickCapture={handleLightBox}>
|
||||
{t('parameters.openInViewer')}
|
||||
</MenuItem>
|
||||
)}
|
||||
<MenuItem
|
||||
icon={<IoArrowUndoCircleOutline />}
|
||||
onClickCapture={handleRecallPrompt}
|
||||
isDisabled={
|
||||
image?.metadata?.positive_conditioning === undefined
|
||||
}
|
||||
>
|
||||
{t('parameters.usePrompt')}
|
||||
</MenuItem>
|
||||
|
||||
<MenuItem
|
||||
icon={<IoArrowUndoCircleOutline />}
|
||||
onClickCapture={handleRecallSeed}
|
||||
isDisabled={image?.metadata?.seed === undefined}
|
||||
>
|
||||
{t('parameters.useSeed')}
|
||||
</MenuItem>
|
||||
{/* <MenuItem
|
||||
icon={<IoArrowUndoCircleOutline />}
|
||||
onClickCapture={handleRecallInitialImage}
|
||||
isDisabled={image?.metadata?.type !== 'img2img'}
|
||||
>
|
||||
{t('parameters.useInitImg')}
|
||||
</MenuItem> */}
|
||||
<MenuItem
|
||||
icon={<IoArrowUndoCircleOutline />}
|
||||
onClickCapture={handleUseAllParameters}
|
||||
isDisabled={
|
||||
// what should these be
|
||||
!['t2l', 'l2l', 'inpaint'].includes(
|
||||
String(image?.metadata?.type)
|
||||
)
|
||||
}
|
||||
>
|
||||
{t('parameters.useAll')}
|
||||
</MenuItem>
|
||||
<MenuItem
|
||||
icon={<FaShare />}
|
||||
onClickCapture={handleSendToImageToImage}
|
||||
id="send-to-img2img"
|
||||
>
|
||||
{t('parameters.sendToImg2Img')}
|
||||
</MenuItem>
|
||||
{isCanvasEnabled && (
|
||||
<MenuItem
|
||||
icon={<FaShare />}
|
||||
onClickCapture={handleSendToCanvas}
|
||||
id="send-to-canvas"
|
||||
>
|
||||
{t('parameters.sendToUnifiedCanvas')}
|
||||
</MenuItem>
|
||||
)}
|
||||
{/* <MenuItem
|
||||
icon={<FaFolder />}
|
||||
isDisabled={isInBatch}
|
||||
onClickCapture={handleAddToBatch}
|
||||
>
|
||||
Add to Batch
|
||||
</MenuItem> */}
|
||||
<MenuItem icon={<FaFolder />} onClickCapture={handleAddToBoard}>
|
||||
{image.board_id ? 'Change Board' : 'Add to Board'}
|
||||
</MenuItem>
|
||||
{image.board_id && (
|
||||
<MenuItem
|
||||
icon={<FaFolder />}
|
||||
onClickCapture={handleRemoveFromBoard}
|
||||
>
|
||||
Remove from Board
|
||||
</MenuItem>
|
||||
)}
|
||||
<MenuItem
|
||||
sx={{ color: 'error.600', _dark: { color: 'error.300' } }}
|
||||
icon={<FaTrash />}
|
||||
onClickCapture={handleDelete}
|
||||
>
|
||||
{t('gallery.deleteImage')}
|
||||
</MenuItem>
|
||||
</>
|
||||
<SingleSelectionMenuItems image={image} />
|
||||
) : (
|
||||
<>
|
||||
<MenuItem
|
||||
@ -271,3 +101,185 @@ const ImageContextMenu = ({ image, children }: Props) => {
|
||||
};
|
||||
|
||||
export default memo(ImageContextMenu);
|
||||
|
||||
type SingleSelectionMenuItemsProps = {
|
||||
image: ImageDTO;
|
||||
};
|
||||
|
||||
const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
||||
const { image } = props;
|
||||
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
[stateSelector],
|
||||
({ batch }) => {
|
||||
const isInBatch = batch.imageNames.includes(image.image_name);
|
||||
|
||||
return { isInBatch };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
[image.image_name]
|
||||
);
|
||||
|
||||
const { isInBatch } = useAppSelector(selector);
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const toaster = useAppToaster();
|
||||
|
||||
const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled;
|
||||
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
|
||||
|
||||
const { onClickAddToBoard } = useContext(AddImageToBoardContext);
|
||||
|
||||
const { currentData } = useGetImageMetadataQuery(image.image_name);
|
||||
|
||||
const metadata = currentData?.metadata;
|
||||
|
||||
const handleDelete = useCallback(() => {
|
||||
if (!image) {
|
||||
return;
|
||||
}
|
||||
dispatch(imageToDeleteSelected(image));
|
||||
}, [dispatch, image]);
|
||||
|
||||
const { recallBothPrompts, recallSeed, recallAllParameters } =
|
||||
useRecallParameters();
|
||||
|
||||
const [removeFromBoard] = useRemoveImageFromBoardMutation();
|
||||
|
||||
// Recall parameters handlers
|
||||
const handleRecallPrompt = useCallback(() => {
|
||||
recallBothPrompts(metadata?.positive_prompt, metadata?.negative_prompt);
|
||||
}, [metadata?.negative_prompt, metadata?.positive_prompt, recallBothPrompts]);
|
||||
|
||||
const handleRecallSeed = useCallback(() => {
|
||||
recallSeed(metadata?.seed);
|
||||
}, [metadata?.seed, recallSeed]);
|
||||
|
||||
const handleSendToImageToImage = useCallback(() => {
|
||||
dispatch(sentImageToImg2Img());
|
||||
dispatch(initialImageSelected(image));
|
||||
}, [dispatch, image]);
|
||||
|
||||
const handleSendToCanvas = () => {
|
||||
dispatch(sentImageToCanvas());
|
||||
dispatch(setInitialCanvasImage(image));
|
||||
dispatch(resizeAndScaleCanvas());
|
||||
dispatch(setActiveTab('unifiedCanvas'));
|
||||
|
||||
toaster({
|
||||
title: t('toast.sentToUnifiedCanvas'),
|
||||
status: 'success',
|
||||
duration: 2500,
|
||||
isClosable: true,
|
||||
});
|
||||
};
|
||||
|
||||
const handleUseAllParameters = useCallback(() => {
|
||||
console.log(metadata);
|
||||
recallAllParameters(metadata);
|
||||
}, [metadata, recallAllParameters]);
|
||||
|
||||
const handleLightBox = () => {
|
||||
// dispatch(setCurrentImage(image));
|
||||
// dispatch(setIsLightboxOpen(true));
|
||||
};
|
||||
|
||||
const handleAddToBoard = useCallback(() => {
|
||||
onClickAddToBoard(image);
|
||||
}, [image, onClickAddToBoard]);
|
||||
|
||||
const handleRemoveFromBoard = useCallback(() => {
|
||||
if (!image.board_id) {
|
||||
return;
|
||||
}
|
||||
removeFromBoard({ board_id: image.board_id, image_name: image.image_name });
|
||||
}, [image.board_id, image.image_name, removeFromBoard]);
|
||||
|
||||
const handleOpenInNewTab = () => {
|
||||
window.open(image.image_url, '_blank');
|
||||
};
|
||||
|
||||
const handleAddToBatch = useCallback(() => {
|
||||
dispatch(imagesAddedToBatch([image.image_name]));
|
||||
}, [dispatch, image.image_name]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<MenuItem icon={<ExternalLinkIcon />} onClickCapture={handleOpenInNewTab}>
|
||||
{t('common.openInNewTab')}
|
||||
</MenuItem>
|
||||
{isLightboxEnabled && (
|
||||
<MenuItem icon={<FaExpand />} onClickCapture={handleLightBox}>
|
||||
{t('parameters.openInViewer')}
|
||||
</MenuItem>
|
||||
)}
|
||||
<MenuItem
|
||||
icon={<IoArrowUndoCircleOutline />}
|
||||
onClickCapture={handleRecallPrompt}
|
||||
isDisabled={
|
||||
metadata?.positive_prompt === undefined &&
|
||||
metadata?.negative_prompt === undefined
|
||||
}
|
||||
>
|
||||
{t('parameters.usePrompt')}
|
||||
</MenuItem>
|
||||
|
||||
<MenuItem
|
||||
icon={<IoArrowUndoCircleOutline />}
|
||||
onClickCapture={handleRecallSeed}
|
||||
isDisabled={metadata?.seed === undefined}
|
||||
>
|
||||
{t('parameters.useSeed')}
|
||||
</MenuItem>
|
||||
<MenuItem
|
||||
icon={<IoArrowUndoCircleOutline />}
|
||||
onClickCapture={handleUseAllParameters}
|
||||
isDisabled={!metadata}
|
||||
>
|
||||
{t('parameters.useAll')}
|
||||
</MenuItem>
|
||||
<MenuItem
|
||||
icon={<FaShare />}
|
||||
onClickCapture={handleSendToImageToImage}
|
||||
id="send-to-img2img"
|
||||
>
|
||||
{t('parameters.sendToImg2Img')}
|
||||
</MenuItem>
|
||||
{isCanvasEnabled && (
|
||||
<MenuItem
|
||||
icon={<FaShare />}
|
||||
onClickCapture={handleSendToCanvas}
|
||||
id="send-to-canvas"
|
||||
>
|
||||
{t('parameters.sendToUnifiedCanvas')}
|
||||
</MenuItem>
|
||||
)}
|
||||
<MenuItem
|
||||
icon={<FaFolder />}
|
||||
isDisabled={isInBatch}
|
||||
onClickCapture={handleAddToBatch}
|
||||
>
|
||||
Add to Batch
|
||||
</MenuItem>
|
||||
<MenuItem icon={<FaFolder />} onClickCapture={handleAddToBoard}>
|
||||
{image.board_id ? 'Change Board' : 'Add to Board'}
|
||||
</MenuItem>
|
||||
{image.board_id && (
|
||||
<MenuItem icon={<FaFolder />} onClickCapture={handleRemoveFromBoard}>
|
||||
Remove from Board
|
||||
</MenuItem>
|
||||
)}
|
||||
<MenuItem
|
||||
sx={{ color: 'error.600', _dark: { color: 'error.300' } }}
|
||||
icon={<FaTrash />}
|
||||
onClickCapture={handleDelete}
|
||||
>
|
||||
{t('gallery.deleteImage')}
|
||||
</MenuItem>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
@ -0,0 +1,212 @@
|
||||
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
||||
import { useCallback } from 'react';
|
||||
import { UnsafeImageMetadata } from 'services/api/endpoints/images';
|
||||
import MetadataItem from './MetadataItem';
|
||||
|
||||
type Props = {
|
||||
metadata?: UnsafeImageMetadata['metadata'];
|
||||
};
|
||||
|
||||
const ImageMetadataActions = (props: Props) => {
|
||||
const { metadata } = props;
|
||||
|
||||
const {
|
||||
recallBothPrompts,
|
||||
recallPositivePrompt,
|
||||
recallNegativePrompt,
|
||||
recallSeed,
|
||||
recallInitialImage,
|
||||
recallCfgScale,
|
||||
recallModel,
|
||||
recallScheduler,
|
||||
recallSteps,
|
||||
recallWidth,
|
||||
recallHeight,
|
||||
recallStrength,
|
||||
recallAllParameters,
|
||||
} = useRecallParameters();
|
||||
|
||||
const handleRecallPositivePrompt = useCallback(() => {
|
||||
recallPositivePrompt(metadata?.positive_prompt);
|
||||
}, [metadata?.positive_prompt, recallPositivePrompt]);
|
||||
|
||||
const handleRecallNegativePrompt = useCallback(() => {
|
||||
recallNegativePrompt(metadata?.negative_prompt);
|
||||
}, [metadata?.negative_prompt, recallNegativePrompt]);
|
||||
|
||||
const handleRecallSeed = useCallback(() => {
|
||||
recallSeed(metadata?.seed);
|
||||
}, [metadata?.seed, recallSeed]);
|
||||
|
||||
const handleRecallModel = useCallback(() => {
|
||||
recallModel(metadata?.model);
|
||||
}, [metadata?.model, recallModel]);
|
||||
|
||||
const handleRecallWidth = useCallback(() => {
|
||||
recallWidth(metadata?.width);
|
||||
}, [metadata?.width, recallWidth]);
|
||||
|
||||
const handleRecallHeight = useCallback(() => {
|
||||
recallHeight(metadata?.height);
|
||||
}, [metadata?.height, recallHeight]);
|
||||
|
||||
const handleRecallScheduler = useCallback(() => {
|
||||
recallScheduler(metadata?.scheduler);
|
||||
}, [metadata?.scheduler, recallScheduler]);
|
||||
|
||||
const handleRecallSteps = useCallback(() => {
|
||||
recallSteps(metadata?.steps);
|
||||
}, [metadata?.steps, recallSteps]);
|
||||
|
||||
const handleRecallCfgScale = useCallback(() => {
|
||||
recallCfgScale(metadata?.cfg_scale);
|
||||
}, [metadata?.cfg_scale, recallCfgScale]);
|
||||
|
||||
const handleRecallStrength = useCallback(() => {
|
||||
recallStrength(metadata?.strength);
|
||||
}, [metadata?.strength, recallStrength]);
|
||||
|
||||
if (!metadata || Object.keys(metadata).length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
{metadata.generation_mode && (
|
||||
<MetadataItem
|
||||
label="Generation Mode"
|
||||
value={metadata.generation_mode}
|
||||
/>
|
||||
)}
|
||||
{metadata.positive_prompt && (
|
||||
<MetadataItem
|
||||
label="Positive Prompt"
|
||||
labelPosition="top"
|
||||
value={metadata.positive_prompt}
|
||||
onClick={handleRecallPositivePrompt}
|
||||
/>
|
||||
)}
|
||||
{metadata.negative_prompt && (
|
||||
<MetadataItem
|
||||
label="Negative Prompt"
|
||||
labelPosition="top"
|
||||
value={metadata.negative_prompt}
|
||||
onClick={handleRecallNegativePrompt}
|
||||
/>
|
||||
)}
|
||||
{metadata.seed !== undefined && (
|
||||
<MetadataItem
|
||||
label="Seed"
|
||||
value={metadata.seed}
|
||||
onClick={handleRecallSeed}
|
||||
/>
|
||||
)}
|
||||
{metadata.model !== undefined && (
|
||||
<MetadataItem
|
||||
label="Model"
|
||||
value={metadata.model.model_name}
|
||||
onClick={handleRecallModel}
|
||||
/>
|
||||
)}
|
||||
{metadata.width && (
|
||||
<MetadataItem
|
||||
label="Width"
|
||||
value={metadata.width}
|
||||
onClick={handleRecallWidth}
|
||||
/>
|
||||
)}
|
||||
{metadata.height && (
|
||||
<MetadataItem
|
||||
label="Height"
|
||||
value={metadata.height}
|
||||
onClick={handleRecallHeight}
|
||||
/>
|
||||
)}
|
||||
{/* {metadata.threshold !== undefined && (
|
||||
<MetadataItem
|
||||
label="Noise Threshold"
|
||||
value={metadata.threshold}
|
||||
onClick={() => dispatch(setThreshold(Number(metadata.threshold)))}
|
||||
/>
|
||||
)}
|
||||
{metadata.perlin !== undefined && (
|
||||
<MetadataItem
|
||||
label="Perlin Noise"
|
||||
value={metadata.perlin}
|
||||
onClick={() => dispatch(setPerlin(Number(metadata.perlin)))}
|
||||
/>
|
||||
)} */}
|
||||
{metadata.scheduler && (
|
||||
<MetadataItem
|
||||
label="Scheduler"
|
||||
value={metadata.scheduler}
|
||||
onClick={handleRecallScheduler}
|
||||
/>
|
||||
)}
|
||||
{metadata.steps && (
|
||||
<MetadataItem
|
||||
label="Steps"
|
||||
value={metadata.steps}
|
||||
onClick={handleRecallSteps}
|
||||
/>
|
||||
)}
|
||||
{metadata.cfg_scale !== undefined && (
|
||||
<MetadataItem
|
||||
label="CFG scale"
|
||||
value={metadata.cfg_scale}
|
||||
onClick={handleRecallCfgScale}
|
||||
/>
|
||||
)}
|
||||
{/* {metadata.variations && metadata.variations.length > 0 && (
|
||||
<MetadataItem
|
||||
label="Seed-weight pairs"
|
||||
value={seedWeightsToString(metadata.variations)}
|
||||
onClick={() =>
|
||||
dispatch(
|
||||
setSeedWeights(seedWeightsToString(metadata.variations))
|
||||
)
|
||||
}
|
||||
/>
|
||||
)}
|
||||
{metadata.seamless && (
|
||||
<MetadataItem
|
||||
label="Seamless"
|
||||
value={metadata.seamless}
|
||||
onClick={() => dispatch(setSeamless(metadata.seamless))}
|
||||
/>
|
||||
)}
|
||||
{metadata.hires_fix && (
|
||||
<MetadataItem
|
||||
label="High Resolution Optimization"
|
||||
value={metadata.hires_fix}
|
||||
onClick={() => dispatch(setHiresFix(metadata.hires_fix))}
|
||||
/>
|
||||
)} */}
|
||||
|
||||
{/* {init_image_path && (
|
||||
<MetadataItem
|
||||
label="Initial image"
|
||||
value={init_image_path}
|
||||
isLink
|
||||
onClick={() => dispatch(setInitialImage(init_image_path))}
|
||||
/>
|
||||
)} */}
|
||||
{metadata.strength && (
|
||||
<MetadataItem
|
||||
label="Image to image strength"
|
||||
value={metadata.strength}
|
||||
onClick={handleRecallStrength}
|
||||
/>
|
||||
)}
|
||||
{/* {metadata.fit && (
|
||||
<MetadataItem
|
||||
label="Image to image fit"
|
||||
value={metadata.fit}
|
||||
onClick={() => dispatch(setShouldFitToWidthHeight(metadata.fit))}
|
||||
/>
|
||||
)} */}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default ImageMetadataActions;
|
@ -1,131 +1,74 @@
|
||||
import { ExternalLinkIcon } from '@chakra-ui/icons';
|
||||
import {
|
||||
Box,
|
||||
Center,
|
||||
Flex,
|
||||
IconButton,
|
||||
Link,
|
||||
Tab,
|
||||
TabList,
|
||||
TabPanel,
|
||||
TabPanels,
|
||||
Tabs,
|
||||
Text,
|
||||
Tooltip,
|
||||
} from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
||||
import { setShouldShowImageDetails } from 'features/ui/store/uiSlice';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
import { memo } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaCopy } from 'react-icons/fa';
|
||||
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
|
||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useGetImageMetadataQuery } from 'services/api/endpoints/images';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
|
||||
type MetadataItemProps = {
|
||||
isLink?: boolean;
|
||||
label: string;
|
||||
onClick?: () => void;
|
||||
value: number | string | boolean;
|
||||
labelPosition?: string;
|
||||
withCopy?: boolean;
|
||||
};
|
||||
|
||||
/**
|
||||
* Component to display an individual metadata item or parameter.
|
||||
*/
|
||||
const MetadataItem = ({
|
||||
label,
|
||||
value,
|
||||
onClick,
|
||||
isLink,
|
||||
labelPosition,
|
||||
withCopy = false,
|
||||
}: MetadataItemProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
if (!value) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex gap={2}>
|
||||
{onClick && (
|
||||
<Tooltip label={`Recall ${label}`}>
|
||||
<IconButton
|
||||
aria-label={t('accessibility.useThisParameter')}
|
||||
icon={<IoArrowUndoCircleOutline />}
|
||||
size="xs"
|
||||
variant="ghost"
|
||||
fontSize={20}
|
||||
onClick={onClick}
|
||||
/>
|
||||
</Tooltip>
|
||||
)}
|
||||
{withCopy && (
|
||||
<Tooltip label={`Copy ${label}`}>
|
||||
<IconButton
|
||||
aria-label={`Copy ${label}`}
|
||||
icon={<FaCopy />}
|
||||
size="xs"
|
||||
variant="ghost"
|
||||
fontSize={14}
|
||||
onClick={() => navigator.clipboard.writeText(value.toString())}
|
||||
/>
|
||||
</Tooltip>
|
||||
)}
|
||||
<Flex direction={labelPosition ? 'column' : 'row'}>
|
||||
<Text fontWeight="semibold" whiteSpace="pre-wrap" pr={2}>
|
||||
{label}:
|
||||
</Text>
|
||||
{isLink ? (
|
||||
<Link href={value.toString()} isExternal wordBreak="break-all">
|
||||
{value.toString()} <ExternalLinkIcon mx="2px" />
|
||||
</Link>
|
||||
) : (
|
||||
<Text overflowY="scroll" wordBreak="break-all">
|
||||
{value.toString()}
|
||||
</Text>
|
||||
)}
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
import { useDebounce } from 'use-debounce';
|
||||
import ImageMetadataActions from './ImageMetadataActions';
|
||||
import MetadataJSONViewer from './MetadataJSONViewer';
|
||||
|
||||
type ImageMetadataViewerProps = {
|
||||
image: ImageDTO;
|
||||
};
|
||||
|
||||
/**
|
||||
* Image metadata viewer overlays currently selected image and provides
|
||||
* access to any of its metadata for use in processing.
|
||||
*/
|
||||
const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const {
|
||||
recallBothPrompts,
|
||||
recallPositivePrompt,
|
||||
recallNegativePrompt,
|
||||
recallSeed,
|
||||
recallInitialImage,
|
||||
recallCfgScale,
|
||||
recallModel,
|
||||
recallScheduler,
|
||||
recallSteps,
|
||||
recallWidth,
|
||||
recallHeight,
|
||||
recallStrength,
|
||||
recallAllParameters,
|
||||
} = useRecallParameters();
|
||||
// TODO: fix hotkeys
|
||||
// const dispatch = useAppDispatch();
|
||||
// useHotkeys('esc', () => {
|
||||
// dispatch(setShouldShowImageDetails(false));
|
||||
// });
|
||||
|
||||
useHotkeys('esc', () => {
|
||||
dispatch(setShouldShowImageDetails(false));
|
||||
});
|
||||
const [debouncedMetadataQueryArg, debounceState] = useDebounce(
|
||||
image.image_name,
|
||||
500
|
||||
);
|
||||
|
||||
const sessionId = image?.session_id;
|
||||
const { currentData } = useGetImageMetadataQuery(
|
||||
debounceState.isPending()
|
||||
? skipToken
|
||||
: debouncedMetadataQueryArg ?? skipToken
|
||||
);
|
||||
const metadata = currentData?.metadata;
|
||||
const graph = currentData?.graph;
|
||||
|
||||
const metadata = image?.metadata;
|
||||
const tabData = useMemo(() => {
|
||||
const _tabData: { label: string; data: object; copyTooltip: string }[] = [];
|
||||
|
||||
const { t } = useTranslation();
|
||||
if (metadata) {
|
||||
_tabData.push({
|
||||
label: 'Core Metadata',
|
||||
data: metadata,
|
||||
copyTooltip: 'Copy Core Metadata JSON',
|
||||
});
|
||||
}
|
||||
|
||||
const metadataJSON = JSON.stringify(image, null, 2);
|
||||
if (image) {
|
||||
_tabData.push({
|
||||
label: 'Image Details',
|
||||
data: image,
|
||||
copyTooltip: 'Copy Image Details JSON',
|
||||
});
|
||||
}
|
||||
|
||||
if (graph) {
|
||||
_tabData.push({
|
||||
label: 'Graph',
|
||||
data: graph,
|
||||
copyTooltip: 'Copy Graph JSON',
|
||||
});
|
||||
}
|
||||
return _tabData;
|
||||
}, [metadata, graph, image]);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
@ -136,11 +79,13 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
|
||||
width: 'full',
|
||||
height: 'full',
|
||||
backdropFilter: 'blur(20px)',
|
||||
bg: 'whiteAlpha.600',
|
||||
bg: 'baseAlpha.200',
|
||||
_dark: {
|
||||
bg: 'blackAlpha.600',
|
||||
},
|
||||
overflow: 'scroll',
|
||||
borderRadius: 'base',
|
||||
position: 'absolute',
|
||||
overflow: 'hidden',
|
||||
}}
|
||||
>
|
||||
<Flex gap={2}>
|
||||
@ -150,179 +95,42 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
|
||||
<ExternalLinkIcon mx="2px" />
|
||||
</Link>
|
||||
</Flex>
|
||||
{metadata && Object.keys(metadata).length > 0 ? (
|
||||
<>
|
||||
{metadata.type && (
|
||||
<MetadataItem label="Invocation type" value={metadata.type} />
|
||||
)}
|
||||
{sessionId && <MetadataItem label="Session ID" value={sessionId} />}
|
||||
{metadata.positive_conditioning && (
|
||||
<MetadataItem
|
||||
label="Positive Prompt"
|
||||
labelPosition="top"
|
||||
value={metadata.positive_conditioning}
|
||||
onClick={() =>
|
||||
recallPositivePrompt(metadata.positive_conditioning)
|
||||
}
|
||||
/>
|
||||
)}
|
||||
{metadata.negative_conditioning && (
|
||||
<MetadataItem
|
||||
label="Negative Prompt"
|
||||
labelPosition="top"
|
||||
value={metadata.negative_conditioning}
|
||||
onClick={() =>
|
||||
recallNegativePrompt(metadata.negative_conditioning)
|
||||
}
|
||||
/>
|
||||
)}
|
||||
{metadata.seed !== undefined && (
|
||||
<MetadataItem
|
||||
label="Seed"
|
||||
value={metadata.seed}
|
||||
onClick={() => recallSeed(metadata.seed)}
|
||||
/>
|
||||
)}
|
||||
{metadata.model !== undefined && (
|
||||
<MetadataItem
|
||||
label="Model"
|
||||
value={metadata.model}
|
||||
onClick={() => recallModel(metadata.model)}
|
||||
/>
|
||||
)}
|
||||
{metadata.width && (
|
||||
<MetadataItem
|
||||
label="Width"
|
||||
value={metadata.width}
|
||||
onClick={() => recallWidth(metadata.width)}
|
||||
/>
|
||||
)}
|
||||
{metadata.height && (
|
||||
<MetadataItem
|
||||
label="Height"
|
||||
value={metadata.height}
|
||||
onClick={() => recallHeight(metadata.height)}
|
||||
/>
|
||||
)}
|
||||
{/* {metadata.threshold !== undefined && (
|
||||
<MetadataItem
|
||||
label="Noise Threshold"
|
||||
value={metadata.threshold}
|
||||
onClick={() => dispatch(setThreshold(Number(metadata.threshold)))}
|
||||
/>
|
||||
)}
|
||||
{metadata.perlin !== undefined && (
|
||||
<MetadataItem
|
||||
label="Perlin Noise"
|
||||
value={metadata.perlin}
|
||||
onClick={() => dispatch(setPerlin(Number(metadata.perlin)))}
|
||||
/>
|
||||
)} */}
|
||||
{metadata.scheduler && (
|
||||
<MetadataItem
|
||||
label="Scheduler"
|
||||
value={metadata.scheduler}
|
||||
onClick={() => recallScheduler(metadata.scheduler)}
|
||||
/>
|
||||
)}
|
||||
{metadata.steps && (
|
||||
<MetadataItem
|
||||
label="Steps"
|
||||
value={metadata.steps}
|
||||
onClick={() => recallSteps(metadata.steps)}
|
||||
/>
|
||||
)}
|
||||
{metadata.cfg_scale !== undefined && (
|
||||
<MetadataItem
|
||||
label="CFG scale"
|
||||
value={metadata.cfg_scale}
|
||||
onClick={() => recallCfgScale(metadata.cfg_scale)}
|
||||
/>
|
||||
)}
|
||||
{/* {metadata.variations && metadata.variations.length > 0 && (
|
||||
<MetadataItem
|
||||
label="Seed-weight pairs"
|
||||
value={seedWeightsToString(metadata.variations)}
|
||||
onClick={() =>
|
||||
dispatch(
|
||||
setSeedWeights(seedWeightsToString(metadata.variations))
|
||||
)
|
||||
}
|
||||
/>
|
||||
)}
|
||||
{metadata.seamless && (
|
||||
<MetadataItem
|
||||
label="Seamless"
|
||||
value={metadata.seamless}
|
||||
onClick={() => dispatch(setSeamless(metadata.seamless))}
|
||||
/>
|
||||
)}
|
||||
{metadata.hires_fix && (
|
||||
<MetadataItem
|
||||
label="High Resolution Optimization"
|
||||
value={metadata.hires_fix}
|
||||
onClick={() => dispatch(setHiresFix(metadata.hires_fix))}
|
||||
/>
|
||||
)} */}
|
||||
|
||||
{/* {init_image_path && (
|
||||
<MetadataItem
|
||||
label="Initial image"
|
||||
value={init_image_path}
|
||||
isLink
|
||||
onClick={() => dispatch(setInitialImage(init_image_path))}
|
||||
/>
|
||||
)} */}
|
||||
{metadata.strength && (
|
||||
<MetadataItem
|
||||
label="Image to image strength"
|
||||
value={metadata.strength}
|
||||
onClick={() => recallStrength(metadata.strength)}
|
||||
/>
|
||||
)}
|
||||
{/* {metadata.fit && (
|
||||
<MetadataItem
|
||||
label="Image to image fit"
|
||||
value={metadata.fit}
|
||||
onClick={() => dispatch(setShouldFitToWidthHeight(metadata.fit))}
|
||||
/>
|
||||
)} */}
|
||||
</>
|
||||
) : (
|
||||
<Center width="100%" pt={10}>
|
||||
<Text fontSize="lg" fontWeight="semibold">
|
||||
No metadata available
|
||||
</Text>
|
||||
</Center>
|
||||
)}
|
||||
<Flex gap={2} direction="column" overflow="auto">
|
||||
<Flex gap={2}>
|
||||
<Tooltip label="Copy metadata JSON">
|
||||
<IconButton
|
||||
aria-label={t('accessibility.copyMetadataJson')}
|
||||
icon={<FaCopy />}
|
||||
size="xs"
|
||||
variant="ghost"
|
||||
fontSize={14}
|
||||
onClick={() => navigator.clipboard.writeText(metadataJSON)}
|
||||
/>
|
||||
</Tooltip>
|
||||
<Text fontWeight="semibold">Metadata JSON:</Text>
|
||||
</Flex>
|
||||
<OverlayScrollbarsComponent defer>
|
||||
<Box
|
||||
sx={{
|
||||
padding: 4,
|
||||
borderRadius: 'base',
|
||||
bg: 'whiteAlpha.500',
|
||||
_dark: { bg: 'blackAlpha.500' },
|
||||
w: 'full',
|
||||
}}
|
||||
>
|
||||
<pre>{metadataJSON}</pre>
|
||||
</Box>
|
||||
</OverlayScrollbarsComponent>
|
||||
</Flex>
|
||||
<ImageMetadataActions metadata={metadata} />
|
||||
|
||||
<Tabs
|
||||
variant="line"
|
||||
sx={{ display: 'flex', flexDir: 'column', w: 'full', h: 'full' }}
|
||||
>
|
||||
<TabList>
|
||||
{tabData.map((tab) => (
|
||||
<Tab
|
||||
key={tab.label}
|
||||
sx={{
|
||||
borderTopRadius: 'base',
|
||||
}}
|
||||
>
|
||||
<Text sx={{ color: 'base.700', _dark: { color: 'base.300' } }}>
|
||||
{tab.label}
|
||||
</Text>
|
||||
</Tab>
|
||||
))}
|
||||
</TabList>
|
||||
|
||||
<TabPanels sx={{ w: 'full', h: 'full' }}>
|
||||
{tabData.map((tab) => (
|
||||
<TabPanel
|
||||
key={tab.label}
|
||||
sx={{ w: 'full', h: 'full', p: 0, pt: 4 }}
|
||||
>
|
||||
<MetadataJSONViewer
|
||||
jsonObject={tab.data}
|
||||
copyTooltip={tab.copyTooltip}
|
||||
/>
|
||||
</TabPanel>
|
||||
))}
|
||||
</TabPanels>
|
||||
</Tabs>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
@ -0,0 +1,77 @@
|
||||
import { ExternalLinkIcon } from '@chakra-ui/icons';
|
||||
import { Flex, IconButton, Link, Text, Tooltip } from '@chakra-ui/react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaCopy } from 'react-icons/fa';
|
||||
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
|
||||
|
||||
type MetadataItemProps = {
|
||||
isLink?: boolean;
|
||||
label: string;
|
||||
onClick?: () => void;
|
||||
value: number | string | boolean;
|
||||
labelPosition?: string;
|
||||
withCopy?: boolean;
|
||||
};
|
||||
|
||||
/**
|
||||
* Component to display an individual metadata item or parameter.
|
||||
*/
|
||||
const MetadataItem = ({
|
||||
label,
|
||||
value,
|
||||
onClick,
|
||||
isLink,
|
||||
labelPosition,
|
||||
withCopy = false,
|
||||
}: MetadataItemProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
if (!value) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex gap={2}>
|
||||
{onClick && (
|
||||
<Tooltip label={`Recall ${label}`}>
|
||||
<IconButton
|
||||
aria-label={t('accessibility.useThisParameter')}
|
||||
icon={<IoArrowUndoCircleOutline />}
|
||||
size="xs"
|
||||
variant="ghost"
|
||||
fontSize={20}
|
||||
onClick={onClick}
|
||||
/>
|
||||
</Tooltip>
|
||||
)}
|
||||
{withCopy && (
|
||||
<Tooltip label={`Copy ${label}`}>
|
||||
<IconButton
|
||||
aria-label={`Copy ${label}`}
|
||||
icon={<FaCopy />}
|
||||
size="xs"
|
||||
variant="ghost"
|
||||
fontSize={14}
|
||||
onClick={() => navigator.clipboard.writeText(value.toString())}
|
||||
/>
|
||||
</Tooltip>
|
||||
)}
|
||||
<Flex direction={labelPosition ? 'column' : 'row'}>
|
||||
<Text fontWeight="semibold" whiteSpace="pre-wrap" pr={2}>
|
||||
{label}:
|
||||
</Text>
|
||||
{isLink ? (
|
||||
<Link href={value.toString()} isExternal wordBreak="break-all">
|
||||
{value.toString()} <ExternalLinkIcon mx="2px" />
|
||||
</Link>
|
||||
) : (
|
||||
<Text overflowY="scroll" wordBreak="break-all">
|
||||
{value.toString()}
|
||||
</Text>
|
||||
)}
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default MetadataItem;
|
@ -0,0 +1,70 @@
|
||||
import { Box, Flex, IconButton, Tooltip } from '@chakra-ui/react';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
import { useMemo } from 'react';
|
||||
import { FaCopy } from 'react-icons/fa';
|
||||
|
||||
type Props = {
|
||||
copyTooltip: string;
|
||||
jsonObject: object;
|
||||
};
|
||||
|
||||
const MetadataJSONViewer = (props: Props) => {
|
||||
const { copyTooltip, jsonObject } = props;
|
||||
const jsonString = useMemo(
|
||||
() => JSON.stringify(jsonObject, null, 2),
|
||||
[jsonObject]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
borderRadius: 'base',
|
||||
bg: 'whiteAlpha.500',
|
||||
_dark: { bg: 'blackAlpha.500' },
|
||||
flexGrow: 1,
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
position: 'relative',
|
||||
}}
|
||||
>
|
||||
<Box
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
top: 0,
|
||||
left: 0,
|
||||
right: 0,
|
||||
bottom: 0,
|
||||
overflow: 'auto',
|
||||
p: 4,
|
||||
}}
|
||||
>
|
||||
<OverlayScrollbarsComponent
|
||||
defer
|
||||
style={{ height: '100%', width: '100%' }}
|
||||
options={{
|
||||
scrollbars: {
|
||||
visibility: 'auto',
|
||||
autoHide: 'move',
|
||||
autoHideDelay: 1300,
|
||||
theme: 'os-theme-dark',
|
||||
},
|
||||
}}
|
||||
>
|
||||
<pre>{jsonString}</pre>
|
||||
</OverlayScrollbarsComponent>
|
||||
</Box>
|
||||
<Flex sx={{ position: 'absolute', top: 0, insetInlineEnd: 0, p: 2 }}>
|
||||
<Tooltip label={copyTooltip}>
|
||||
<IconButton
|
||||
aria-label={copyTooltip}
|
||||
icon={<FaCopy />}
|
||||
variant="ghost"
|
||||
onClick={() => navigator.clipboard.writeText(jsonString)}
|
||||
/>
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default MetadataJSONViewer;
|
@ -1,18 +1,8 @@
|
||||
import { ChakraProps, Flex, Grid, IconButton, Spinner } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import {
|
||||
imageSelected,
|
||||
selectFilteredImages,
|
||||
selectImagesById,
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
import { clamp, isEqual } from 'lodash-es';
|
||||
import { memo, useCallback, useState } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaAngleDoubleRight, FaAngleLeft, FaAngleRight } from 'react-icons/fa';
|
||||
import { receivedPageOfImages } from 'services/api/thunks/image';
|
||||
import { useNextPrevImage } from '../hooks/useNextPrevImage';
|
||||
|
||||
const nextPrevButtonTriggerAreaStyles: ChakraProps['sx'] = {
|
||||
height: '100%',
|
||||
@ -24,74 +14,18 @@ const nextPrevButtonStyles: ChakraProps['sx'] = {
|
||||
color: 'base.100',
|
||||
};
|
||||
|
||||
export const nextPrevImageButtonsSelector = createSelector(
|
||||
[stateSelector, selectFilteredImages],
|
||||
(state, filteredImages) => {
|
||||
const { total, isFetching } = state.gallery;
|
||||
const lastSelectedImage =
|
||||
state.gallery.selection[state.gallery.selection.length - 1];
|
||||
|
||||
if (!lastSelectedImage || filteredImages.length === 0) {
|
||||
return {
|
||||
isOnFirstImage: true,
|
||||
isOnLastImage: true,
|
||||
};
|
||||
}
|
||||
|
||||
const currentImageIndex = filteredImages.findIndex(
|
||||
(i) => i.image_name === lastSelectedImage
|
||||
);
|
||||
const nextImageIndex = clamp(
|
||||
currentImageIndex + 1,
|
||||
0,
|
||||
filteredImages.length - 1
|
||||
);
|
||||
|
||||
const prevImageIndex = clamp(
|
||||
currentImageIndex - 1,
|
||||
0,
|
||||
filteredImages.length - 1
|
||||
);
|
||||
|
||||
const nextImageId = filteredImages[nextImageIndex].image_name;
|
||||
const prevImageId = filteredImages[prevImageIndex].image_name;
|
||||
|
||||
const nextImage = selectImagesById(state, nextImageId);
|
||||
const prevImage = selectImagesById(state, prevImageId);
|
||||
|
||||
const imagesLength = filteredImages.length;
|
||||
|
||||
return {
|
||||
isOnFirstImage: currentImageIndex === 0,
|
||||
isOnLastImage:
|
||||
!isNaN(currentImageIndex) && currentImageIndex === imagesLength - 1,
|
||||
areMoreImagesAvailable: total > imagesLength,
|
||||
isFetching,
|
||||
nextImage,
|
||||
prevImage,
|
||||
nextImageId,
|
||||
prevImageId,
|
||||
};
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
const NextPrevImageButtons = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const {
|
||||
handlePrevImage,
|
||||
handleNextImage,
|
||||
isOnFirstImage,
|
||||
isOnLastImage,
|
||||
nextImageId,
|
||||
prevImageId,
|
||||
handleLoadMoreImages,
|
||||
areMoreImagesAvailable,
|
||||
isFetching,
|
||||
} = useAppSelector(nextPrevImageButtonsSelector);
|
||||
} = useNextPrevImage();
|
||||
|
||||
const [shouldShowNextPrevButtons, setShouldShowNextPrevButtons] =
|
||||
useState<boolean>(false);
|
||||
@ -104,50 +38,6 @@ const NextPrevImageButtons = () => {
|
||||
setShouldShowNextPrevButtons(false);
|
||||
}, []);
|
||||
|
||||
const handlePrevImage = useCallback(() => {
|
||||
prevImageId && dispatch(imageSelected(prevImageId));
|
||||
}, [dispatch, prevImageId]);
|
||||
|
||||
const handleNextImage = useCallback(() => {
|
||||
nextImageId && dispatch(imageSelected(nextImageId));
|
||||
}, [dispatch, nextImageId]);
|
||||
|
||||
const handleLoadMoreImages = useCallback(() => {
|
||||
dispatch(
|
||||
receivedPageOfImages({
|
||||
is_intermediate: false,
|
||||
})
|
||||
);
|
||||
}, [dispatch]);
|
||||
|
||||
useHotkeys(
|
||||
'left',
|
||||
() => {
|
||||
handlePrevImage();
|
||||
},
|
||||
[prevImageId]
|
||||
);
|
||||
|
||||
useHotkeys(
|
||||
'right',
|
||||
() => {
|
||||
if (isOnLastImage && areMoreImagesAvailable && !isFetching) {
|
||||
handleLoadMoreImages();
|
||||
return;
|
||||
}
|
||||
if (!isOnLastImage) {
|
||||
handleNextImage();
|
||||
}
|
||||
},
|
||||
[
|
||||
nextImageId,
|
||||
isOnLastImage,
|
||||
areMoreImagesAvailable,
|
||||
handleLoadMoreImages,
|
||||
isFetching,
|
||||
]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
|
@ -0,0 +1,108 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import {
|
||||
imageSelected,
|
||||
selectFilteredImages,
|
||||
selectImagesById,
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
import { clamp, isEqual } from 'lodash-es';
|
||||
import { useCallback } from 'react';
|
||||
import { receivedPageOfImages } from 'services/api/thunks/image';
|
||||
|
||||
export const nextPrevImageButtonsSelector = createSelector(
|
||||
[stateSelector, selectFilteredImages],
|
||||
(state, filteredImages) => {
|
||||
const { total, isFetching } = state.gallery;
|
||||
const lastSelectedImage =
|
||||
state.gallery.selection[state.gallery.selection.length - 1];
|
||||
|
||||
if (!lastSelectedImage || filteredImages.length === 0) {
|
||||
return {
|
||||
isOnFirstImage: true,
|
||||
isOnLastImage: true,
|
||||
};
|
||||
}
|
||||
|
||||
const currentImageIndex = filteredImages.findIndex(
|
||||
(i) => i.image_name === lastSelectedImage
|
||||
);
|
||||
const nextImageIndex = clamp(
|
||||
currentImageIndex + 1,
|
||||
0,
|
||||
filteredImages.length - 1
|
||||
);
|
||||
|
||||
const prevImageIndex = clamp(
|
||||
currentImageIndex - 1,
|
||||
0,
|
||||
filteredImages.length - 1
|
||||
);
|
||||
|
||||
const nextImageId = filteredImages[nextImageIndex].image_name;
|
||||
const prevImageId = filteredImages[prevImageIndex].image_name;
|
||||
|
||||
const nextImage = selectImagesById(state, nextImageId);
|
||||
const prevImage = selectImagesById(state, prevImageId);
|
||||
|
||||
const imagesLength = filteredImages.length;
|
||||
|
||||
return {
|
||||
isOnFirstImage: currentImageIndex === 0,
|
||||
isOnLastImage:
|
||||
!isNaN(currentImageIndex) && currentImageIndex === imagesLength - 1,
|
||||
areMoreImagesAvailable: total > imagesLength,
|
||||
isFetching,
|
||||
nextImage,
|
||||
prevImage,
|
||||
nextImageId,
|
||||
prevImageId,
|
||||
};
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
export const useNextPrevImage = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const {
|
||||
isOnFirstImage,
|
||||
isOnLastImage,
|
||||
nextImageId,
|
||||
prevImageId,
|
||||
areMoreImagesAvailable,
|
||||
isFetching,
|
||||
} = useAppSelector(nextPrevImageButtonsSelector);
|
||||
|
||||
const handlePrevImage = useCallback(() => {
|
||||
prevImageId && dispatch(imageSelected(prevImageId));
|
||||
}, [dispatch, prevImageId]);
|
||||
|
||||
const handleNextImage = useCallback(() => {
|
||||
nextImageId && dispatch(imageSelected(nextImageId));
|
||||
}, [dispatch, nextImageId]);
|
||||
|
||||
const handleLoadMoreImages = useCallback(() => {
|
||||
dispatch(
|
||||
receivedPageOfImages({
|
||||
is_intermediate: false,
|
||||
})
|
||||
);
|
||||
}, [dispatch]);
|
||||
|
||||
return {
|
||||
handlePrevImage,
|
||||
handleNextImage,
|
||||
isOnFirstImage,
|
||||
isOnLastImage,
|
||||
nextImageId,
|
||||
prevImageId,
|
||||
areMoreImagesAvailable,
|
||||
handleLoadMoreImages,
|
||||
isFetching,
|
||||
};
|
||||
};
|
@ -37,7 +37,7 @@ const ParamLora = (props: Props) => {
|
||||
return (
|
||||
<Flex sx={{ gap: 2.5, alignItems: 'flex-end' }}>
|
||||
<IAISlider
|
||||
label={lora.name}
|
||||
label={lora.model_name}
|
||||
value={lora.weight}
|
||||
onChange={handleChange}
|
||||
min={-1}
|
||||
|
@ -18,7 +18,7 @@ const selector = createSelector(
|
||||
const ParamLoraList = () => {
|
||||
const { loras } = useAppSelector(selector);
|
||||
|
||||
return map(loras, (lora) => <ParamLora key={lora.name} lora={lora} />);
|
||||
return map(loras, (lora) => <ParamLora key={lora.model_name} lora={lora} />);
|
||||
};
|
||||
|
||||
export default ParamLoraList;
|
||||
|
@ -45,7 +45,7 @@ const ParamLoraSelect = () => {
|
||||
|
||||
data.push({
|
||||
value: id,
|
||||
label: lora.name,
|
||||
label: lora.model_name,
|
||||
disabled,
|
||||
group: MODEL_TYPE_MAP[lora.base_model],
|
||||
tooltip: disabled
|
||||
|
@ -1,12 +1,8 @@
|
||||
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
|
||||
import { LoRAModelParam } from 'features/parameters/store/parameterZodSchemas';
|
||||
import { LoRAModelConfigEntity } from 'services/api/endpoints/models';
|
||||
import { BaseModelType } from 'services/api/types';
|
||||
|
||||
export type Lora = {
|
||||
id: string;
|
||||
base_model: BaseModelType;
|
||||
name: string;
|
||||
export type Lora = LoRAModelParam & {
|
||||
weight: number;
|
||||
};
|
||||
|
||||
@ -27,8 +23,8 @@ export const loraSlice = createSlice({
|
||||
initialState: intialLoraState,
|
||||
reducers: {
|
||||
loraAdded: (state, action: PayloadAction<LoRAModelConfigEntity>) => {
|
||||
const { name, id, base_model } = action.payload;
|
||||
state.loras[id] = { id, name, base_model, ...defaultLoRAConfig };
|
||||
const { model_name, id, base_model } = action.payload;
|
||||
state.loras[id] = { id, model_name, base_model, ...defaultLoRAConfig };
|
||||
},
|
||||
loraRemoved: (state, action: PayloadAction<string>) => {
|
||||
const id = action.payload;
|
||||
|
@ -7,6 +7,7 @@ import {
|
||||
OnConnectEnd,
|
||||
OnConnectStart,
|
||||
OnEdgesChange,
|
||||
OnInit,
|
||||
OnNodesChange,
|
||||
ReactFlow,
|
||||
} from 'reactflow';
|
||||
@ -16,6 +17,7 @@ import {
|
||||
connectionStarted,
|
||||
edgesChanged,
|
||||
nodesChanged,
|
||||
setEditorInstance,
|
||||
} from '../store/nodesSlice';
|
||||
import { InvocationComponent } from './InvocationComponent';
|
||||
import ProgressImageNode from './ProgressImageNode';
|
||||
@ -67,6 +69,14 @@ export const Flow = () => {
|
||||
dispatch(connectionEnded());
|
||||
}, [dispatch]);
|
||||
|
||||
const onInit: OnInit = useCallback(
|
||||
(v) => {
|
||||
dispatch(setEditorInstance(v));
|
||||
if (v) v.fitView();
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<ReactFlow
|
||||
nodeTypes={nodeTypes}
|
||||
@ -77,6 +87,7 @@ export const Flow = () => {
|
||||
onConnectStart={onConnectStart}
|
||||
onConnect={onConnect}
|
||||
onConnectEnd={onConnectEnd}
|
||||
onInit={onInit}
|
||||
defaultEdgeOptions={{
|
||||
style: { strokeWidth: 2 },
|
||||
}}
|
||||
|
@ -45,7 +45,7 @@ const LoRAModelInputFieldComponent = (
|
||||
|
||||
data.push({
|
||||
value: id,
|
||||
label: model.name,
|
||||
label: model.model_name,
|
||||
group: BASE_MODEL_NAME_MAP[model.base_model],
|
||||
});
|
||||
});
|
||||
|
@ -38,7 +38,7 @@ const ModelInputFieldComponent = (
|
||||
|
||||
data.push({
|
||||
value: id,
|
||||
label: model.name,
|
||||
label: model.model_name,
|
||||
group: BASE_MODEL_NAME_MAP[model.base_model],
|
||||
});
|
||||
});
|
||||
|
@ -45,7 +45,7 @@ const VaeModelInputFieldComponent = (
|
||||
|
||||
data.push({
|
||||
value: id,
|
||||
label: model.name,
|
||||
label: model.model_name,
|
||||
group: BASE_MODEL_NAME_MAP[model.base_model],
|
||||
});
|
||||
});
|
||||
|
@ -1,25 +1,23 @@
|
||||
import { HStack } from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { Panel } from 'reactflow';
|
||||
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
|
||||
import NodeInvokeButton from '../ui/NodeInvokeButton';
|
||||
import CancelButton from 'features/parameters/components/ProcessButtons/CancelButton';
|
||||
import { memo } from 'react';
|
||||
import { Panel } from 'reactflow';
|
||||
import LoadNodesButton from '../ui/LoadNodesButton';
|
||||
import NodeInvokeButton from '../ui/NodeInvokeButton';
|
||||
import ReloadSchemaButton from '../ui/ReloadSchemaButton';
|
||||
import SaveNodesButton from '../ui/SaveNodesButton';
|
||||
import ClearNodesButton from '../ui/ClearNodesButton';
|
||||
|
||||
const TopCenterPanel = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const handleReloadSchema = useCallback(() => {
|
||||
dispatch(receivedOpenAPISchema());
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<Panel position="top-center">
|
||||
<HStack>
|
||||
<NodeInvokeButton />
|
||||
<CancelButton />
|
||||
<IAIButton onClick={handleReloadSchema}>Reload Schema</IAIButton>
|
||||
<ReloadSchemaButton />
|
||||
<SaveNodesButton />
|
||||
<LoadNodesButton />
|
||||
<ClearNodesButton />
|
||||
</HStack>
|
||||
</Panel>
|
||||
);
|
||||
|
@ -0,0 +1,86 @@
|
||||
import {
|
||||
AlertDialog,
|
||||
AlertDialogBody,
|
||||
AlertDialogContent,
|
||||
AlertDialogFooter,
|
||||
AlertDialogHeader,
|
||||
AlertDialogOverlay,
|
||||
Button,
|
||||
Text,
|
||||
useDisclosure,
|
||||
} from '@chakra-ui/react';
|
||||
import { makeToast } from 'app/components/Toaster';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { memo, useRef, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaTrash } from 'react-icons/fa';
|
||||
|
||||
const ClearNodesButton = () => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
const cancelRef = useRef<HTMLButtonElement | null>(null);
|
||||
|
||||
const nodes = useAppSelector((state: RootState) => state.nodes.nodes);
|
||||
|
||||
const handleConfirmClear = useCallback(() => {
|
||||
dispatch(nodeEditorReset());
|
||||
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('toast.nodesCleared'),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
|
||||
onClose();
|
||||
}, [dispatch, t, onClose]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<IAIIconButton
|
||||
icon={<FaTrash />}
|
||||
tooltip={t('nodes.clearNodes')}
|
||||
aria-label={t('nodes.clearNodes')}
|
||||
onClick={onOpen}
|
||||
isDisabled={nodes.length === 0}
|
||||
/>
|
||||
|
||||
<AlertDialog
|
||||
isOpen={isOpen}
|
||||
onClose={onClose}
|
||||
leastDestructiveRef={cancelRef}
|
||||
isCentered
|
||||
>
|
||||
<AlertDialogOverlay />
|
||||
|
||||
<AlertDialogContent>
|
||||
<AlertDialogHeader fontSize="lg" fontWeight="bold">
|
||||
{t('nodes.clearNodes')}
|
||||
</AlertDialogHeader>
|
||||
|
||||
<AlertDialogBody>
|
||||
<Text>{t('common.clearNodes')}</Text>
|
||||
</AlertDialogBody>
|
||||
|
||||
<AlertDialogFooter>
|
||||
<Button ref={cancelRef} onClick={onClose}>
|
||||
{t('common.cancel')}
|
||||
</Button>
|
||||
<Button colorScheme="red" ml={3} onClick={handleConfirmClear}>
|
||||
{t('common.accept')}
|
||||
</Button>
|
||||
</AlertDialogFooter>
|
||||
</AlertDialogContent>
|
||||
</AlertDialog>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ClearNodesButton);
|
@ -0,0 +1,79 @@
|
||||
import { FileButton } from '@mantine/core';
|
||||
import { makeToast } from 'app/components/Toaster';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { loadFileEdges, loadFileNodes } from 'features/nodes/store/nodesSlice';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaUpload } from 'react-icons/fa';
|
||||
import { useReactFlow } from 'reactflow';
|
||||
|
||||
const LoadNodesButton = () => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const { fitView } = useReactFlow();
|
||||
|
||||
const uploadedFileRef = useRef<() => void>(null);
|
||||
|
||||
const restoreJSONToEditor = useCallback(
|
||||
(v: File | null) => {
|
||||
if (!v) return;
|
||||
const reader = new FileReader();
|
||||
reader.onload = async () => {
|
||||
const json = reader.result;
|
||||
const retrievedNodeTree = await JSON.parse(String(json));
|
||||
|
||||
if (!retrievedNodeTree) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('toast.nodesLoadedFailed'),
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (retrievedNodeTree) {
|
||||
dispatch(loadFileNodes(retrievedNodeTree.nodes));
|
||||
dispatch(loadFileEdges(retrievedNodeTree.edges));
|
||||
fitView();
|
||||
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({ title: t('toast.nodesLoaded'), status: 'success' })
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
reader.abort();
|
||||
};
|
||||
|
||||
reader.readAsText(v);
|
||||
|
||||
// Cleanup
|
||||
uploadedFileRef.current?.();
|
||||
},
|
||||
[fitView, dispatch, t]
|
||||
);
|
||||
return (
|
||||
<FileButton
|
||||
resetRef={uploadedFileRef}
|
||||
accept="application/json"
|
||||
onChange={restoreJSONToEditor}
|
||||
>
|
||||
{(props) => (
|
||||
<IAIIconButton
|
||||
icon={<FaUpload />}
|
||||
tooltip={t('nodes.loadNodes')}
|
||||
aria-label={t('nodes.loadNodes')}
|
||||
{...props}
|
||||
/>
|
||||
)}
|
||||
</FileButton>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(LoadNodesButton);
|
@ -0,0 +1,24 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaSyncAlt } from 'react-icons/fa';
|
||||
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
|
||||
|
||||
export default function ReloadSchemaButton() {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const handleReloadSchema = useCallback(() => {
|
||||
dispatch(receivedOpenAPISchema());
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<IAIIconButton
|
||||
icon={<FaSyncAlt />}
|
||||
tooltip={t('nodes.reloadSchema')}
|
||||
aria-label={t('nodes.reloadSchema')}
|
||||
onClick={handleReloadSchema}
|
||||
/>
|
||||
);
|
||||
}
|
@ -0,0 +1,48 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { map, omit } from 'lodash-es';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaSave } from 'react-icons/fa';
|
||||
|
||||
const SaveNodesButton = () => {
|
||||
const { t } = useTranslation();
|
||||
const editorInstance = useAppSelector(
|
||||
(state: RootState) => state.nodes.editorInstance
|
||||
);
|
||||
|
||||
const nodes = useAppSelector((state: RootState) => state.nodes.nodes);
|
||||
|
||||
const saveEditorToJSON = useCallback(() => {
|
||||
if (editorInstance) {
|
||||
const editorState = editorInstance.toObject();
|
||||
|
||||
editorState.edges = map(editorState.edges, (edge) => {
|
||||
return omit(edge, ['style']);
|
||||
});
|
||||
|
||||
const nodeSetupJSON = new Blob([JSON.stringify(editorState)]);
|
||||
const nodeDownloadElement = document.createElement('a');
|
||||
nodeDownloadElement.href = URL.createObjectURL(nodeSetupJSON);
|
||||
nodeDownloadElement.download = 'MyNodes.json';
|
||||
document.body.appendChild(nodeDownloadElement);
|
||||
nodeDownloadElement.click();
|
||||
// Cleanup
|
||||
nodeDownloadElement.remove();
|
||||
}
|
||||
}, [editorInstance]);
|
||||
|
||||
return (
|
||||
<IAIIconButton
|
||||
icon={<FaSave />}
|
||||
fontSize={18}
|
||||
tooltip={t('nodes.saveNodes')}
|
||||
aria-label={t('nodes.saveNodes')}
|
||||
onClick={saveEditorToJSON}
|
||||
isDisabled={nodes.length === 0}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(SaveNodesButton);
|
@ -13,6 +13,7 @@ import {
|
||||
Node,
|
||||
NodeChange,
|
||||
OnConnectStartParams,
|
||||
ReactFlowInstance,
|
||||
} from 'reactflow';
|
||||
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
|
||||
import { ImageField } from 'services/api/types';
|
||||
@ -25,6 +26,7 @@ export type NodesState = {
|
||||
invocationTemplates: Record<string, InvocationTemplate>;
|
||||
connectionStartParams: OnConnectStartParams | null;
|
||||
shouldShowGraphOverlay: boolean;
|
||||
editorInstance: ReactFlowInstance | undefined;
|
||||
};
|
||||
|
||||
export const initialNodesState: NodesState = {
|
||||
@ -34,6 +36,7 @@ export const initialNodesState: NodesState = {
|
||||
invocationTemplates: {},
|
||||
connectionStartParams: null,
|
||||
shouldShowGraphOverlay: false,
|
||||
editorInstance: undefined,
|
||||
};
|
||||
|
||||
const nodesSlice = createSlice({
|
||||
@ -118,8 +121,18 @@ const nodesSlice = createSlice({
|
||||
) => {
|
||||
state.invocationTemplates = action.payload;
|
||||
},
|
||||
nodeEditorReset: () => {
|
||||
return { ...initialNodesState };
|
||||
nodeEditorReset: (state) => {
|
||||
state.nodes = [];
|
||||
state.edges = [];
|
||||
},
|
||||
setEditorInstance: (state, action) => {
|
||||
state.editorInstance = action.payload;
|
||||
},
|
||||
loadFileNodes: (state, action: PayloadAction<Node<InvocationValue>[]>) => {
|
||||
state.nodes = action.payload;
|
||||
},
|
||||
loadFileEdges: (state, action: PayloadAction<Edge[]>) => {
|
||||
state.edges = action.payload;
|
||||
},
|
||||
},
|
||||
extraReducers: (builder) => {
|
||||
@ -141,6 +154,9 @@ export const {
|
||||
nodeTemplatesBuilt,
|
||||
nodeEditorReset,
|
||||
imageCollectionFieldValueChanged,
|
||||
setEditorInstance,
|
||||
loadFileNodes,
|
||||
loadFileEdges,
|
||||
} = nodesSlice.actions;
|
||||
|
||||
export default nodesSlice.reducer;
|
||||
|
@ -1,94 +0,0 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { getValidControlNets } from 'features/controlNet/util/getValidControlNets';
|
||||
import { CollectInvocation, ControlNetInvocation } from 'services/api/types';
|
||||
import { NonNullableGraph } from '../types/types';
|
||||
import { CONTROL_NET_COLLECT } from './graphBuilders/constants';
|
||||
|
||||
export const addControlNetToLinearGraph = (
|
||||
graph: NonNullableGraph,
|
||||
baseNodeId: string,
|
||||
state: RootState
|
||||
): void => {
|
||||
const { isEnabled: isControlNetEnabled, controlNets } = state.controlNet;
|
||||
|
||||
const validControlNets = getValidControlNets(controlNets);
|
||||
|
||||
if (isControlNetEnabled && Boolean(validControlNets.length)) {
|
||||
if (validControlNets.length > 1) {
|
||||
// We have multiple controlnets, add ControlNet collector
|
||||
const controlNetIterateNode: CollectInvocation = {
|
||||
id: CONTROL_NET_COLLECT,
|
||||
type: 'collect',
|
||||
};
|
||||
graph.nodes[controlNetIterateNode.id] = controlNetIterateNode;
|
||||
graph.edges.push({
|
||||
source: { node_id: controlNetIterateNode.id, field: 'collection' },
|
||||
destination: {
|
||||
node_id: baseNodeId,
|
||||
field: 'control',
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
validControlNets.forEach((controlNet) => {
|
||||
const {
|
||||
controlNetId,
|
||||
controlImage,
|
||||
processedControlImage,
|
||||
beginStepPct,
|
||||
endStepPct,
|
||||
controlMode,
|
||||
model,
|
||||
processorType,
|
||||
weight,
|
||||
} = controlNet;
|
||||
|
||||
const controlNetNode: ControlNetInvocation = {
|
||||
id: `control_net_${controlNetId}`,
|
||||
type: 'controlnet',
|
||||
begin_step_percent: beginStepPct,
|
||||
end_step_percent: endStepPct,
|
||||
control_mode: controlMode,
|
||||
control_model: model as ControlNetInvocation['control_model'],
|
||||
control_weight: weight,
|
||||
};
|
||||
|
||||
if (processedControlImage && processorType !== 'none') {
|
||||
// We've already processed the image in the app, so we can just use the processed image
|
||||
controlNetNode.image = {
|
||||
image_name: processedControlImage,
|
||||
};
|
||||
} else if (controlImage) {
|
||||
// The control image is preprocessed
|
||||
controlNetNode.image = {
|
||||
image_name: controlImage,
|
||||
};
|
||||
} else {
|
||||
// Skip ControlNets without an unprocessed image - should never happen if everything is working correctly
|
||||
return;
|
||||
}
|
||||
|
||||
graph.nodes[controlNetNode.id] = controlNetNode;
|
||||
|
||||
if (validControlNets.length > 1) {
|
||||
// if we have multiple controlnets, link to the collector
|
||||
graph.edges.push({
|
||||
source: { node_id: controlNetNode.id, field: 'control' },
|
||||
destination: {
|
||||
node_id: CONTROL_NET_COLLECT,
|
||||
field: 'item',
|
||||
},
|
||||
});
|
||||
} else {
|
||||
// otherwise, link directly to the base node
|
||||
graph.edges.push({
|
||||
source: { node_id: controlNetNode.id, field: 'control' },
|
||||
destination: {
|
||||
node_id: baseNodeId,
|
||||
field: 'control',
|
||||
},
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
@ -1,40 +0,0 @@
|
||||
import {
|
||||
Edge,
|
||||
ImageToImageInvocation,
|
||||
InpaintInvocation,
|
||||
IterateInvocation,
|
||||
RandomRangeInvocation,
|
||||
RangeInvocation,
|
||||
TextToImageInvocation,
|
||||
} from 'services/api/types';
|
||||
|
||||
export const buildEdges = (
|
||||
baseNode: TextToImageInvocation | ImageToImageInvocation | InpaintInvocation,
|
||||
rangeNode: RangeInvocation | RandomRangeInvocation,
|
||||
iterateNode: IterateInvocation
|
||||
): Edge[] => {
|
||||
const edges: Edge[] = [
|
||||
{
|
||||
source: {
|
||||
node_id: rangeNode.id,
|
||||
field: 'collection',
|
||||
},
|
||||
destination: {
|
||||
node_id: iterateNode.id,
|
||||
field: 'collection',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: iterateNode.id,
|
||||
field: 'item',
|
||||
},
|
||||
destination: {
|
||||
node_id: baseNode.id,
|
||||
field: 'seed',
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
return edges;
|
||||
};
|
@ -0,0 +1,102 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { getValidControlNets } from 'features/controlNet/util/getValidControlNets';
|
||||
import { omit } from 'lodash-es';
|
||||
import {
|
||||
CollectInvocation,
|
||||
ControlField,
|
||||
ControlNetInvocation,
|
||||
MetadataAccumulatorInvocation,
|
||||
} from 'services/api/types';
|
||||
import { NonNullableGraph } from '../../types/types';
|
||||
import { CONTROL_NET_COLLECT, METADATA_ACCUMULATOR } from './constants';
|
||||
|
||||
export const addControlNetToLinearGraph = (
|
||||
state: RootState,
|
||||
graph: NonNullableGraph,
|
||||
baseNodeId: string
|
||||
): void => {
|
||||
const { isEnabled: isControlNetEnabled, controlNets } = state.controlNet;
|
||||
|
||||
const validControlNets = getValidControlNets(controlNets);
|
||||
|
||||
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
|
||||
| MetadataAccumulatorInvocation
|
||||
| undefined;
|
||||
|
||||
if (isControlNetEnabled && Boolean(validControlNets.length)) {
|
||||
if (validControlNets.length) {
|
||||
// We have multiple controlnets, add ControlNet collector
|
||||
const controlNetIterateNode: CollectInvocation = {
|
||||
id: CONTROL_NET_COLLECT,
|
||||
type: 'collect',
|
||||
};
|
||||
graph.nodes[CONTROL_NET_COLLECT] = controlNetIterateNode;
|
||||
graph.edges.push({
|
||||
source: { node_id: CONTROL_NET_COLLECT, field: 'collection' },
|
||||
destination: {
|
||||
node_id: baseNodeId,
|
||||
field: 'control',
|
||||
},
|
||||
});
|
||||
|
||||
validControlNets.forEach((controlNet) => {
|
||||
const {
|
||||
controlNetId,
|
||||
controlImage,
|
||||
processedControlImage,
|
||||
beginStepPct,
|
||||
endStepPct,
|
||||
controlMode,
|
||||
model,
|
||||
processorType,
|
||||
weight,
|
||||
} = controlNet;
|
||||
|
||||
const controlNetNode: ControlNetInvocation = {
|
||||
id: `control_net_${controlNetId}`,
|
||||
type: 'controlnet',
|
||||
begin_step_percent: beginStepPct,
|
||||
end_step_percent: endStepPct,
|
||||
control_mode: controlMode,
|
||||
control_model: model as ControlNetInvocation['control_model'],
|
||||
control_weight: weight,
|
||||
};
|
||||
|
||||
if (processedControlImage && processorType !== 'none') {
|
||||
// We've already processed the image in the app, so we can just use the processed image
|
||||
controlNetNode.image = {
|
||||
image_name: processedControlImage,
|
||||
};
|
||||
} else if (controlImage) {
|
||||
// The control image is preprocessed
|
||||
controlNetNode.image = {
|
||||
image_name: controlImage,
|
||||
};
|
||||
} else {
|
||||
// Skip ControlNets without an unprocessed image - should never happen if everything is working correctly
|
||||
return;
|
||||
}
|
||||
|
||||
graph.nodes[controlNetNode.id] = controlNetNode;
|
||||
|
||||
if (metadataAccumulator) {
|
||||
// metadata accumulator only needs a control field - not the whole node
|
||||
// extract what we need and add to the accumulator
|
||||
const controlField = omit(controlNetNode, [
|
||||
'id',
|
||||
'type',
|
||||
]) as ControlField;
|
||||
metadataAccumulator.controlnets.push(controlField);
|
||||
}
|
||||
|
||||
graph.edges.push({
|
||||
source: { node_id: controlNetNode.id, field: 'control' },
|
||||
destination: {
|
||||
node_id: CONTROL_NET_COLLECT,
|
||||
field: 'item',
|
||||
},
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
@ -1,8 +1,10 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { unset } from 'lodash-es';
|
||||
import {
|
||||
DynamicPromptInvocation,
|
||||
IterateInvocation,
|
||||
MetadataAccumulatorInvocation,
|
||||
NoiseInvocation,
|
||||
RandomIntInvocation,
|
||||
RangeOfSizeInvocation,
|
||||
@ -10,16 +12,16 @@ import {
|
||||
import {
|
||||
DYNAMIC_PROMPT,
|
||||
ITERATE,
|
||||
METADATA_ACCUMULATOR,
|
||||
NOISE,
|
||||
POSITIVE_CONDITIONING,
|
||||
RANDOM_INT,
|
||||
RANGE_OF_SIZE,
|
||||
} from './constants';
|
||||
import { unset } from 'lodash-es';
|
||||
|
||||
export const addDynamicPromptsToGraph = (
|
||||
graph: NonNullableGraph,
|
||||
state: RootState
|
||||
state: RootState,
|
||||
graph: NonNullableGraph
|
||||
): void => {
|
||||
const { positivePrompt, iterations, seed, shouldRandomizeSeed } =
|
||||
state.generation;
|
||||
@ -30,6 +32,10 @@ export const addDynamicPromptsToGraph = (
|
||||
maxPrompts,
|
||||
} = state.dynamicPrompts;
|
||||
|
||||
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
|
||||
| MetadataAccumulatorInvocation
|
||||
| undefined;
|
||||
|
||||
if (isDynamicPromptsEnabled) {
|
||||
// iteration is handled via dynamic prompts
|
||||
unset(graph.nodes[POSITIVE_CONDITIONING], 'prompt');
|
||||
@ -74,6 +80,18 @@ export const addDynamicPromptsToGraph = (
|
||||
}
|
||||
);
|
||||
|
||||
// hook up positive prompt to metadata
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: ITERATE,
|
||||
field: 'item',
|
||||
},
|
||||
destination: {
|
||||
node_id: METADATA_ACCUMULATOR,
|
||||
field: 'positive_prompt',
|
||||
},
|
||||
});
|
||||
|
||||
if (shouldRandomizeSeed) {
|
||||
// Random int node to generate the starting seed
|
||||
const randomIntNode: RandomIntInvocation = {
|
||||
@ -88,11 +106,26 @@ export const addDynamicPromptsToGraph = (
|
||||
source: { node_id: RANDOM_INT, field: 'a' },
|
||||
destination: { node_id: NOISE, field: 'seed' },
|
||||
});
|
||||
|
||||
graph.edges.push({
|
||||
source: { node_id: RANDOM_INT, field: 'a' },
|
||||
destination: { node_id: METADATA_ACCUMULATOR, field: 'seed' },
|
||||
});
|
||||
} else {
|
||||
// User specified seed, so set the start of the range of size to the seed
|
||||
(graph.nodes[NOISE] as NoiseInvocation).seed = seed;
|
||||
|
||||
// hook up seed to metadata
|
||||
if (metadataAccumulator) {
|
||||
metadataAccumulator.seed = seed;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// no dynamic prompt - hook up positive prompt
|
||||
if (metadataAccumulator) {
|
||||
metadataAccumulator.positive_prompt = positivePrompt;
|
||||
}
|
||||
|
||||
const rangeOfSizeNode: RangeOfSizeInvocation = {
|
||||
id: RANGE_OF_SIZE,
|
||||
type: 'range_of_size',
|
||||
@ -130,6 +163,18 @@ export const addDynamicPromptsToGraph = (
|
||||
},
|
||||
});
|
||||
|
||||
// hook up seed to metadata
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: ITERATE,
|
||||
field: 'item',
|
||||
},
|
||||
destination: {
|
||||
node_id: METADATA_ACCUMULATOR,
|
||||
field: 'seed',
|
||||
},
|
||||
});
|
||||
|
||||
// handle seed
|
||||
if (shouldRandomizeSeed) {
|
||||
// Random int node to generate the starting seed
|
||||
|
@ -1,19 +1,23 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { forEach, size } from 'lodash-es';
|
||||
import { LoraLoaderInvocation } from 'services/api/types';
|
||||
import {
|
||||
LoraLoaderInvocation,
|
||||
MetadataAccumulatorInvocation,
|
||||
} from 'services/api/types';
|
||||
import { modelIdToLoRAModelField } from '../modelIdToLoRAName';
|
||||
import {
|
||||
CLIP_SKIP,
|
||||
LORA_LOADER,
|
||||
MAIN_MODEL_LOADER,
|
||||
METADATA_ACCUMULATOR,
|
||||
NEGATIVE_CONDITIONING,
|
||||
POSITIVE_CONDITIONING,
|
||||
} from './constants';
|
||||
|
||||
export const addLoRAsToGraph = (
|
||||
graph: NonNullableGraph,
|
||||
state: RootState,
|
||||
graph: NonNullableGraph,
|
||||
baseNodeId: string
|
||||
): void => {
|
||||
/**
|
||||
@ -26,6 +30,9 @@ export const addLoRAsToGraph = (
|
||||
|
||||
const { loras } = state.lora;
|
||||
const loraCount = size(loras);
|
||||
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
|
||||
| MetadataAccumulatorInvocation
|
||||
| undefined;
|
||||
|
||||
if (loraCount > 0) {
|
||||
// Remove MAIN_MODEL_LOADER unet connection to feed it to LoRAs
|
||||
@ -62,6 +69,12 @@ export const addLoRAsToGraph = (
|
||||
weight,
|
||||
};
|
||||
|
||||
// add the lora to the metadata accumulator
|
||||
if (metadataAccumulator) {
|
||||
metadataAccumulator.loras.push({ lora: loraField, weight });
|
||||
}
|
||||
|
||||
// add to graph
|
||||
graph.nodes[currentLoraNodeId] = loraLoaderNode;
|
||||
|
||||
if (currentLoraIndex === 0) {
|
||||
|
@ -1,5 +1,6 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { MetadataAccumulatorInvocation } from 'services/api/types';
|
||||
import { modelIdToVAEModelField } from '../modelIdToVAEModelField';
|
||||
import {
|
||||
IMAGE_TO_IMAGE_GRAPH,
|
||||
@ -8,18 +9,22 @@ import {
|
||||
INPAINT_GRAPH,
|
||||
LATENTS_TO_IMAGE,
|
||||
MAIN_MODEL_LOADER,
|
||||
METADATA_ACCUMULATOR,
|
||||
TEXT_TO_IMAGE_GRAPH,
|
||||
VAE_LOADER,
|
||||
} from './constants';
|
||||
|
||||
export const addVAEToGraph = (
|
||||
graph: NonNullableGraph,
|
||||
state: RootState
|
||||
state: RootState,
|
||||
graph: NonNullableGraph
|
||||
): void => {
|
||||
const { vae } = state.generation;
|
||||
const vae_model = modelIdToVAEModelField(vae?.id || '');
|
||||
|
||||
const isAutoVae = !vae;
|
||||
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
|
||||
| MetadataAccumulatorInvocation
|
||||
| undefined;
|
||||
|
||||
if (!isAutoVae) {
|
||||
graph.nodes[VAE_LOADER] = {
|
||||
@ -67,4 +72,8 @@ export const addVAEToGraph = (
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
if (vae && metadataAccumulator) {
|
||||
metadataAccumulator.vae = vae_model;
|
||||
}
|
||||
};
|
||||
|
@ -7,8 +7,7 @@ import {
|
||||
ImageResizeInvocation,
|
||||
ImageToLatentsInvocation,
|
||||
} from 'services/api/types';
|
||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||
import { modelIdToMainModelField } from '../modelIdToMainModelField';
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
||||
import { addLoRAsToGraph } from './addLoRAsToGraph';
|
||||
import { addVAEToGraph } from './addVAEToGraph';
|
||||
@ -19,6 +18,7 @@ import {
|
||||
LATENTS_TO_IMAGE,
|
||||
LATENTS_TO_LATENTS,
|
||||
MAIN_MODEL_LOADER,
|
||||
METADATA_ACCUMULATOR,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NOISE,
|
||||
POSITIVE_CONDITIONING,
|
||||
@ -37,7 +37,7 @@ export const buildCanvasImageToImageGraph = (
|
||||
const {
|
||||
positivePrompt,
|
||||
negativePrompt,
|
||||
model: currentModel,
|
||||
model,
|
||||
cfgScale: cfg_scale,
|
||||
scheduler,
|
||||
steps,
|
||||
@ -50,7 +50,10 @@ export const buildCanvasImageToImageGraph = (
|
||||
// The bounding box determines width and height, not the width and height params
|
||||
const { width, height } = state.canvas.boundingBoxDimensions;
|
||||
|
||||
const model = modelIdToMainModelField(currentModel?.id || '');
|
||||
if (!model) {
|
||||
moduleLog.error('No model found in state');
|
||||
throw new Error('No model found in state');
|
||||
}
|
||||
|
||||
const use_cpu = shouldUseNoiseSettings
|
||||
? shouldUseCpuNoise
|
||||
@ -275,16 +278,51 @@ export const buildCanvasImageToImageGraph = (
|
||||
});
|
||||
}
|
||||
|
||||
addLoRAsToGraph(graph, state, LATENTS_TO_LATENTS);
|
||||
// add metadata accumulator, which is only mostly populated - some fields are added later
|
||||
graph.nodes[METADATA_ACCUMULATOR] = {
|
||||
id: METADATA_ACCUMULATOR,
|
||||
type: 'metadata_accumulator',
|
||||
generation_mode: 'img2img',
|
||||
cfg_scale,
|
||||
height,
|
||||
width,
|
||||
positive_prompt: '', // set in addDynamicPromptsToGraph
|
||||
negative_prompt: negativePrompt,
|
||||
model,
|
||||
seed: 0, // set in addDynamicPromptsToGraph
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
scheduler,
|
||||
vae: undefined, // option; set in addVAEToGraph
|
||||
controlnets: [], // populated in addControlNetToLinearGraph
|
||||
loras: [], // populated in addLoRAsToGraph
|
||||
clip_skip: clipSkip,
|
||||
strength,
|
||||
init_image: initialImage.image_name,
|
||||
};
|
||||
|
||||
// Add VAE
|
||||
addVAEToGraph(graph, state);
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: METADATA_ACCUMULATOR,
|
||||
field: 'metadata',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
field: 'metadata',
|
||||
},
|
||||
});
|
||||
|
||||
// add dynamic prompts, mutating `graph`
|
||||
addDynamicPromptsToGraph(graph, state);
|
||||
// add LoRA support
|
||||
addLoRAsToGraph(state, graph, LATENTS_TO_LATENTS);
|
||||
|
||||
// optionally add custom VAE
|
||||
addVAEToGraph(state, graph);
|
||||
|
||||
// add dynamic prompts - also sets up core iteration and seed
|
||||
addDynamicPromptsToGraph(state, graph);
|
||||
|
||||
// add controlnet, mutating `graph`
|
||||
addControlNetToLinearGraph(graph, LATENTS_TO_LATENTS, state);
|
||||
addControlNetToLinearGraph(state, graph, LATENTS_TO_LATENTS);
|
||||
|
||||
return graph;
|
||||
};
|
||||
|
@ -7,7 +7,6 @@ import {
|
||||
RandomIntInvocation,
|
||||
RangeOfSizeInvocation,
|
||||
} from 'services/api/types';
|
||||
import { modelIdToMainModelField } from '../modelIdToMainModelField';
|
||||
import { addLoRAsToGraph } from './addLoRAsToGraph';
|
||||
import { addVAEToGraph } from './addVAEToGraph';
|
||||
import {
|
||||
@ -35,7 +34,7 @@ export const buildCanvasInpaintGraph = (
|
||||
const {
|
||||
positivePrompt,
|
||||
negativePrompt,
|
||||
model: currentModel,
|
||||
model,
|
||||
cfgScale: cfg_scale,
|
||||
scheduler,
|
||||
steps,
|
||||
@ -53,14 +52,17 @@ export const buildCanvasInpaintGraph = (
|
||||
clipSkip,
|
||||
} = state.generation;
|
||||
|
||||
if (!model) {
|
||||
moduleLog.error('No model found in state');
|
||||
throw new Error('No model found in state');
|
||||
}
|
||||
|
||||
// The bounding box determines width and height, not the width and height params
|
||||
const { width, height } = state.canvas.boundingBoxDimensions;
|
||||
|
||||
// We may need to set the inpaint width and height to scale the image
|
||||
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
|
||||
|
||||
const model = modelIdToMainModelField(currentModel?.id || '');
|
||||
|
||||
const graph: NonNullableGraph = {
|
||||
id: INPAINT_GRAPH,
|
||||
nodes: {
|
||||
@ -212,10 +214,10 @@ export const buildCanvasInpaintGraph = (
|
||||
],
|
||||
};
|
||||
|
||||
addLoRAsToGraph(graph, state, INPAINT);
|
||||
addLoRAsToGraph(state, graph, INPAINT);
|
||||
|
||||
// Add VAE
|
||||
addVAEToGraph(graph, state);
|
||||
addVAEToGraph(state, graph);
|
||||
|
||||
// handle seed
|
||||
if (shouldRandomizeSeed) {
|
||||
|
@ -1,8 +1,8 @@
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { initialGenerationState } from 'features/parameters/store/generationSlice';
|
||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||
import { modelIdToMainModelField } from '../modelIdToMainModelField';
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
||||
import { addLoRAsToGraph } from './addLoRAsToGraph';
|
||||
import { addVAEToGraph } from './addVAEToGraph';
|
||||
@ -10,6 +10,7 @@ import {
|
||||
CLIP_SKIP,
|
||||
LATENTS_TO_IMAGE,
|
||||
MAIN_MODEL_LOADER,
|
||||
METADATA_ACCUMULATOR,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NOISE,
|
||||
POSITIVE_CONDITIONING,
|
||||
@ -17,6 +18,8 @@ import {
|
||||
TEXT_TO_LATENTS,
|
||||
} from './constants';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'nodes' });
|
||||
|
||||
/**
|
||||
* Builds the Canvas tab's Text to Image graph.
|
||||
*/
|
||||
@ -26,7 +29,7 @@ export const buildCanvasTextToImageGraph = (
|
||||
const {
|
||||
positivePrompt,
|
||||
negativePrompt,
|
||||
model: currentModel,
|
||||
model,
|
||||
cfgScale: cfg_scale,
|
||||
scheduler,
|
||||
steps,
|
||||
@ -38,7 +41,10 @@ export const buildCanvasTextToImageGraph = (
|
||||
// The bounding box determines width and height, not the width and height params
|
||||
const { width, height } = state.canvas.boundingBoxDimensions;
|
||||
|
||||
const model = modelIdToMainModelField(currentModel?.id || '');
|
||||
if (!model) {
|
||||
moduleLog.error('No model found in state');
|
||||
throw new Error('No model found in state');
|
||||
}
|
||||
|
||||
const use_cpu = shouldUseNoiseSettings
|
||||
? shouldUseCpuNoise
|
||||
@ -180,16 +186,49 @@ export const buildCanvasTextToImageGraph = (
|
||||
],
|
||||
};
|
||||
|
||||
addLoRAsToGraph(graph, state, TEXT_TO_LATENTS);
|
||||
// add metadata accumulator, which is only mostly populated - some fields are added later
|
||||
graph.nodes[METADATA_ACCUMULATOR] = {
|
||||
id: METADATA_ACCUMULATOR,
|
||||
type: 'metadata_accumulator',
|
||||
generation_mode: 'txt2img',
|
||||
cfg_scale,
|
||||
height,
|
||||
width,
|
||||
positive_prompt: '', // set in addDynamicPromptsToGraph
|
||||
negative_prompt: negativePrompt,
|
||||
model,
|
||||
seed: 0, // set in addDynamicPromptsToGraph
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
scheduler,
|
||||
vae: undefined, // option; set in addVAEToGraph
|
||||
controlnets: [], // populated in addControlNetToLinearGraph
|
||||
loras: [], // populated in addLoRAsToGraph
|
||||
clip_skip: clipSkip,
|
||||
};
|
||||
|
||||
// Add VAE
|
||||
addVAEToGraph(graph, state);
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: METADATA_ACCUMULATOR,
|
||||
field: 'metadata',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
field: 'metadata',
|
||||
},
|
||||
});
|
||||
|
||||
// add dynamic prompts, mutating `graph`
|
||||
addDynamicPromptsToGraph(graph, state);
|
||||
// add LoRA support
|
||||
addLoRAsToGraph(state, graph, TEXT_TO_LATENTS);
|
||||
|
||||
// optionally add custom VAE
|
||||
addVAEToGraph(state, graph);
|
||||
|
||||
// add dynamic prompts - also sets up core iteration and seed
|
||||
addDynamicPromptsToGraph(state, graph);
|
||||
|
||||
// add controlnet, mutating `graph`
|
||||
addControlNetToLinearGraph(graph, TEXT_TO_LATENTS, state);
|
||||
addControlNetToLinearGraph(state, graph, TEXT_TO_LATENTS);
|
||||
|
||||
return graph;
|
||||
};
|
||||
|
@ -3,25 +3,21 @@ import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { initialGenerationState } from 'features/parameters/store/generationSlice';
|
||||
import {
|
||||
ImageCollectionInvocation,
|
||||
ImageResizeInvocation,
|
||||
ImageToLatentsInvocation,
|
||||
IterateInvocation,
|
||||
} from 'services/api/types';
|
||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||
import { modelIdToMainModelField } from '../modelIdToMainModelField';
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
||||
import { addLoRAsToGraph } from './addLoRAsToGraph';
|
||||
import { addVAEToGraph } from './addVAEToGraph';
|
||||
import {
|
||||
CLIP_SKIP,
|
||||
IMAGE_COLLECTION,
|
||||
IMAGE_COLLECTION_ITERATE,
|
||||
IMAGE_TO_IMAGE_GRAPH,
|
||||
IMAGE_TO_LATENTS,
|
||||
LATENTS_TO_IMAGE,
|
||||
LATENTS_TO_LATENTS,
|
||||
MAIN_MODEL_LOADER,
|
||||
METADATA_ACCUMULATOR,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NOISE,
|
||||
POSITIVE_CONDITIONING,
|
||||
@ -39,7 +35,7 @@ export const buildLinearImageToImageGraph = (
|
||||
const {
|
||||
positivePrompt,
|
||||
negativePrompt,
|
||||
model: currentModel,
|
||||
model,
|
||||
cfgScale: cfg_scale,
|
||||
scheduler,
|
||||
steps,
|
||||
@ -53,14 +49,15 @@ export const buildLinearImageToImageGraph = (
|
||||
shouldUseNoiseSettings,
|
||||
} = state.generation;
|
||||
|
||||
const {
|
||||
isEnabled: isBatchEnabled,
|
||||
imageNames: batchImageNames,
|
||||
asInitialImage,
|
||||
} = state.batch;
|
||||
// TODO: add batch functionality
|
||||
// const {
|
||||
// isEnabled: isBatchEnabled,
|
||||
// imageNames: batchImageNames,
|
||||
// asInitialImage,
|
||||
// } = state.batch;
|
||||
|
||||
const shouldBatch =
|
||||
isBatchEnabled && batchImageNames.length > 0 && asInitialImage;
|
||||
// const shouldBatch =
|
||||
// isBatchEnabled && batchImageNames.length > 0 && asInitialImage;
|
||||
|
||||
/**
|
||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||
@ -71,12 +68,15 @@ export const buildLinearImageToImageGraph = (
|
||||
* the `fit` param. These are added to the graph at the end.
|
||||
*/
|
||||
|
||||
if (!initialImage && !shouldBatch) {
|
||||
if (!initialImage) {
|
||||
moduleLog.error('No initial image found in state');
|
||||
throw new Error('No initial image found in state');
|
||||
}
|
||||
|
||||
const model = modelIdToMainModelField(currentModel?.id || '');
|
||||
if (!model) {
|
||||
moduleLog.error('No model found in state');
|
||||
throw new Error('No model found in state');
|
||||
}
|
||||
|
||||
const use_cpu = shouldUseNoiseSettings
|
||||
? shouldUseCpuNoise
|
||||
@ -295,51 +295,87 @@ export const buildLinearImageToImageGraph = (
|
||||
});
|
||||
}
|
||||
|
||||
if (isBatchEnabled && asInitialImage && batchImageNames.length > 0) {
|
||||
// we are going to connect an iterate up to the init image
|
||||
delete (graph.nodes[IMAGE_TO_LATENTS] as ImageToLatentsInvocation).image;
|
||||
// TODO: add batch functionality
|
||||
// if (isBatchEnabled && asInitialImage && batchImageNames.length > 0) {
|
||||
// // we are going to connect an iterate up to the init image
|
||||
// delete (graph.nodes[IMAGE_TO_LATENTS] as ImageToLatentsInvocation).image;
|
||||
|
||||
const imageCollection: ImageCollectionInvocation = {
|
||||
id: IMAGE_COLLECTION,
|
||||
type: 'image_collection',
|
||||
images: batchImageNames.map((image_name) => ({ image_name })),
|
||||
};
|
||||
// const imageCollection: ImageCollectionInvocation = {
|
||||
// id: IMAGE_COLLECTION,
|
||||
// type: 'image_collection',
|
||||
// images: batchImageNames.map((image_name) => ({ image_name })),
|
||||
// };
|
||||
|
||||
const imageCollectionIterate: IterateInvocation = {
|
||||
id: IMAGE_COLLECTION_ITERATE,
|
||||
type: 'iterate',
|
||||
};
|
||||
// const imageCollectionIterate: IterateInvocation = {
|
||||
// id: IMAGE_COLLECTION_ITERATE,
|
||||
// type: 'iterate',
|
||||
// };
|
||||
|
||||
graph.nodes[IMAGE_COLLECTION] = imageCollection;
|
||||
graph.nodes[IMAGE_COLLECTION_ITERATE] = imageCollectionIterate;
|
||||
// graph.nodes[IMAGE_COLLECTION] = imageCollection;
|
||||
// graph.nodes[IMAGE_COLLECTION_ITERATE] = imageCollectionIterate;
|
||||
|
||||
graph.edges.push({
|
||||
source: { node_id: IMAGE_COLLECTION, field: 'collection' },
|
||||
destination: {
|
||||
node_id: IMAGE_COLLECTION_ITERATE,
|
||||
field: 'collection',
|
||||
},
|
||||
});
|
||||
// graph.edges.push({
|
||||
// source: { node_id: IMAGE_COLLECTION, field: 'collection' },
|
||||
// destination: {
|
||||
// node_id: IMAGE_COLLECTION_ITERATE,
|
||||
// field: 'collection',
|
||||
// },
|
||||
// });
|
||||
|
||||
graph.edges.push({
|
||||
source: { node_id: IMAGE_COLLECTION_ITERATE, field: 'item' },
|
||||
destination: {
|
||||
node_id: IMAGE_TO_LATENTS,
|
||||
field: 'image',
|
||||
},
|
||||
});
|
||||
}
|
||||
// graph.edges.push({
|
||||
// source: { node_id: IMAGE_COLLECTION_ITERATE, field: 'item' },
|
||||
// destination: {
|
||||
// node_id: IMAGE_TO_LATENTS,
|
||||
// field: 'image',
|
||||
// },
|
||||
// });
|
||||
// }
|
||||
|
||||
addLoRAsToGraph(graph, state, LATENTS_TO_LATENTS);
|
||||
// add metadata accumulator, which is only mostly populated - some fields are added later
|
||||
graph.nodes[METADATA_ACCUMULATOR] = {
|
||||
id: METADATA_ACCUMULATOR,
|
||||
type: 'metadata_accumulator',
|
||||
generation_mode: 'img2img',
|
||||
cfg_scale,
|
||||
height,
|
||||
width,
|
||||
positive_prompt: '', // set in addDynamicPromptsToGraph
|
||||
negative_prompt: negativePrompt,
|
||||
model,
|
||||
seed: 0, // set in addDynamicPromptsToGraph
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
scheduler,
|
||||
vae: undefined, // option; set in addVAEToGraph
|
||||
controlnets: [], // populated in addControlNetToLinearGraph
|
||||
loras: [], // populated in addLoRAsToGraph
|
||||
clip_skip: clipSkip,
|
||||
strength,
|
||||
init_image: initialImage.imageName,
|
||||
};
|
||||
|
||||
// Add VAE
|
||||
addVAEToGraph(graph, state);
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: METADATA_ACCUMULATOR,
|
||||
field: 'metadata',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
field: 'metadata',
|
||||
},
|
||||
});
|
||||
|
||||
// add dynamic prompts, mutating `graph`
|
||||
addDynamicPromptsToGraph(graph, state);
|
||||
// add LoRA support
|
||||
addLoRAsToGraph(state, graph, LATENTS_TO_LATENTS);
|
||||
|
||||
// optionally add custom VAE
|
||||
addVAEToGraph(state, graph);
|
||||
|
||||
// add dynamic prompts - also sets up core iteration and seed
|
||||
addDynamicPromptsToGraph(state, graph);
|
||||
|
||||
// add controlnet, mutating `graph`
|
||||
addControlNetToLinearGraph(graph, LATENTS_TO_LATENTS, state);
|
||||
addControlNetToLinearGraph(state, graph, LATENTS_TO_LATENTS);
|
||||
|
||||
return graph;
|
||||
};
|
||||
|
@ -1,8 +1,8 @@
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { initialGenerationState } from 'features/parameters/store/generationSlice';
|
||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||
import { modelIdToMainModelField } from '../modelIdToMainModelField';
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
||||
import { addLoRAsToGraph } from './addLoRAsToGraph';
|
||||
import { addVAEToGraph } from './addVAEToGraph';
|
||||
@ -10,6 +10,7 @@ import {
|
||||
CLIP_SKIP,
|
||||
LATENTS_TO_IMAGE,
|
||||
MAIN_MODEL_LOADER,
|
||||
METADATA_ACCUMULATOR,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NOISE,
|
||||
POSITIVE_CONDITIONING,
|
||||
@ -17,13 +18,15 @@ import {
|
||||
TEXT_TO_LATENTS,
|
||||
} from './constants';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'nodes' });
|
||||
|
||||
export const buildLinearTextToImageGraph = (
|
||||
state: RootState
|
||||
): NonNullableGraph => {
|
||||
const {
|
||||
positivePrompt,
|
||||
negativePrompt,
|
||||
model: currentModel,
|
||||
model,
|
||||
cfgScale: cfg_scale,
|
||||
scheduler,
|
||||
steps,
|
||||
@ -34,12 +37,15 @@ export const buildLinearTextToImageGraph = (
|
||||
shouldUseNoiseSettings,
|
||||
} = state.generation;
|
||||
|
||||
const model = modelIdToMainModelField(currentModel?.id || '');
|
||||
|
||||
const use_cpu = shouldUseNoiseSettings
|
||||
? shouldUseCpuNoise
|
||||
: initialGenerationState.shouldUseCpuNoise;
|
||||
|
||||
if (!model) {
|
||||
moduleLog.error('No model found in state');
|
||||
throw new Error('No model found in state');
|
||||
}
|
||||
|
||||
/**
|
||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
||||
@ -176,16 +182,49 @@ export const buildLinearTextToImageGraph = (
|
||||
],
|
||||
};
|
||||
|
||||
addLoRAsToGraph(graph, state, TEXT_TO_LATENTS);
|
||||
// add metadata accumulator, which is only mostly populated - some fields are added later
|
||||
graph.nodes[METADATA_ACCUMULATOR] = {
|
||||
id: METADATA_ACCUMULATOR,
|
||||
type: 'metadata_accumulator',
|
||||
generation_mode: 'txt2img',
|
||||
cfg_scale,
|
||||
height,
|
||||
width,
|
||||
positive_prompt: '', // set in addDynamicPromptsToGraph
|
||||
negative_prompt: negativePrompt,
|
||||
model,
|
||||
seed: 0, // set in addDynamicPromptsToGraph
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
scheduler,
|
||||
vae: undefined, // option; set in addVAEToGraph
|
||||
controlnets: [], // populated in addControlNetToLinearGraph
|
||||
loras: [], // populated in addLoRAsToGraph
|
||||
clip_skip: clipSkip,
|
||||
};
|
||||
|
||||
// Add Custom VAE Support
|
||||
addVAEToGraph(graph, state);
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: METADATA_ACCUMULATOR,
|
||||
field: 'metadata',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
field: 'metadata',
|
||||
},
|
||||
});
|
||||
|
||||
// add dynamic prompts, mutating `graph`
|
||||
addDynamicPromptsToGraph(graph, state);
|
||||
// add LoRA support
|
||||
addLoRAsToGraph(state, graph, TEXT_TO_LATENTS);
|
||||
|
||||
// optionally add custom VAE
|
||||
addVAEToGraph(state, graph);
|
||||
|
||||
// add dynamic prompts - also sets up core iteration and seed
|
||||
addDynamicPromptsToGraph(state, graph);
|
||||
|
||||
// add controlnet, mutating `graph`
|
||||
addControlNetToLinearGraph(graph, TEXT_TO_LATENTS, state);
|
||||
addControlNetToLinearGraph(state, graph, TEXT_TO_LATENTS);
|
||||
|
||||
return graph;
|
||||
};
|
||||
|
@ -19,6 +19,7 @@ export const CONTROL_NET_COLLECT = 'control_net_collect';
|
||||
export const DYNAMIC_PROMPT = 'dynamic_prompt';
|
||||
export const IMAGE_COLLECTION = 'image_collection';
|
||||
export const IMAGE_COLLECTION_ITERATE = 'image_collection_iterate';
|
||||
export const METADATA_ACCUMULATOR = 'metadata_accumulator';
|
||||
|
||||
// friendly graph ids
|
||||
export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph';
|
||||
|
@ -5,17 +5,21 @@ import {
|
||||
InputFieldTemplate,
|
||||
InvocationSchemaObject,
|
||||
InvocationTemplate,
|
||||
isInvocationSchemaObject,
|
||||
OutputFieldTemplate,
|
||||
isInvocationSchemaObject,
|
||||
} from '../types/types';
|
||||
import {
|
||||
buildInputFieldTemplate,
|
||||
buildOutputFieldTemplates,
|
||||
} from './fieldTemplateBuilders';
|
||||
|
||||
const RESERVED_FIELD_NAMES = ['id', 'type', 'is_intermediate'];
|
||||
const RESERVED_FIELD_NAMES = ['id', 'type', 'is_intermediate', 'metadata'];
|
||||
|
||||
const invocationDenylist = ['Graph', 'InvocationMeta'];
|
||||
const invocationDenylist = [
|
||||
'Graph',
|
||||
'InvocationMeta',
|
||||
'MetadataAccumulatorInvocation',
|
||||
];
|
||||
|
||||
export const parseSchema = (openAPI: OpenAPIV3.Document) => {
|
||||
// filter out non-invocation schemas, plus some tricky invocations for now
|
||||
|
@ -2,22 +2,26 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { roundToMultiple } from 'common/util/roundDownToMultiple';
|
||||
import {
|
||||
canvasSelector,
|
||||
isStagingSelector,
|
||||
} from 'features/canvas/store/canvasSelectors';
|
||||
import { setBoundingBoxDimensions } from 'features/canvas/store/canvasSlice';
|
||||
import { uiSelector } from 'features/ui/store/uiSelectors';
|
||||
import { memo } from 'react';
|
||||
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const selector = createSelector(
|
||||
[canvasSelector, isStagingSelector],
|
||||
(canvas, isStaging) => {
|
||||
[canvasSelector, isStagingSelector, uiSelector],
|
||||
(canvas, isStaging, ui) => {
|
||||
const { boundingBoxDimensions } = canvas;
|
||||
const { aspectRatio } = ui;
|
||||
return {
|
||||
boundingBoxDimensions,
|
||||
isStaging,
|
||||
aspectRatio,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
@ -25,7 +29,8 @@ const selector = createSelector(
|
||||
|
||||
const ParamBoundingBoxWidth = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { boundingBoxDimensions, isStaging } = useAppSelector(selector);
|
||||
const { boundingBoxDimensions, isStaging, aspectRatio } =
|
||||
useAppSelector(selector);
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
@ -36,6 +41,15 @@ const ParamBoundingBoxWidth = () => {
|
||||
height: Math.floor(v),
|
||||
})
|
||||
);
|
||||
if (aspectRatio) {
|
||||
const newWidth = roundToMultiple(v * aspectRatio, 64);
|
||||
dispatch(
|
||||
setBoundingBoxDimensions({
|
||||
width: newWidth,
|
||||
height: Math.floor(v),
|
||||
})
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
const handleResetHeight = () => {
|
||||
@ -45,6 +59,15 @@ const ParamBoundingBoxWidth = () => {
|
||||
height: Math.floor(512),
|
||||
})
|
||||
);
|
||||
if (aspectRatio) {
|
||||
const newWidth = roundToMultiple(512 * aspectRatio, 64);
|
||||
dispatch(
|
||||
setBoundingBoxDimensions({
|
||||
width: newWidth,
|
||||
height: Math.floor(512),
|
||||
})
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
|
@ -0,0 +1,57 @@
|
||||
import { Flex, Spacer, Text } from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { flipBoundingBoxAxes } from 'features/canvas/store/canvasSlice';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { MdOutlineSwapVert } from 'react-icons/md';
|
||||
import ParamAspectRatio from '../../Core/ParamAspectRatio';
|
||||
import ParamBoundingBoxHeight from './ParamBoundingBoxHeight';
|
||||
import ParamBoundingBoxWidth from './ParamBoundingBoxWidth';
|
||||
|
||||
export default function ParamBoundingBoxSize() {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
gap: 2,
|
||||
p: 4,
|
||||
borderRadius: 4,
|
||||
flexDirection: 'column',
|
||||
w: 'full',
|
||||
bg: 'base.150',
|
||||
_dark: {
|
||||
bg: 'base.750',
|
||||
},
|
||||
}}
|
||||
>
|
||||
<Flex alignItems="center" gap={2}>
|
||||
<Text
|
||||
sx={{
|
||||
fontSize: 'sm',
|
||||
width: 'full',
|
||||
color: 'base.700',
|
||||
_dark: {
|
||||
color: 'base.300',
|
||||
},
|
||||
}}
|
||||
>
|
||||
{t('parameters.aspectRatio')}
|
||||
</Text>
|
||||
<Spacer />
|
||||
<ParamAspectRatio />
|
||||
<IAIIconButton
|
||||
tooltip={t('ui.swapSizes')}
|
||||
aria-label={t('ui.swapSizes')}
|
||||
size="sm"
|
||||
icon={<MdOutlineSwapVert />}
|
||||
fontSize={20}
|
||||
onClick={() => dispatch(flipBoundingBoxAxes())}
|
||||
/>
|
||||
</Flex>
|
||||
<ParamBoundingBoxWidth />
|
||||
<ParamBoundingBoxHeight />
|
||||
</Flex>
|
||||
);
|
||||
}
|
@ -2,22 +2,26 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { roundToMultiple } from 'common/util/roundDownToMultiple';
|
||||
import {
|
||||
canvasSelector,
|
||||
isStagingSelector,
|
||||
} from 'features/canvas/store/canvasSelectors';
|
||||
import { setBoundingBoxDimensions } from 'features/canvas/store/canvasSlice';
|
||||
import { uiSelector } from 'features/ui/store/uiSelectors';
|
||||
import { memo } from 'react';
|
||||
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const selector = createSelector(
|
||||
[canvasSelector, isStagingSelector],
|
||||
(canvas, isStaging) => {
|
||||
[canvasSelector, isStagingSelector, uiSelector],
|
||||
(canvas, isStaging, ui) => {
|
||||
const { boundingBoxDimensions } = canvas;
|
||||
const { aspectRatio } = ui;
|
||||
return {
|
||||
boundingBoxDimensions,
|
||||
isStaging,
|
||||
aspectRatio,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
@ -25,7 +29,8 @@ const selector = createSelector(
|
||||
|
||||
const ParamBoundingBoxWidth = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { boundingBoxDimensions, isStaging } = useAppSelector(selector);
|
||||
const { boundingBoxDimensions, isStaging, aspectRatio } =
|
||||
useAppSelector(selector);
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
@ -36,6 +41,15 @@ const ParamBoundingBoxWidth = () => {
|
||||
width: Math.floor(v),
|
||||
})
|
||||
);
|
||||
if (aspectRatio) {
|
||||
const newHeight = roundToMultiple(v / aspectRatio, 64);
|
||||
dispatch(
|
||||
setBoundingBoxDimensions({
|
||||
width: Math.floor(v),
|
||||
height: newHeight,
|
||||
})
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
const handleResetWidth = () => {
|
||||
@ -45,6 +59,15 @@ const ParamBoundingBoxWidth = () => {
|
||||
width: Math.floor(512),
|
||||
})
|
||||
);
|
||||
if (aspectRatio) {
|
||||
const newHeight = roundToMultiple(512 / aspectRatio, 64);
|
||||
dispatch(
|
||||
setBoundingBoxDimensions({
|
||||
width: Math.floor(512),
|
||||
height: newHeight,
|
||||
})
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
|
@ -27,6 +27,9 @@ const ParamNoiseCollapse = () => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const isNoiseEnabled = useFeatureStatus('noise').isFeatureEnabled;
|
||||
const isPerlinNoiseEnabled = useFeatureStatus('perlinNoise').isFeatureEnabled;
|
||||
const isNoiseThresholdEnabled =
|
||||
useFeatureStatus('noiseThreshold').isFeatureEnabled;
|
||||
|
||||
const { activeLabel } = useAppSelector(selector);
|
||||
|
||||
@ -42,8 +45,8 @@ const ParamNoiseCollapse = () => {
|
||||
<Flex sx={{ gap: 2, flexDirection: 'column' }}>
|
||||
<ParamNoiseToggle />
|
||||
<ParamCpuNoiseToggle />
|
||||
<ParamPerlinNoise />
|
||||
<ParamNoiseThreshold />
|
||||
{isPerlinNoiseEnabled && <ParamPerlinNoise />}
|
||||
{isNoiseThresholdEnabled && <ParamNoiseThreshold />}
|
||||
</Flex>
|
||||
</IAICollapse>
|
||||
);
|
||||
|
@ -2,6 +2,7 @@ import { useAppToaster } from 'app/components/Toaster';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { UnsafeImageMetadata } from 'services/api/endpoints/images';
|
||||
import { isImageField } from 'services/api/guards';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import { initialImageSelected, modelSelected } from '../store/actions';
|
||||
@ -162,7 +163,7 @@ export const useRecallParameters = () => {
|
||||
parameterNotSetToast();
|
||||
return;
|
||||
}
|
||||
dispatch(modelSelected(model?.id || ''));
|
||||
dispatch(modelSelected(model));
|
||||
parameterSetToast();
|
||||
},
|
||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||
@ -269,28 +270,24 @@ export const useRecallParameters = () => {
|
||||
);
|
||||
|
||||
const recallAllParameters = useCallback(
|
||||
(image: ImageDTO | undefined) => {
|
||||
if (!image || !image.metadata) {
|
||||
(metadata: UnsafeImageMetadata['metadata'] | undefined) => {
|
||||
if (!metadata) {
|
||||
allParameterNotSetToast();
|
||||
return;
|
||||
}
|
||||
|
||||
const {
|
||||
cfg_scale,
|
||||
height,
|
||||
model,
|
||||
positive_conditioning,
|
||||
negative_conditioning,
|
||||
positive_prompt,
|
||||
negative_prompt,
|
||||
scheduler,
|
||||
seed,
|
||||
steps,
|
||||
width,
|
||||
strength,
|
||||
clip,
|
||||
extra,
|
||||
latents,
|
||||
unet,
|
||||
vae,
|
||||
} = image.metadata;
|
||||
} = metadata;
|
||||
|
||||
if (isValidCfgScale(cfg_scale)) {
|
||||
dispatch(setCfgScale(cfg_scale));
|
||||
@ -298,11 +295,11 @@ export const useRecallParameters = () => {
|
||||
if (isValidMainModel(model)) {
|
||||
dispatch(modelSelected(model));
|
||||
}
|
||||
if (isValidPositivePrompt(positive_conditioning)) {
|
||||
dispatch(setPositivePrompt(positive_conditioning));
|
||||
if (isValidPositivePrompt(positive_prompt)) {
|
||||
dispatch(setPositivePrompt(positive_prompt));
|
||||
}
|
||||
if (isValidNegativePrompt(negative_conditioning)) {
|
||||
dispatch(setNegativePrompt(negative_conditioning));
|
||||
if (isValidNegativePrompt(negative_prompt)) {
|
||||
dispatch(setNegativePrompt(negative_prompt));
|
||||
}
|
||||
if (isValidScheduler(scheduler)) {
|
||||
dispatch(setScheduler(scheduler));
|
||||
|
@ -1,8 +1,10 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import { ImageDTO, MainModelField } from 'services/api/types';
|
||||
|
||||
export const initialImageSelected = createAction<ImageDTO | string | undefined>(
|
||||
'generation/initialImageSelected'
|
||||
);
|
||||
|
||||
export const modelSelected = createAction<string>('generation/modelSelected');
|
||||
export const modelSelected = createAction<MainModelField>(
|
||||
'generation/modelSelected'
|
||||
);
|
||||
|
@ -8,12 +8,11 @@ import {
|
||||
setShouldShowAdvancedOptions,
|
||||
} from 'features/ui/store/uiSlice';
|
||||
import { clamp } from 'lodash-es';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import { ImageDTO, MainModelField } from 'services/api/types';
|
||||
import { clipSkipMap } from '../components/Parameters/Advanced/ParamClipSkip';
|
||||
import {
|
||||
CfgScaleParam,
|
||||
HeightParam,
|
||||
MainModelParam,
|
||||
NegativePromptParam,
|
||||
PositivePromptParam,
|
||||
SchedulerParam,
|
||||
@ -54,7 +53,7 @@ export interface GenerationState {
|
||||
shouldUseSymmetry: boolean;
|
||||
horizontalSymmetrySteps: number;
|
||||
verticalSymmetrySteps: number;
|
||||
model: MainModelParam | null;
|
||||
model: MainModelField | null;
|
||||
vae: VaeModelParam | null;
|
||||
seamlessXAxis: boolean;
|
||||
seamlessYAxis: boolean;
|
||||
@ -227,23 +226,17 @@ export const generationSlice = createSlice({
|
||||
const { image_name, width, height } = action.payload;
|
||||
state.initialImage = { imageName: image_name, width, height };
|
||||
},
|
||||
modelSelected: (state, action: PayloadAction<string>) => {
|
||||
const [base_model, type, name] = action.payload.split('/');
|
||||
modelChanged: (state, action: PayloadAction<MainModelField | null>) => {
|
||||
if (!action.payload) {
|
||||
state.model = null;
|
||||
}
|
||||
|
||||
state.model = zMainModel.parse({
|
||||
id: action.payload,
|
||||
base_model,
|
||||
name,
|
||||
type,
|
||||
});
|
||||
state.model = zMainModel.parse(action.payload);
|
||||
|
||||
// Clamp ClipSkip Based On Selected Model
|
||||
const { maxClip } = clipSkipMap[state.model.base_model];
|
||||
state.clipSkip = clamp(state.clipSkip, 0, maxClip);
|
||||
},
|
||||
modelChanged: (state, action: PayloadAction<MainModelParam>) => {
|
||||
state.model = action.payload;
|
||||
},
|
||||
vaeSelected: (state, action: PayloadAction<VaeModelParam | null>) => {
|
||||
state.vae = action.payload;
|
||||
},
|
||||
|
@ -135,8 +135,7 @@ export type BaseModelParam = z.infer<typeof zBaseModel>;
|
||||
* TODO: Make this a dynamically generated enum?
|
||||
*/
|
||||
export const zMainModel = z.object({
|
||||
id: z.string(),
|
||||
name: z.string(),
|
||||
model_name: z.string(),
|
||||
base_model: zBaseModel,
|
||||
});
|
||||
|
||||
@ -171,7 +170,7 @@ export const isValidVaeModel = (val: unknown): val is VaeModelParam =>
|
||||
*/
|
||||
export const zLoRAModel = z.object({
|
||||
id: z.string(),
|
||||
name: z.string(),
|
||||
model_name: z.string(),
|
||||
base_model: zBaseModel,
|
||||
});
|
||||
/**
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user