Merge branch 'development' of github.com:psychedelicious/stable-diffusion into psychedelicious-development
64
.github/workflows/cache-model.yml
vendored
@ -1,64 +0,0 @@
|
||||
name: Cache Model
|
||||
on:
|
||||
workflow_dispatch
|
||||
jobs:
|
||||
build:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ macos-12 ]
|
||||
name: Create Caches using ${{ matrix.os }}
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- name: Checkout sources
|
||||
uses: actions/checkout@v3
|
||||
- name: Cache model
|
||||
id: cache-sd-v1-4
|
||||
uses: actions/cache@v3
|
||||
env:
|
||||
cache-name: cache-sd-v1-4
|
||||
with:
|
||||
path: models/ldm/stable-diffusion-v1/model.ckpt
|
||||
key: ${{ env.cache-name }}
|
||||
restore-keys: |
|
||||
${{ env.cache-name }}
|
||||
- name: Download Stable Diffusion v1.4 model
|
||||
if: ${{ steps.cache-sd-v1-4.outputs.cache-hit != 'true' }}
|
||||
continue-on-error: true
|
||||
run: |
|
||||
if [ ! -e models/ldm/stable-diffusion-v1 ]; then
|
||||
mkdir -p models/ldm/stable-diffusion-v1
|
||||
fi
|
||||
if [ ! -e models/ldm/stable-diffusion-v1/model.ckpt ]; then
|
||||
curl -o models/ldm/stable-diffusion-v1/model.ckpt ${{ secrets.SD_V1_4_URL }}
|
||||
fi
|
||||
# Uncomment this when we no longer make changes to environment-mac.yaml
|
||||
# - name: Cache environment
|
||||
# id: cache-conda-env-ldm
|
||||
# uses: actions/cache@v3
|
||||
# env:
|
||||
# cache-name: cache-conda-env-ldm
|
||||
# with:
|
||||
# path: ~/.conda/envs/ldm
|
||||
# key: ${{ env.cache-name }}
|
||||
# restore-keys: |
|
||||
# ${{ env.cache-name }}
|
||||
- name: Install dependencies
|
||||
# if: ${{ steps.cache-conda-env-ldm.outputs.cache-hit != 'true' }}
|
||||
run: |
|
||||
conda env create -f environment-mac.yaml
|
||||
- name: Cache hugginface and torch models
|
||||
id: cache-hugginface-torch
|
||||
uses: actions/cache@v3
|
||||
env:
|
||||
cache-name: cache-hugginface-torch
|
||||
with:
|
||||
path: ~/.cache
|
||||
key: ${{ env.cache-name }}
|
||||
restore-keys: |
|
||||
${{ env.cache-name }}
|
||||
- name: Download Huggingface and Torch models
|
||||
if: ${{ steps.cache-hugginface-torch.outputs.cache-hit != 'true' }}
|
||||
continue-on-error: true
|
||||
run: |
|
||||
export PYTHON_BIN=/usr/local/miniconda/envs/ldm/bin/python
|
||||
$PYTHON_BIN scripts/preload_models.py
|
70
.github/workflows/create-caches.yml
vendored
Normal file
@ -0,0 +1,70 @@
|
||||
name: Create Caches
|
||||
on:
|
||||
workflow_dispatch
|
||||
jobs:
|
||||
build:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ ubuntu-latest, macos-12 ]
|
||||
name: Create Caches on ${{ matrix.os }} conda
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- name: Set platform variables
|
||||
id: vars
|
||||
run: |
|
||||
if [ "$RUNNER_OS" = "macOS" ]; then
|
||||
echo "::set-output name=ENV_FILE::environment-mac.yaml"
|
||||
echo "::set-output name=PYTHON_BIN::/usr/local/miniconda/envs/ldm/bin/python"
|
||||
elif [ "$RUNNER_OS" = "Linux" ]; then
|
||||
echo "::set-output name=ENV_FILE::environment.yaml"
|
||||
echo "::set-output name=PYTHON_BIN::/usr/share/miniconda/envs/ldm/bin/python"
|
||||
fi
|
||||
- name: Checkout sources
|
||||
uses: actions/checkout@v3
|
||||
- name: Use Cached Stable Diffusion v1.4 Model
|
||||
id: cache-sd-v1-4
|
||||
uses: actions/cache@v3
|
||||
env:
|
||||
cache-name: cache-sd-v1-4
|
||||
with:
|
||||
path: models/ldm/stable-diffusion-v1/model.ckpt
|
||||
key: ${{ env.cache-name }}
|
||||
restore-keys: |
|
||||
${{ env.cache-name }}
|
||||
- name: Download Stable Diffusion v1.4 Model
|
||||
if: ${{ steps.cache-sd-v1-4.outputs.cache-hit != 'true' }}
|
||||
run: |
|
||||
if [ ! -e models/ldm/stable-diffusion-v1 ]; then
|
||||
mkdir -p models/ldm/stable-diffusion-v1
|
||||
fi
|
||||
if [ ! -e models/ldm/stable-diffusion-v1/model.ckpt ]; then
|
||||
curl -o models/ldm/stable-diffusion-v1/model.ckpt ${{ secrets.SD_V1_4_URL }}
|
||||
fi
|
||||
- name: Use Cached Dependencies
|
||||
id: cache-conda-env-ldm
|
||||
uses: actions/cache@v3
|
||||
env:
|
||||
cache-name: cache-conda-env-ldm
|
||||
with:
|
||||
path: ~/.conda/envs/ldm
|
||||
key: ${{ env.cache-name }}
|
||||
restore-keys: |
|
||||
${{ env.cache-name }}-${{ runner.os }}-${{ hashFiles(steps.vars.outputs.ENV_FILE) }}
|
||||
- name: Install Dependencies
|
||||
if: ${{ steps.cache-conda-env-ldm.outputs.cache-hit != 'true' }}
|
||||
run: |
|
||||
conda env create -f ${{ steps.vars.outputs.ENV_FILE }}
|
||||
- name: Use Cached Huggingface and Torch models
|
||||
id: cache-huggingface-torch
|
||||
uses: actions/cache@v3
|
||||
env:
|
||||
cache-name: cache-huggingface-torch
|
||||
with:
|
||||
path: ~/.cache
|
||||
key: ${{ env.cache-name }}
|
||||
restore-keys: |
|
||||
${{ env.cache-name }}-${{ hashFiles('scripts/preload_models.py') }}
|
||||
- name: Download Huggingface and Torch models
|
||||
if: ${{ steps.cache-huggingface-torch.outputs.cache-hit != 'true' }}
|
||||
run: |
|
||||
${{ steps.vars.outputs.PYTHON_BIN }} scripts/preload_models.py
|
80
.github/workflows/macos12-miniconda.yml
vendored
@ -1,80 +0,0 @@
|
||||
name: Build
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
jobs:
|
||||
build:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ macos-12 ]
|
||||
name: Build on ${{ matrix.os }} miniconda
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- name: Checkout sources
|
||||
uses: actions/checkout@v3
|
||||
- name: Cache model
|
||||
id: cache-sd-v1-4
|
||||
uses: actions/cache@v3
|
||||
env:
|
||||
cache-name: cache-sd-v1-4
|
||||
with:
|
||||
path: models/ldm/stable-diffusion-v1/model.ckpt
|
||||
key: ${{ env.cache-name }}
|
||||
restore-keys: |
|
||||
${{ env.cache-name }}
|
||||
- name: Download Stable Diffusion v1.4 model
|
||||
if: ${{ steps.cache-sd-v1-4.outputs.cache-hit != 'true' }}
|
||||
continue-on-error: true
|
||||
run: |
|
||||
if [ ! -e models/ldm/stable-diffusion-v1 ]; then
|
||||
mkdir -p models/ldm/stable-diffusion-v1
|
||||
fi
|
||||
if [ ! -e models/ldm/stable-diffusion-v1/model.ckpt ]; then
|
||||
curl -o models/ldm/stable-diffusion-v1/model.ckpt ${{ secrets.SD_V1_4_URL }}
|
||||
fi
|
||||
# Uncomment this when we no longer make changes to environment-mac.yaml
|
||||
# - name: Cache environment
|
||||
# id: cache-conda-env-ldm
|
||||
# uses: actions/cache@v3
|
||||
# env:
|
||||
# cache-name: cache-conda-env-ldm
|
||||
# with:
|
||||
# path: ~/.conda/envs/ldm
|
||||
# key: ${{ env.cache-name }}
|
||||
# restore-keys: |
|
||||
# ${{ env.cache-name }}
|
||||
- name: Install dependencies
|
||||
# if: ${{ steps.cache-conda-env-ldm.outputs.cache-hit != 'true' }}
|
||||
run: |
|
||||
conda env create -f environment-mac.yaml
|
||||
- name: Cache hugginface and torch models
|
||||
id: cache-hugginface-torch
|
||||
uses: actions/cache@v3
|
||||
env:
|
||||
cache-name: cache-hugginface-torch
|
||||
with:
|
||||
path: ~/.cache
|
||||
key: ${{ env.cache-name }}
|
||||
restore-keys: |
|
||||
${{ env.cache-name }}
|
||||
- name: Download Huggingface and Torch models
|
||||
if: ${{ steps.cache-hugginface-torch.outputs.cache-hit != 'true' }}
|
||||
continue-on-error: true
|
||||
run: |
|
||||
export PYTHON_BIN=/usr/local/miniconda/envs/ldm/bin/python
|
||||
$PYTHON_BIN scripts/preload_models.py
|
||||
- name: Run the tests
|
||||
run: |
|
||||
# Note, can't "activate" via automation, and activation is just env vars and path
|
||||
export PYTHON_BIN=/usr/local/miniconda/envs/ldm/bin/python
|
||||
export PYTORCH_ENABLE_MPS_FALLBACK=1
|
||||
$PYTHON_BIN scripts/preload_models.py
|
||||
mkdir -p outputs/img-samples
|
||||
time $PYTHON_BIN scripts/dream.py --from_file tests/prompts.txt </dev/null 2> outputs/img-samples/err.log > outputs/img-samples/out.log
|
||||
- name: Archive results
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: results
|
||||
path: outputs/img-samples
|
97
.github/workflows/test-dream-conda.yml
vendored
Normal file
@ -0,0 +1,97 @@
|
||||
name: Test Dream with Conda
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
- 'development'
|
||||
jobs:
|
||||
os_matrix:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ ubuntu-latest, macos-12 ]
|
||||
name: Test dream.py on ${{ matrix.os }} with conda
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- run: |
|
||||
echo The PR was merged
|
||||
- name: Set platform variables
|
||||
id: vars
|
||||
run: |
|
||||
# Note, can't "activate" via github action; specifying the env's python has the same effect
|
||||
if [ "$RUNNER_OS" = "macOS" ]; then
|
||||
echo "::set-output name=ENV_FILE::environment-mac.yaml"
|
||||
echo "::set-output name=PYTHON_BIN::/usr/local/miniconda/envs/ldm/bin/python"
|
||||
elif [ "$RUNNER_OS" = "Linux" ]; then
|
||||
echo "::set-output name=ENV_FILE::environment.yaml"
|
||||
echo "::set-output name=PYTHON_BIN::/usr/share/miniconda/envs/ldm/bin/python"
|
||||
fi
|
||||
- name: Checkout sources
|
||||
uses: actions/checkout@v3
|
||||
- name: Use Cached Stable Diffusion v1.4 Model
|
||||
id: cache-sd-v1-4
|
||||
uses: actions/cache@v3
|
||||
env:
|
||||
cache-name: cache-sd-v1-4
|
||||
with:
|
||||
path: models/ldm/stable-diffusion-v1/model.ckpt
|
||||
key: ${{ env.cache-name }}
|
||||
restore-keys: |
|
||||
${{ env.cache-name }}
|
||||
- name: Download Stable Diffusion v1.4 Model
|
||||
if: ${{ steps.cache-sd-v1-4.outputs.cache-hit != 'true' }}
|
||||
run: |
|
||||
if [ ! -e models/ldm/stable-diffusion-v1 ]; then
|
||||
mkdir -p models/ldm/stable-diffusion-v1
|
||||
fi
|
||||
if [ ! -e models/ldm/stable-diffusion-v1/model.ckpt ]; then
|
||||
curl -o models/ldm/stable-diffusion-v1/model.ckpt ${{ secrets.SD_V1_4_URL }}
|
||||
fi
|
||||
- name: Use Cached Dependencies
|
||||
id: cache-conda-env-ldm
|
||||
uses: actions/cache@v3
|
||||
env:
|
||||
cache-name: cache-conda-env-ldm
|
||||
with:
|
||||
path: ~/.conda/envs/ldm
|
||||
key: ${{ env.cache-name }}
|
||||
restore-keys: |
|
||||
${{ env.cache-name }}-${{ runner.os }}-${{ hashFiles(steps.vars.outputs.ENV_FILE) }}
|
||||
- name: Install Dependencies
|
||||
if: ${{ steps.cache-conda-env-ldm.outputs.cache-hit != 'true' }}
|
||||
run: |
|
||||
conda env create -f ${{ steps.vars.outputs.ENV_FILE }}
|
||||
- name: Use Cached Huggingface and Torch models
|
||||
id: cache-hugginface-torch
|
||||
uses: actions/cache@v3
|
||||
env:
|
||||
cache-name: cache-hugginface-torch
|
||||
with:
|
||||
path: ~/.cache
|
||||
key: ${{ env.cache-name }}
|
||||
restore-keys: |
|
||||
${{ env.cache-name }}-${{ hashFiles('scripts/preload_models.py') }}
|
||||
- name: Download Huggingface and Torch models
|
||||
if: ${{ steps.cache-hugginface-torch.outputs.cache-hit != 'true' }}
|
||||
run: |
|
||||
${{ steps.vars.outputs.PYTHON_BIN }} scripts/preload_models.py
|
||||
# - name: Run tmate
|
||||
# uses: mxschmitt/action-tmate@v3
|
||||
# timeout-minutes: 30
|
||||
- name: Run the tests
|
||||
run: |
|
||||
# Note, can't "activate" via github action; specifying the env's python has the same effect
|
||||
if [ $(uname) = "Darwin" ]; then
|
||||
export PYTORCH_ENABLE_MPS_FALLBACK=1
|
||||
fi
|
||||
# Utterly hacky, but I don't know how else to do this
|
||||
if [[ ${{ github.ref }} == 'refs/heads/master' ]]; then
|
||||
time ${{ steps.vars.outputs.PYTHON_BIN }} scripts/dream.py --from_file tests/preflight_prompts.txt --full_precision
|
||||
elif [[ ${{ github.ref }} == 'refs/heads/development' ]]; then
|
||||
time ${{ steps.vars.outputs.PYTHON_BIN }} scripts/dream.py --from_file tests/dev_prompts.txt --full_precision
|
||||
fi
|
||||
mkdir -p outputs/img-samples
|
||||
- name: Archive results
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: results
|
||||
path: outputs/img-samples
|
1
.gitignore
vendored
@ -1,6 +1,7 @@
|
||||
# ignore default image save location and model symbolic link
|
||||
outputs/
|
||||
models/ldm/stable-diffusion-v1/model.ckpt
|
||||
ldm/restoration/codeformer/weights
|
||||
|
||||
# ignore a directory which serves as a place for initial images
|
||||
inputs/
|
||||
|
131
README.md
@ -1,16 +1,41 @@
|
||||
<h1 align='center'><b>Stable Diffusion Dream Script</b></h1>
|
||||
<div align="center">
|
||||
|
||||
# Stable Diffusion Dream Script
|
||||
|
||||
![project logo](docs/assets/logo.png)
|
||||
|
||||
<p align='center'>
|
||||
<img src="docs/assets/logo.png"/>
|
||||
<a href="https://discord.gg/ZmtBAhwWhy"><img src="docs/assets/join-us-on-discord-image.png"/></a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<img src="https://img.shields.io/github/last-commit/lstein/stable-diffusion?logo=Python&logoColor=green&style=for-the-badge" alt="last-commit"/>
|
||||
<img src="https://img.shields.io/github/stars/lstein/stable-diffusion?logo=GitHub&style=for-the-badge" alt="stars"/>
|
||||
<br>
|
||||
<img src="https://img.shields.io/github/issues/lstein/stable-diffusion?logo=GitHub&style=for-the-badge" alt="issues"/>
|
||||
<img src="https://img.shields.io/github/issues-pr/lstein/stable-diffusion?logo=GitHub&style=for-the-badge" alt="pull-requests"/>
|
||||
</p>
|
||||
# **Stable Diffusion Dream Script**
|
||||
[![discord badge]][discord link]
|
||||
|
||||
[![latest release badge]][latest release link] [![github stars badge]][github stars link] [![github forks badge]][github forks link]
|
||||
|
||||
[![CI checks on main badge]][CI checks on main link] [![CI checks on dev badge]][CI checks on dev link] [![latest commit to dev badge]][latest commit to dev link]
|
||||
|
||||
[![github open issues badge]][github open issues link] [![github open prs badge]][github open prs link]
|
||||
|
||||
[CI checks on dev badge]: https://flat.badgen.net/github/checks/lstein/stable-diffusion/development?label=CI%20status%20on%20dev&cache=900&icon=github
|
||||
[CI checks on dev link]: https://github.com/lstein/stable-diffusion/actions?query=branch%3Adevelopment
|
||||
[CI checks on main badge]: https://flat.badgen.net/github/checks/lstein/stable-diffusion/main?label=CI%20status%20on%20main&cache=900&icon=github
|
||||
[CI checks on main link]: https://github.com/lstein/stable-diffusion/actions/workflows/test-dream-conda.yml
|
||||
[discord badge]: https://flat.badgen.net/discord/members/htRgbc7e?icon=discord
|
||||
[discord link]: https://discord.com/invite/htRgbc7e
|
||||
[github forks badge]: https://flat.badgen.net/github/forks/lstein/stable-diffusion?icon=github
|
||||
[github forks link]: https://useful-forks.github.io/?repo=lstein%2Fstable-diffusion
|
||||
[github open issues badge]: https://flat.badgen.net/github/open-issues/lstein/stable-diffusion?icon=github
|
||||
[github open issues link]: https://github.com/lstein/stable-diffusion/issues?q=is%3Aissue+is%3Aopen
|
||||
[github open prs badge]: https://flat.badgen.net/github/open-prs/lstein/stable-diffusion?icon=github
|
||||
[github open prs link]: https://github.com/lstein/stable-diffusion/pulls?q=is%3Apr+is%3Aopen
|
||||
[github stars badge]: https://flat.badgen.net/github/stars/lstein/stable-diffusion?icon=github
|
||||
[github stars link]: https://github.com/lstein/stable-diffusion/stargazers
|
||||
[latest commit to dev badge]: https://flat.badgen.net/github/last-commit/lstein/stable-diffusion/development?icon=github&color=yellow&label=last%20dev%20commit&cache=900
|
||||
[latest commit to dev link]: https://github.com/lstein/stable-diffusion/commits/development
|
||||
[latest release badge]: https://flat.badgen.net/github/release/lstein/stable-diffusion/development?icon=github
|
||||
[latest release link]: https://github.com/lstein/stable-diffusion/releases
|
||||
</div>
|
||||
|
||||
This is a fork of [CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion), the open
|
||||
source text-to-image generator. It provides a streamlined process with various new features and
|
||||
@ -21,7 +46,7 @@ _Note: This fork is rapidly evolving. Please use the
|
||||
[Issues](https://github.com/lstein/stable-diffusion/issues) tab to report bugs and make feature
|
||||
requests. Be sure to use the provided templates. They will help aid diagnose issues faster._
|
||||
|
||||
**Table of Contents**
|
||||
## Table of Contents
|
||||
|
||||
1. [Installation](#installation)
|
||||
2. [Hardware Requirements](#hardware-requirements)
|
||||
@ -33,38 +58,38 @@ requests. Be sure to use the provided templates. They will help aid diagnose iss
|
||||
8. [Support](#support)
|
||||
9. [Further Reading](#further-reading)
|
||||
|
||||
## Installation
|
||||
### Installation
|
||||
|
||||
This fork is supported across multiple platforms. You can find individual installation instructions
|
||||
below.
|
||||
|
||||
- ### [Linux](docs/installation/INSTALL_LINUX.md)
|
||||
- #### [Linux](docs/installation/INSTALL_LINUX.md)
|
||||
|
||||
- ### [Windows](docs/installation/INSTALL_WINDOWS.md)
|
||||
- #### [Windows](docs/installation/INSTALL_WINDOWS.md)
|
||||
|
||||
- ### [Macintosh](docs/installation/INSTALL_MAC.md)
|
||||
- #### [Macintosh](docs/installation/INSTALL_MAC.md)
|
||||
|
||||
## Hardware Requirements
|
||||
### Hardware Requirements
|
||||
|
||||
**System**
|
||||
#### System
|
||||
|
||||
You wil need one of the following:
|
||||
|
||||
- An NVIDIA-based graphics card with 4 GB or more VRAM memory.
|
||||
- An Apple computer with an M1 chip.
|
||||
|
||||
**Memory**
|
||||
#### Memory
|
||||
|
||||
- At least 12 GB Main Memory RAM.
|
||||
|
||||
**Disk**
|
||||
#### Disk
|
||||
|
||||
- At least 6 GB of free disk space for the machine learning model, Python, and all its dependencies.
|
||||
|
||||
**Note**
|
||||
|
||||
If you are have a Nvidia 10xx series card (e.g. the 1080ti), please run the dream script in
|
||||
full-precision mode as shown below.
|
||||
> Note
|
||||
>
|
||||
> If you have an Nvidia 10xx series card (e.g. the 1080ti), please run the dream script in
|
||||
> full-precision mode as shown below.
|
||||
|
||||
Similarly, specify full-precision mode on Apple M1 hardware.
|
||||
|
||||
@ -74,43 +99,31 @@ To run in full-precision mode, start `dream.py` with the `--full_precision` flag
|
||||
(ldm) ~/stable-diffusion$ python scripts/dream.py --full_precision
|
||||
```
|
||||
|
||||
## Features
|
||||
### Features
|
||||
|
||||
### Major Features
|
||||
#### Major Features
|
||||
|
||||
- #### [Interactive Command Line Interface](docs/features/CLI.md)
|
||||
- [Interactive Command Line Interface](docs/features/CLI.md)
|
||||
- [Image To Image](docs/features/IMG2IMG.md)
|
||||
- [Inpainting Support](docs/features/INPAINTING.md)
|
||||
- [GFPGAN and Real-ESRGAN Support](docs/features/UPSCALE.md)
|
||||
- [Seamless Tiling](docs/features/OTHER.md#seamless-tiling)
|
||||
- [Google Colab](docs/features/OTHER.md#google-colab)
|
||||
- [Web Server](docs/features/WEB.md)
|
||||
- [Reading Prompts From File](docs/features/PROMPTS.md#reading-prompts-from-a-file)
|
||||
- [Shortcut: Reusing Seeds](docs/features/OTHER.md#shortcuts-reusing-seeds)
|
||||
- [Weighted Prompts](docs/features/PROMPTS.md#weighted-prompts)
|
||||
- [Negative/Unconditioned Prompts](docs/features/PROMPTS.md#negative-and-unconditioned-prompts)
|
||||
- [Variations](docs/features/VARIATIONS.md)
|
||||
- [Personalizing Text-to-Image Generation](docs/features/TEXTUAL_INVERSION.md)
|
||||
- [Simplified API for text to image generation](docs/features/OTHER.md#simplified-api)
|
||||
|
||||
- #### [Image To Image](docs/features/IMG2IMG.md)
|
||||
#### Other Features
|
||||
|
||||
- #### [Inpainting Support](docs/features/INPAINTING.md)
|
||||
- [Creating Transparent Regions for Inpainting](docs/features/INPAINTING.md#creating-transparent-regions-for-inpainting)
|
||||
- [Preload Models](docs/features/OTHER.md#preload-models)
|
||||
|
||||
- #### [GFPGAN and Real-ESRGAN Support](docs/features/UPSCALE.md)
|
||||
|
||||
- #### [Seamless Tiling](docs/features/OTHER.md#seamless-tiling)
|
||||
|
||||
- #### [Google Colab](docs/features/OTHER.md#google-colab)
|
||||
|
||||
- #### [Web Server](docs/features/WEB.md)
|
||||
|
||||
- #### [Reading Prompts From File](docs/features/OTHER.md#reading-prompts-from-a-file)
|
||||
|
||||
- #### [Shortcut: Reusing Seeds](docs/features/OTHER.md#shortcuts-reusing-seeds)
|
||||
|
||||
- #### [Weighted Prompts](docs/features/OTHER.md#weighted-prompts)
|
||||
|
||||
- #### [Variations](docs/features/VARIATIONS.md)
|
||||
|
||||
- #### [Personalizing Text-to-Image Generation](docs/features/TEXTUAL_INVERSION.md)
|
||||
|
||||
- #### [Simplified API for text to image generation](docs/features/OTHER.md#simplified-api)
|
||||
|
||||
### Other Features
|
||||
|
||||
- #### [Creating Transparent Regions for Inpainting](docs/features/INPAINTING.md#creating-transparent-regions-for-inpainting)
|
||||
|
||||
- #### [Preload Models](docs/features/OTHER.md#preload-models)
|
||||
|
||||
## Latest Changes
|
||||
### Latest Changes
|
||||
|
||||
- v1.14 (11 September 2022)
|
||||
|
||||
@ -142,12 +155,12 @@ To run in full-precision mode, start `dream.py` with the `--full_precision` flag
|
||||
|
||||
For older changelogs, please visit the **[CHANGELOG](docs/features/CHANGELOG.md)**.
|
||||
|
||||
## Troubleshooting
|
||||
### Troubleshooting
|
||||
|
||||
Please check out our **[Q&A](docs/help/TROUBLESHOOT.md)** to get solutions for common installation
|
||||
problems and other issues.
|
||||
|
||||
## Contributing
|
||||
### Contributing
|
||||
|
||||
Anyone who wishes to contribute to this project, whether documentation, features, bug fixes, code
|
||||
cleanup, testing, or code reviews, is very much encouraged to do so. If you are unfamiliar with how
|
||||
@ -159,13 +172,13 @@ important thing is to **make your pull request against the "development" branch*
|
||||
"main". This will help keep public breakage to a minimum and will allow you to propose more radical
|
||||
changes.
|
||||
|
||||
## Contributors
|
||||
### Contributors
|
||||
|
||||
This fork is a combined effort of various people from across the world.
|
||||
[Check out the list of all these amazing people](docs/other/CONTRIBUTORS.md). We thank them for
|
||||
their time, hard work and effort.
|
||||
|
||||
## Support
|
||||
### Support
|
||||
|
||||
For support, please use this repository's GitHub Issues tracking service. Feel free to send me an
|
||||
email if you use and like the script.
|
||||
@ -173,7 +186,7 @@ email if you use and like the script.
|
||||
Original portions of the software are Copyright (c) 2020
|
||||
[Lincoln D. Stein](https://github.com/lstein)
|
||||
|
||||
## Further Reading
|
||||
### Further Reading
|
||||
|
||||
Please see the original README for more information on this software and underlying algorithm,
|
||||
located in the file [README-CompViz.md](docs/other/README-CompViz.md).
|
||||
|
@ -40,6 +40,8 @@ def parameters_to_command(params):
|
||||
switches.append(f'-I {params["init_img"]}')
|
||||
if 'init_mask' in params and len(params['init_mask']) > 0:
|
||||
switches.append(f'-M {params["init_mask"]}')
|
||||
if 'init_color' in params and len(params['init_color']) > 0:
|
||||
switches.append(f'--init_color {params["init_color"]}')
|
||||
if 'strength' in params and 'init_img' in params:
|
||||
switches.append(f'-f {params["strength"]}')
|
||||
if 'fit' in params and params["fit"] == True:
|
||||
@ -129,6 +131,11 @@ def create_cmd_parser():
|
||||
type=str,
|
||||
help='Path to input mask for inpainting mode (supersedes width and height)',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--init_color',
|
||||
type=str,
|
||||
help='Path to reference image for color correction (used for repeated img2img and inpainting)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'-T',
|
||||
'-fit',
|
||||
|
@ -7,6 +7,8 @@ import eventlet
|
||||
import glob
|
||||
import shlex
|
||||
import argparse
|
||||
import math
|
||||
import shutil
|
||||
|
||||
from flask_socketio import SocketIO
|
||||
from flask import Flask, send_from_directory, url_for, jsonify
|
||||
@ -15,6 +17,7 @@ from PIL import Image
|
||||
from pytorch_lightning import logging
|
||||
from threading import Event
|
||||
from uuid import uuid4
|
||||
from send2trash import send2trash
|
||||
|
||||
from ldm.gfpgan.gfpgan_tools import real_esrgan_upscale
|
||||
from ldm.gfpgan.gfpgan_tools import run_gfpgan
|
||||
@ -118,15 +121,15 @@ result_path = os.path.join(output_dir, 'img-samples/')
|
||||
intermediate_path = os.path.join(result_path, 'intermediates/')
|
||||
|
||||
# path for user-uploaded init images and masks
|
||||
init_path = os.path.join(result_path, 'init-images/')
|
||||
mask_path = os.path.join(result_path, 'mask-images/')
|
||||
init_image_path = os.path.join(result_path, 'init-images/')
|
||||
mask_image_path = os.path.join(result_path, 'mask-images/')
|
||||
|
||||
# txt log
|
||||
log_path = os.path.join(result_path, 'dream_log.txt')
|
||||
|
||||
# make all output paths
|
||||
[os.makedirs(path, exist_ok=True)
|
||||
for path in [result_path, intermediate_path, init_path, mask_path]]
|
||||
for path in [result_path, intermediate_path, init_image_path, mask_image_path]]
|
||||
|
||||
|
||||
"""
|
||||
@ -154,7 +157,8 @@ def handle_request_all_images():
|
||||
else:
|
||||
metadata = all_metadata['sd-metadata']
|
||||
image_array.append({'path': path, 'metadata': metadata})
|
||||
return make_response("OK", data=image_array)
|
||||
socketio.emit('galleryImages', {'images': image_array})
|
||||
eventlet.sleep(0)
|
||||
|
||||
|
||||
@socketio.on('generateImage')
|
||||
@ -165,16 +169,32 @@ def handle_generate_image_event(generation_parameters, esrgan_parameters, gfpgan
|
||||
esrgan_parameters,
|
||||
gfpgan_parameters
|
||||
)
|
||||
return make_response("OK")
|
||||
|
||||
|
||||
@socketio.on('runESRGAN')
|
||||
def handle_run_esrgan_event(original_image, esrgan_parameters):
|
||||
print(f'>> ESRGAN upscale requested for "{original_image["url"]}": {esrgan_parameters}')
|
||||
progress = {
|
||||
'currentStep': 1,
|
||||
'totalSteps': 1,
|
||||
'currentIteration': 1,
|
||||
'totalIterations': 1,
|
||||
'currentStatus': 'Preparing',
|
||||
'isProcessing': True,
|
||||
'currentStatusHasSteps': False
|
||||
}
|
||||
|
||||
socketio.emit('progressUpdate', progress)
|
||||
eventlet.sleep(0)
|
||||
|
||||
image = Image.open(original_image["url"])
|
||||
|
||||
seed = original_image['metadata']['seed'] if 'seed' in original_image['metadata'] else 'unknown_seed'
|
||||
|
||||
progress['currentStatus'] = 'Upscaling'
|
||||
socketio.emit('progressUpdate', progress)
|
||||
eventlet.sleep(0)
|
||||
|
||||
image = real_esrgan_upscale(
|
||||
image=image,
|
||||
upsampler_scale=esrgan_parameters['upscale'][0],
|
||||
@ -182,24 +202,54 @@ def handle_run_esrgan_event(original_image, esrgan_parameters):
|
||||
seed=seed
|
||||
)
|
||||
|
||||
progress['currentStatus'] = 'Saving image'
|
||||
socketio.emit('progressUpdate', progress)
|
||||
eventlet.sleep(0)
|
||||
|
||||
esrgan_parameters['seed'] = seed
|
||||
path = save_image(image, esrgan_parameters, result_path, postprocessing='esrgan')
|
||||
command = parameters_to_command(esrgan_parameters)
|
||||
|
||||
write_log_message(f'[Upscaled] "{original_image["url"]}" > "{path}": {command}')
|
||||
|
||||
progress['currentStatus'] = 'Finished'
|
||||
progress['currentStep'] = 0
|
||||
progress['totalSteps'] = 0
|
||||
progress['currentIteration'] = 0
|
||||
progress['totalIterations'] = 0
|
||||
progress['isProcessing'] = False
|
||||
socketio.emit('progressUpdate', progress)
|
||||
eventlet.sleep(0)
|
||||
|
||||
socketio.emit(
|
||||
'result', {'url': os.path.relpath(path), 'type': 'esrgan', 'uuid': original_image['uuid'],'metadata': esrgan_parameters})
|
||||
'esrganResult', {'url': os.path.relpath(path), 'uuid': original_image['uuid'], 'metadata': esrgan_parameters})
|
||||
|
||||
|
||||
|
||||
@socketio.on('runGFPGAN')
|
||||
def handle_run_gfpgan_event(original_image, gfpgan_parameters):
|
||||
print(f'>> GFPGAN face fix requested for "{original_image["url"]}": {gfpgan_parameters}')
|
||||
progress = {
|
||||
'currentStep': 1,
|
||||
'totalSteps': 1,
|
||||
'currentIteration': 1,
|
||||
'totalIterations': 1,
|
||||
'currentStatus': 'Preparing',
|
||||
'isProcessing': True,
|
||||
'currentStatusHasSteps': False
|
||||
}
|
||||
|
||||
socketio.emit('progressUpdate', progress)
|
||||
eventlet.sleep(0)
|
||||
|
||||
image = Image.open(original_image["url"])
|
||||
|
||||
seed = original_image['metadata']['seed'] if 'seed' in original_image['metadata'] else 'unknown_seed'
|
||||
|
||||
progress['currentStatus'] = 'Fixing faces'
|
||||
socketio.emit('progressUpdate', progress)
|
||||
eventlet.sleep(0)
|
||||
|
||||
image = run_gfpgan(
|
||||
image=image,
|
||||
strength=gfpgan_parameters['gfpgan_strength'],
|
||||
@ -207,29 +257,42 @@ def handle_run_gfpgan_event(original_image, gfpgan_parameters):
|
||||
upsampler_scale=1
|
||||
)
|
||||
|
||||
progress['currentStatus'] = 'Saving image'
|
||||
socketio.emit('progressUpdate', progress)
|
||||
eventlet.sleep(0)
|
||||
|
||||
gfpgan_parameters['seed'] = seed
|
||||
path = save_image(image, gfpgan_parameters, result_path, postprocessing='gfpgan')
|
||||
command = parameters_to_command(gfpgan_parameters)
|
||||
|
||||
write_log_message(f'[Fixed faces] "{original_image["url"]}" > "{path}": {command}')
|
||||
|
||||
progress['currentStatus'] = 'Finished'
|
||||
progress['currentStep'] = 0
|
||||
progress['totalSteps'] = 0
|
||||
progress['currentIteration'] = 0
|
||||
progress['totalIterations'] = 0
|
||||
progress['isProcessing'] = False
|
||||
socketio.emit('progressUpdate', progress)
|
||||
eventlet.sleep(0)
|
||||
|
||||
socketio.emit(
|
||||
'result', {'url': os.path.relpath(path), 'type': 'gfpgan', 'uuid': original_image['uuid'],'metadata': gfpgan_parameters})
|
||||
'gfpganResult', {'url': os.path.relpath(path), 'uuid': original_image['uuid'], 'metadata': gfpgan_parameters})
|
||||
|
||||
|
||||
@socketio.on('cancel')
|
||||
def handle_cancel():
|
||||
print(f'>> Cancel processing requested')
|
||||
canceled.set()
|
||||
return make_response("OK")
|
||||
socketio.emit('processingCanceled')
|
||||
|
||||
|
||||
# TODO: I think this needs a safety mechanism.
|
||||
@socketio.on('deleteImage')
|
||||
def handle_delete_image(path):
|
||||
def handle_delete_image(path, uuid):
|
||||
print(f'>> Delete requested "{path}"')
|
||||
Path(path).unlink()
|
||||
return make_response("OK")
|
||||
send2trash(path)
|
||||
socketio.emit('imageDeleted', {'url': path, 'uuid': uuid})
|
||||
|
||||
|
||||
# TODO: I think this needs a safety mechanism.
|
||||
@ -239,11 +302,11 @@ def handle_upload_initial_image(bytes, name):
|
||||
uuid = uuid4().hex
|
||||
split = os.path.splitext(name)
|
||||
name = f'{split[0]}.{uuid}{split[1]}'
|
||||
file_path = os.path.join(init_path, name)
|
||||
file_path = os.path.join(init_image_path, name)
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
newFile = open(file_path, "wb")
|
||||
newFile.write(bytes)
|
||||
return make_response("OK", data=file_path)
|
||||
socketio.emit('initialImageUploaded', {'url': file_path, 'uuid': ''})
|
||||
|
||||
|
||||
# TODO: I think this needs a safety mechanism.
|
||||
@ -253,11 +316,11 @@ def handle_upload_mask_image(bytes, name):
|
||||
uuid = uuid4().hex
|
||||
split = os.path.splitext(name)
|
||||
name = f'{split[0]}.{uuid}{split[1]}'
|
||||
file_path = os.path.join(mask_path, name)
|
||||
file_path = os.path.join(mask_image_path, name)
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
newFile = open(file_path, "wb")
|
||||
newFile.write(bytes)
|
||||
return make_response("OK", data=file_path)
|
||||
socketio.emit('maskImageUploaded', {'url': file_path, 'uuid': ''})
|
||||
|
||||
|
||||
|
||||
@ -272,6 +335,13 @@ ADDITIONAL FUNCTIONS
|
||||
"""
|
||||
|
||||
|
||||
def make_unique_init_image_filename(name):
|
||||
uuid = uuid4().hex
|
||||
split = os.path.splitext(name)
|
||||
name = f'{split[0]}.{uuid}{split[1]}'
|
||||
return name
|
||||
|
||||
|
||||
def write_log_message(message, log_path=log_path):
|
||||
"""Logs the filename and parameters used to generate or process that image to log file"""
|
||||
message = f'{message}\n'
|
||||
@ -279,15 +349,6 @@ def write_log_message(message, log_path=log_path):
|
||||
file.writelines(message)
|
||||
|
||||
|
||||
def make_response(status, message=None, data=None):
|
||||
response = {'status': status}
|
||||
if message is not None:
|
||||
response['message'] = message
|
||||
if data is not None:
|
||||
response['data'] = data
|
||||
return response
|
||||
|
||||
|
||||
def save_image(image, parameters, output_dir, step_index=None, postprocessing=False):
|
||||
seed = parameters['seed'] if 'seed' in parameters else 'unknown_seed'
|
||||
|
||||
@ -309,16 +370,69 @@ def save_image(image, parameters, output_dir, step_index=None, postprocessing=Fa
|
||||
|
||||
return path
|
||||
|
||||
def calculate_real_steps(steps, strength, has_init_image):
|
||||
return math.floor(strength * steps) if has_init_image else steps
|
||||
|
||||
def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters):
|
||||
canceled.clear()
|
||||
|
||||
step_index = 1
|
||||
|
||||
"""
|
||||
If a result image is used as an init image, and then deleted, we will want to be
|
||||
able to use it as an init image in the future. Need to copy it.
|
||||
|
||||
If the init/mask image doesn't exist in the init_image_path/mask_image_path,
|
||||
make a unique filename for it and copy it there.
|
||||
"""
|
||||
if ('init_img' in generation_parameters):
|
||||
filename = os.path.basename(generation_parameters['init_img'])
|
||||
if not os.path.exists(os.path.join(init_image_path, filename)):
|
||||
unique_filename = make_unique_init_image_filename(filename)
|
||||
new_path = os.path.join(init_image_path, unique_filename)
|
||||
shutil.copy(generation_parameters['init_img'], new_path)
|
||||
generation_parameters['init_img'] = new_path
|
||||
if ('init_mask' in generation_parameters):
|
||||
filename = os.path.basename(generation_parameters['init_mask'])
|
||||
if not os.path.exists(os.path.join(mask_image_path, filename)):
|
||||
unique_filename = make_unique_init_image_filename(filename)
|
||||
new_path = os.path.join(init_image_path, unique_filename)
|
||||
shutil.copy(generation_parameters['init_img'], new_path)
|
||||
generation_parameters['init_mask'] = new_path
|
||||
|
||||
|
||||
|
||||
totalSteps = calculate_real_steps(
|
||||
steps=generation_parameters['steps'],
|
||||
strength=generation_parameters['strength'] if 'strength' in generation_parameters else None,
|
||||
has_init_image='init_img' in generation_parameters
|
||||
)
|
||||
|
||||
progress = {
|
||||
'currentStep': 1,
|
||||
'totalSteps': totalSteps,
|
||||
'currentIteration': 1,
|
||||
'totalIterations': generation_parameters['iterations'],
|
||||
'currentStatus': 'Preparing',
|
||||
'isProcessing': True,
|
||||
'currentStatusHasSteps': False
|
||||
}
|
||||
|
||||
socketio.emit('progressUpdate', progress)
|
||||
eventlet.sleep(0)
|
||||
|
||||
def image_progress(sample, step):
|
||||
if canceled.is_set():
|
||||
raise CanceledException
|
||||
|
||||
nonlocal step_index
|
||||
nonlocal generation_parameters
|
||||
nonlocal progress
|
||||
|
||||
progress['currentStep'] = step + 1
|
||||
progress['currentStatus'] = 'Generating'
|
||||
progress['currentStatusHasSteps'] = True
|
||||
|
||||
if generation_parameters["progress_images"] and step % 5 == 0 and step < generation_parameters['steps'] - 1:
|
||||
image = model.sample_to_image(sample)
|
||||
path = save_image(image, generation_parameters, intermediate_path, step_index)
|
||||
@ -326,18 +440,30 @@ def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters)
|
||||
step_index += 1
|
||||
socketio.emit('intermediateResult', {
|
||||
'url': os.path.relpath(path), 'metadata': generation_parameters})
|
||||
socketio.emit('progress', {'step': step + 1})
|
||||
socketio.emit('progressUpdate', progress)
|
||||
eventlet.sleep(0)
|
||||
|
||||
def image_done(image, seed):
|
||||
nonlocal generation_parameters
|
||||
nonlocal esrgan_parameters
|
||||
nonlocal gfpgan_parameters
|
||||
nonlocal progress
|
||||
|
||||
step_index = 1
|
||||
|
||||
progress['currentStatus'] = 'Generation complete'
|
||||
socketio.emit('progressUpdate', progress)
|
||||
eventlet.sleep(0)
|
||||
|
||||
all_parameters = generation_parameters
|
||||
postprocessing = False
|
||||
|
||||
if esrgan_parameters:
|
||||
progress['currentStatus'] = 'Upscaling'
|
||||
progress['currentStatusHasSteps'] = False
|
||||
socketio.emit('progressUpdate', progress)
|
||||
eventlet.sleep(0)
|
||||
|
||||
image = real_esrgan_upscale(
|
||||
image=image,
|
||||
strength=esrgan_parameters['strength'],
|
||||
@ -348,6 +474,11 @@ def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters)
|
||||
all_parameters["upscale"] = [esrgan_parameters['level'], esrgan_parameters['strength']]
|
||||
|
||||
if gfpgan_parameters:
|
||||
progress['currentStatus'] = 'Fixing faces'
|
||||
progress['currentStatusHasSteps'] = False
|
||||
socketio.emit('progressUpdate', progress)
|
||||
eventlet.sleep(0)
|
||||
|
||||
image = run_gfpgan(
|
||||
image=image,
|
||||
strength=gfpgan_parameters['strength'],
|
||||
@ -358,6 +489,9 @@ def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters)
|
||||
all_parameters["gfpgan_strength"] = gfpgan_parameters['strength']
|
||||
|
||||
all_parameters['seed'] = seed
|
||||
progress['currentStatus'] = 'Saving image'
|
||||
socketio.emit('progressUpdate', progress)
|
||||
eventlet.sleep(0)
|
||||
|
||||
path = save_image(image, all_parameters, result_path, postprocessing=postprocessing)
|
||||
command = parameters_to_command(all_parameters)
|
||||
@ -365,8 +499,24 @@ def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters)
|
||||
print(f'Image generated: "{path}"')
|
||||
write_log_message(f'[Generated] "{path}": {command}')
|
||||
|
||||
if (progress['totalIterations'] > progress['currentIteration']):
|
||||
progress['currentStep'] = 1
|
||||
progress['currentIteration'] +=1
|
||||
progress['currentStatus'] = 'Iteration finished'
|
||||
progress['currentStatusHasSteps'] = False
|
||||
else:
|
||||
progress['currentStep'] = 0
|
||||
progress['totalSteps'] = 0
|
||||
progress['currentIteration'] = 0
|
||||
progress['totalIterations'] = 0
|
||||
progress['currentStatus'] = 'Finished'
|
||||
progress['isProcessing'] = False
|
||||
|
||||
socketio.emit('progressUpdate', progress)
|
||||
eventlet.sleep(0)
|
||||
|
||||
socketio.emit(
|
||||
'result', {'url': os.path.relpath(path), 'type': 'generation', 'metadata': all_parameters})
|
||||
'generationResult', {'url': os.path.relpath(path), 'metadata': all_parameters})
|
||||
eventlet.sleep(0)
|
||||
|
||||
try:
|
||||
@ -381,7 +531,7 @@ def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters)
|
||||
except CanceledException:
|
||||
pass
|
||||
except Exception as e:
|
||||
socketio.emit('error', (str(e)))
|
||||
socketio.emit('error', {'message': (str(e))})
|
||||
print("\n")
|
||||
traceback.print_exc()
|
||||
print("\n")
|
||||
|
BIN
docs/assets/join-us-on-discord-image.png
Normal file
After Width: | Height: | Size: 25 KiB |
BIN
docs/assets/negative_prompt_walkthru/step1.png
Normal file
After Width: | Height: | Size: 451 KiB |
BIN
docs/assets/negative_prompt_walkthru/step2.png
Normal file
After Width: | Height: | Size: 453 KiB |
BIN
docs/assets/negative_prompt_walkthru/step3.png
Normal file
After Width: | Height: | Size: 463 KiB |
BIN
docs/assets/negative_prompt_walkthru/step4.png
Normal file
After Width: | Height: | Size: 435 KiB |
BIN
docs/assets/step1.png
Normal file
After Width: | Height: | Size: 503 KiB |
BIN
docs/assets/step2.png
Normal file
After Width: | Height: | Size: 1.4 KiB |
BIN
docs/assets/step4.png
Normal file
After Width: | Height: | Size: 1.3 KiB |
BIN
docs/assets/step5.png
Normal file
After Width: | Height: | Size: 5.6 KiB |
BIN
docs/assets/step6.png
Normal file
After Width: | Height: | Size: 395 KiB |
BIN
docs/assets/step7.png
Normal file
After Width: | Height: | Size: 1014 KiB |
@ -154,13 +154,19 @@ vary greatly depending on what is in the image. We also ask to --fit the image i
|
||||
than 640x480. Otherwise the image size will be identical to the provided photo and you may run out
|
||||
of memory if it is large.
|
||||
|
||||
Repeated chaining of img2img on an image can result in significant color shifts
|
||||
in the output, especially if run with lower strength. Color correction can be
|
||||
run against a reference image to fix this issue. Use the original input image to the
|
||||
chain as the the reference image for each step in the chain.
|
||||
|
||||
In addition to the command-line options recognized by txt2img, img2img accepts additional options:
|
||||
|
||||
| Argument | Shortcut | Default | Description |
|
||||
| ------------------ | --------- | ------- | ------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| --init_img <path> | -I<path> | None | Path to the initialization image |
|
||||
| --fit | -F | False | Scale the image to fit into the specified -H and -W dimensions |
|
||||
| --strength <float> | -s<float> | 0.75 | How hard to try to match the prompt to the initial image. Ranges from 0.0-0.99, with higher values replacing the initial image completely. |
|
||||
| --init_img <path> | -I<path> | None | Path to the initialization image |
|
||||
| --init_color <path> | | None | Path to reference image for color correction |
|
||||
| --fit | -F | False | Scale the image to fit into the specified -H and -W dimensions |
|
||||
| --strength <float> | -s<float> | 0.75 | How hard to try to match the prompt to the initial image. Ranges from 0.0-0.99, with higher values replacing the initial image completely. |
|
||||
|
||||
### This is an example of inpainting
|
||||
|
||||
|
@ -37,5 +37,44 @@ We are hoping to get rid of the need for this workaround in an upcoming release.
|
||||
5. Open the Layers toolbar (^L) and select "Floating Selection"
|
||||
6. Set opacity to 0%
|
||||
7. Export as PNG
|
||||
8. In the export dialogue, Make sure the "Save colour values from
|
||||
transparent pixels" checkbox is selected.
|
||||
|
||||
|
||||
## Recipe for Adobe Photoshop
|
||||
|
||||
1. Open image in Photoshop
|
||||
<p align='left'>
|
||||
<img src="../assets/step1.png"/>
|
||||
</p>
|
||||
|
||||
2. Use any of the selection tools (Marquee, Lasso, or Wand) to select the area you desire to inpaint.
|
||||
<p align='left'>
|
||||
<img src="../assets/step2.png"/>
|
||||
</p>
|
||||
|
||||
3. Because we'll be applying a mask over the area we want to preserve, you should now select the inverse by using the Shift + Ctrl + I shortcut, or right clicking and using the "Select Inverse" option.
|
||||
|
||||
4. You'll now create a mask by selecting the image layer, and Masking the selection. Make sure that you don't delete any of the underlying image, or your inpainting results will be dramatically impacted.
|
||||
<p align='left'>
|
||||
<img src="../assets/step4.png"/>
|
||||
</p>
|
||||
|
||||
5. Make sure to hide any background layers that are present. You should see the mask applied to your image layer, and the image on your canvas should display the checkered background.
|
||||
<p align='left'>
|
||||
<img src="../assets/step5.png"/>
|
||||
</p>
|
||||
|
||||
<p align='left'>
|
||||
<img src="../assets/step6.png"/>
|
||||
</p>
|
||||
|
||||
6. Save the image as a transparent PNG by using the "Save a Copy" option in the File menu, or using the Alt + Ctrl + S keyboard shortcut.
|
||||
|
||||
7. After following the inpainting instructions above (either through the CLI or the Web UI), marvel at your newfound ability to selectively dream. Lookin' good!
|
||||
<p align='left'>
|
||||
<img src="../assets/step7.png"/>
|
||||
</p>
|
||||
|
||||
8. In the export dialogue, Make sure the "Save colour values from transparent pixels" checkbox is
|
||||
selected.
|
||||
|
@ -28,32 +28,6 @@ dream> "pond garden with lotus by claude monet" --seamless -s100 -n4
|
||||
|
||||
---
|
||||
|
||||
## **Reading Prompts from a File**
|
||||
|
||||
You can automate `dream.py` by providing a text file with the prompts you want to run, one line per
|
||||
prompt. The text file must be composed with a text editor (e.g. Notepad) and not a word processor.
|
||||
Each line should look like what you would type at the dream> prompt:
|
||||
|
||||
```bash
|
||||
a beautiful sunny day in the park, children playing -n4 -C10
|
||||
stormy weather on a mountain top, goats grazing -s100
|
||||
innovative packaging for a squid's dinner -S137038382
|
||||
```
|
||||
|
||||
Then pass this file's name to `dream.py` when you invoke it:
|
||||
|
||||
```bash
|
||||
(ldm) ~/stable-diffusion$ python3 scripts/dream.py --from_file "path/to/prompts.txt"
|
||||
```
|
||||
|
||||
You may read a series of prompts from standard input by providing a filename of `-`:
|
||||
|
||||
```bash
|
||||
(ldm) ~/stable-diffusion$ echo "a beautiful day" | python3 scripts/dream.py --from_file -
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## **Shortcuts: Reusing Seeds**
|
||||
|
||||
Since it is so common to reuse seeds while refining a prompt, there is now a shortcut as of version
|
||||
@ -79,22 +53,6 @@ outputs/img-samples/000040.3498014304.png: "a cute child playing hopscotch" -G1.
|
||||
|
||||
---
|
||||
|
||||
## **Weighted Prompts**
|
||||
|
||||
You may weight different sections of the prompt to tell the sampler to attach different levels of
|
||||
priority to them, by adding `:(number)` to the end of the section you wish to up- or downweight. For
|
||||
example consider this prompt:
|
||||
|
||||
```bash
|
||||
tabby cat:0.25 white duck:0.75 hybrid
|
||||
```
|
||||
|
||||
This will tell the sampler to invest 25% of its effort on the tabby cat aspect of the image and 75%
|
||||
on the white duck aspect (surprisingly, this example actually works). The prompt weights can use any
|
||||
combination of integers and floating point numbers, and they do not need to add up to 1.
|
||||
|
||||
---
|
||||
|
||||
## **Simplified API**
|
||||
|
||||
For programmers who wish to incorporate stable-diffusion into other products, this repository
|
||||
|
96
docs/features/PROMPTS.md
Normal file
@ -0,0 +1,96 @@
|
||||
# Prompting Features
|
||||
|
||||
## **Reading Prompts from a File**
|
||||
|
||||
You can automate `dream.py` by providing a text file with the prompts you want to run, one line per
|
||||
prompt. The text file must be composed with a text editor (e.g. Notepad) and not a word processor.
|
||||
Each line should look like what you would type at the dream> prompt:
|
||||
|
||||
```bash
|
||||
a beautiful sunny day in the park, children playing -n4 -C10
|
||||
stormy weather on a mountain top, goats grazing -s100
|
||||
innovative packaging for a squid's dinner -S137038382
|
||||
```
|
||||
|
||||
Then pass this file's name to `dream.py` when you invoke it:
|
||||
|
||||
```bash
|
||||
(ldm) ~/stable-diffusion$ python3 scripts/dream.py --from_file "path/to/prompts.txt"
|
||||
```
|
||||
|
||||
You may read a series of prompts from standard input by providing a filename of `-`:
|
||||
|
||||
```bash
|
||||
(ldm) ~/stable-diffusion$ echo "a beautiful day" | python3 scripts/dream.py --from_file -
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## **Weighted Prompts**
|
||||
|
||||
You may weight different sections of the prompt to tell the sampler to attach different levels of
|
||||
priority to them, by adding `:(number)` to the end of the section you wish to up- or downweight. For
|
||||
example consider this prompt:
|
||||
|
||||
```bash
|
||||
tabby cat:0.25 white duck:0.75 hybrid
|
||||
```
|
||||
|
||||
This will tell the sampler to invest 25% of its effort on the tabby cat aspect of the image and 75%
|
||||
on the white duck aspect (surprisingly, this example actually works). The prompt weights can use any
|
||||
combination of integers and floating point numbers, and they do not need to add up to 1.
|
||||
|
||||
---
|
||||
|
||||
## **Negative and Unconditioned Prompts**
|
||||
|
||||
Any words between a pair of square brackets will try and be ignored by Stable Diffusion's model during generation of images.
|
||||
|
||||
```bash
|
||||
this is a test prompt [not really] to make you understand [cool] how this works.
|
||||
```
|
||||
|
||||
In the above statement, the words 'not really cool` will be ignored by Stable Diffusion.
|
||||
|
||||
Here's a prompt that depicts what it does.
|
||||
|
||||
original prompt:
|
||||
|
||||
```bash
|
||||
"A fantastical translucent poney made of water and foam, ethereal, radiant, hyperalism, scottish folklore, digital painting, artstation, concept art, smooth, 8 k frostbite 3 engine, ultra detailed, art by artgerm and greg rutkowski and magali villeneuve" -s 20 -W 512 -H 768 -C 7.5 -A k_euler_a -S 1654590180
|
||||
```
|
||||
|
||||
![step1](../assets/variation_walkthru/step1.png)
|
||||
|
||||
That image has a woman, so if we want the horse without a rider, we can influence the image not to have a woman by putting [woman] in the prompt, like this:
|
||||
|
||||
```bash
|
||||
"A fantastical translucent poney made of water and foam, ethereal, radiant, hyperalism, scottish folklore, digital painting, artstation, concept art, smooth, 8 k frostbite 3 engine, ultra detailed, art by artgerm and greg rutkowski and magali villeneuve [woman]" -s 20 -W 512 -H 768 -C 7.5 -A k_euler_a -S 1654590180
|
||||
```
|
||||
|
||||
![step2](../assets/variation_walkthru/step2.png)
|
||||
|
||||
That's nice - but say we also don't want the image to be quite so blue. We can add "blue" to the list of negative prompts, so it's now [woman blue]:
|
||||
|
||||
```bash
|
||||
"A fantastical translucent poney made of water and foam, ethereal, radiant, hyperalism, scottish folklore, digital painting, artstation, concept art, smooth, 8 k frostbite 3 engine, ultra detailed, art by artgerm and greg rutkowski and magali villeneuve [woman blue]" -s 20 -W 512 -H 768 -C 7.5 -A k_euler_a -S 1654590180
|
||||
```
|
||||
|
||||
![step3](../assets/variation_walkthru/step3.png)
|
||||
|
||||
|
||||
Getting close - but there's no sense in having a saddle when our horse doesn't have a rider, so we'll add one more negative prompt: [woman blue saddle].
|
||||
|
||||
```bash
|
||||
"A fantastical translucent poney made of water and foam, ethereal, radiant, hyperalism, scottish folklore, digital painting, artstation, concept art, smooth, 8 k frostbite 3 engine, ultra detailed, art by artgerm and greg rutkowski and magali villeneuve [woman blue saddle]" -s 20 -W 512 -H 768 -C 7.5 -A k_euler_a -S 1654590180
|
||||
```
|
||||
|
||||
![step4](../assets/variation_walkthru/step4.png)
|
||||
|
||||
|
||||
Notes about this feature:
|
||||
|
||||
* The only requirement for words to be ignored is that they are in between a pair of square brackets.
|
||||
* You can provide multiple words within the same bracket.
|
||||
* You can provide multiple brackets with multiple words in different places of your prompt. That works just fine.
|
||||
* To improve typical anatomy problems, you can add negative prompts like [bad anatomy, extra legs, extra arms, extra fingers, poorly drawn hands, poorly drawn feet, disfigured, out of frame, tiling, bad art, deformed, mutated].
|
@ -97,3 +97,39 @@ the base images.
|
||||
If you wish to stop during the image generation but want to upscale or face restore a particular
|
||||
generated image, pass it again with the same prompt and generated seed along with the `-U` and `-G`
|
||||
prompt arguments to perform those actions.
|
||||
|
||||
## CodeFormer Support
|
||||
|
||||
This repo also allows you to perform face restoration using
|
||||
[CodeFormer](https://github.com/sczhou/CodeFormer).
|
||||
|
||||
In order to setup CodeFormer to work, you need to download the models like with GFPGAN. You can do
|
||||
this either by running `preload_models.py` or by manually downloading the
|
||||
[model file](https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth) and
|
||||
saving it to `ldm/restoration/codeformer/weights` folder.
|
||||
|
||||
You can use `-ft` prompt argument to swap between CodeFormer and the default GFPGAN. The above
|
||||
mentioned `-G` prompt argument will allow you to control the strength of the restoration effect.
|
||||
|
||||
### **Usage:**
|
||||
|
||||
The following command will perform face restoration with CodeFormer instead of the default gfpgan.
|
||||
|
||||
`<prompt> -G 0.8 -ft codeformer`
|
||||
|
||||
**Other Options:**
|
||||
|
||||
- `-cf` - cf or CodeFormer Fidelity takes values between `0` and `1`. 0 produces high quality
|
||||
results but low accuracy and 1 produces lower quality results but higher accuacy to your original
|
||||
face.
|
||||
|
||||
The following command will perform face restoration with CodeFormer. CodeFormer will output a result
|
||||
that is closely matching to the input face.
|
||||
|
||||
`<prompt> -G 1.0 -ft codeformer -cf 0.9`
|
||||
|
||||
The following command will perform face restoration with CodeFormer. CodeFormer will output a result
|
||||
that is the best restoration possible. This may deviate slightly from the original face. This is an
|
||||
excellent option to use in situations when there is very little facial data to work with.
|
||||
|
||||
`<prompt> -G 1.0 -ft codeformer -cf 0.1`
|
||||
|
@ -102,6 +102,7 @@ generate more variations around the almost-but-not-quite image. We do the
|
||||
latter, using both the `-V` (combining) and `-v` (variation strength) options.
|
||||
Note that we use `-n6` to generate 6 variations:
|
||||
|
||||
```bash
|
||||
dream> "prompt" -S3357757885 -V3647897225,0.1,1614299449,0.1 -v0.05 -n6
|
||||
Outputs:
|
||||
./outputs/Xena/000004.3279757577.png: "prompt" -s50 -W512 -H512 -C7.5 -Ak_lms -V 3647897225:0.1,1614299449:0.1,3279757577:0.05 -S3357757885
|
||||
|
@ -7,10 +7,7 @@ title: macOS
|
||||
- macOS 12.3 Monterey or later
|
||||
- Python
|
||||
- Patience
|
||||
- Apple Silicon\*
|
||||
|
||||
\*I haven't tested any of this on Intel Macs but I have read that one person got
|
||||
it to work, so Apple Silicon might not be requried.
|
||||
- Apple Silicon or Intel Mac
|
||||
|
||||
Things have moved really fast and so these instructions change often and are
|
||||
often out-of-date. One of the problems is that there are so many different ways
|
||||
@ -59,9 +56,13 @@ First get the weights checkpoint download started - it's big:
|
||||
# install python 3, git, cmake, protobuf:
|
||||
brew install cmake protobuf rust
|
||||
|
||||
# install miniconda (M1 arm64 version):
|
||||
curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh -o Miniconda3-latest-MacOSX-arm64.sh
|
||||
/bin/bash Miniconda3-latest-MacOSX-arm64.sh
|
||||
# install miniconda for M1 arm64:
|
||||
curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh -o Miniconda3-latest-MacOSX-arm64.sh
|
||||
/bin/bash Miniconda3-latest-MacOSX-arm64.sh
|
||||
|
||||
# OR install miniconda for Intel:
|
||||
curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh -o Miniconda3-latest-MacOSX-x86_64.sh
|
||||
/bin/bash Miniconda3-latest-MacOSX-x86_64.sh
|
||||
|
||||
|
||||
# EITHER WAY,
|
||||
@ -82,15 +83,22 @@ brew install cmake protobuf rust
|
||||
|
||||
ln -s "$PATH_TO_CKPT/sd-v1-4.ckpt" models/ldm/stable-diffusion-v1/model.ckpt
|
||||
|
||||
# install packages
|
||||
PIP_EXISTS_ACTION=w CONDA_SUBDIR=osx-arm64 conda env create -f environment-mac.yaml
|
||||
conda activate ldm
|
||||
# install packages for arm64
|
||||
PIP_EXISTS_ACTION=w CONDA_SUBDIR=osx-arm64 conda env create -f environment-mac.yaml
|
||||
conda activate ldm
|
||||
|
||||
# OR install packages for x86_64
|
||||
PIP_EXISTS_ACTION=w CONDA_SUBDIR=osx-x86_64 conda env create -f environment-mac.yaml
|
||||
conda activate ldm
|
||||
|
||||
# only need to do this once
|
||||
python scripts/preload_models.py
|
||||
|
||||
# run SD!
|
||||
python scripts/dream.py --full_precision # half-precision requires autocast and won't work
|
||||
|
||||
# or run the web interface!
|
||||
python scripts/dream.py --web
|
||||
```
|
||||
|
||||
The original scripts should work as well.
|
||||
@ -181,7 +189,12 @@ There are several causes of these errors.
|
||||
- Third, if it says you're missing taming you need to rebuild your virtual
|
||||
environment.
|
||||
|
||||
`conda env remove -n ldm conda env create -f environment-mac.yaml`
|
||||
````bash
|
||||
conda deactivate
|
||||
|
||||
conda env remove -n ldm
|
||||
PIP_EXISTS_ACTION=w CONDA_SUBDIR=osx-arm64 conda env create -f environment-mac.yaml
|
||||
```
|
||||
|
||||
Fourth, If you have activated the ldm virtual environment and tried rebuilding
|
||||
it, maybe the problem could be that I have something installed that you don't
|
||||
|
@ -2,15 +2,16 @@
|
||||
title: Contributors
|
||||
---
|
||||
|
||||
The list of all the amazing people who have contributed to the various features that you get to experience in this fork.
|
||||
The list of all the amazing people who have contributed to the various features that you get to
|
||||
experience in this fork.
|
||||
|
||||
We thank them for all of their time and hard work.
|
||||
|
||||
## __Original Author:__
|
||||
## **Original Author:**
|
||||
|
||||
- [Lincoln D. Stein](mailto:lincoln.stein@gmail.com)
|
||||
|
||||
## __Contributions by:__
|
||||
## **Contributions by:**
|
||||
|
||||
- [Sean McLellan](https://github.com/Oceanswave)
|
||||
- [Kevin Gibbons](https://github.com/bakkot)
|
||||
@ -52,8 +53,9 @@ We thank them for all of their time and hard work.
|
||||
- [Doggettx](https://github.com/doggettx)
|
||||
- [Matthias Wild](https://github.com/mauwii)
|
||||
- [Kyle Schouviller](https://github.com/kyle0654)
|
||||
- [rabidcopy](https://github.com/rabidcopy)
|
||||
|
||||
## __Original CompVis Authors:__
|
||||
## **Original CompVis Authors:**
|
||||
|
||||
- [Robin Rombach](https://github.com/rromb)
|
||||
- [Patrick von Platen](https://github.com/patrickvonplaten)
|
||||
@ -65,4 +67,5 @@ We thank them for all of their time and hard work.
|
||||
|
||||
---
|
||||
|
||||
_If you have contributed and don't see your name on the list of contributors, please let one of the collaborators know about the omission, or feel free to make a pull request._
|
||||
_If you have contributed and don't see your name on the list of contributors, please let one of the
|
||||
collaborators know about the omission, or feel free to make a pull request._
|
||||
|
@ -48,6 +48,7 @@ dependencies:
|
||||
- opencv-python==4.6.0
|
||||
- protobuf==3.20.1
|
||||
- realesrgan==0.2.5.0
|
||||
- send2trash==1.8.0
|
||||
- test-tube==0.7.5
|
||||
- transformers==4.21.2
|
||||
- torch-fidelity==0.3.0
|
||||
|
@ -20,7 +20,8 @@ dependencies:
|
||||
- realesrgan==0.2.5.0
|
||||
- test-tube>=0.7.5
|
||||
- streamlit==1.12.0
|
||||
- pillow==6.2.0
|
||||
- send2trash==1.8.0
|
||||
- pillow==9.2.0
|
||||
- einops==0.3.0
|
||||
- torch-fidelity==0.3.0
|
||||
- transformers==4.19.2
|
||||
|
@ -1,85 +1,37 @@
|
||||
# Stable Diffusion Web UI
|
||||
|
||||
Demo at https://peaceful-otter-7a427f.netlify.app/ (not connected to back end)
|
||||
## Run
|
||||
|
||||
much of this readme is just notes for myself during dev work
|
||||
- `python backend/server.py` serves both frontend and backend at http://localhost:9090
|
||||
|
||||
numpy rand: 0 to 4294967295
|
||||
## Evironment
|
||||
|
||||
## Test and Build
|
||||
Install [node](https://nodejs.org/en/download/) (includes npm) and optionally
|
||||
[yarn](https://yarnpkg.com/getting-started/install).
|
||||
|
||||
from `frontend/`:
|
||||
From `frontend/` run `npm install` / `yarn install` to install the frontend packages.
|
||||
|
||||
- `yarn dev` runs `tsc-watch`, which runs `vite build` on successful `tsc` transpilation
|
||||
## Dev
|
||||
|
||||
from `.`:
|
||||
1. From `frontend/`, run `npm dev` / `yarn dev` to start the dev server.
|
||||
2. Note the address it starts up on (probably `http://localhost:5173/`).
|
||||
3. Edit `backend/server.py`'s `additional_allowed_origins` to include this address, e.g.
|
||||
`additional_allowed_origins = ['http://localhost:5173']`.
|
||||
4. Leaving the dev server running, open a new terminal and go to the project root.
|
||||
5. Run `python backend/server.py`.
|
||||
6. Navigate to the dev server address e.g. `http://localhost:5173/`.
|
||||
|
||||
- `python backend/server.py` serves both frontend and backend at http://localhost:9090
|
||||
To build for dev: `npm build-dev` / `yarn build-dev`
|
||||
|
||||
## API
|
||||
|
||||
`backend/server.py` serves the UI and provides a [socket.io](https://github.com/socketio/socket.io) API via [flask-socketio](https://github.com/miguelgrinberg/flask-socketio).
|
||||
|
||||
### Server Listeners
|
||||
|
||||
The server listens for these socket.io events:
|
||||
|
||||
`cancel`
|
||||
|
||||
- Cancels in-progress image generation
|
||||
- Returns ack only
|
||||
|
||||
`generateImage`
|
||||
|
||||
- Accepts object of image parameters
|
||||
- Generates an image
|
||||
- Returns ack only (image generation function sends progress and result via separate events)
|
||||
|
||||
`deleteImage`
|
||||
|
||||
- Accepts file path to image
|
||||
- Deletes image
|
||||
- Returns ack only
|
||||
|
||||
`deleteAllImages` WIP
|
||||
|
||||
- Deletes all images in `outputs/`
|
||||
- Returns ack only
|
||||
|
||||
`requestAllImages`
|
||||
|
||||
- Returns array of all images in `outputs/`
|
||||
|
||||
`requestCapabilities` WIP
|
||||
|
||||
- Returns capabilities of server (torch device, GFPGAN and ESRGAN availability, ???)
|
||||
|
||||
`sendImage` WIP
|
||||
|
||||
- Accepts a File and attributes
|
||||
- Saves image
|
||||
- Used to save init images which are not generated images
|
||||
|
||||
### Server Emitters
|
||||
|
||||
`progress`
|
||||
|
||||
- Emitted during each step in generation
|
||||
- Sends a number from 0 to 1 representing percentage of steps completed
|
||||
|
||||
`result` WIP
|
||||
|
||||
- Emitted when an image generation has completed
|
||||
- Sends a object:
|
||||
|
||||
```
|
||||
{
|
||||
url: relative_file_path,
|
||||
metadata: image_metadata_object
|
||||
}
|
||||
```
|
||||
To build for production: `npm build` / `yarn build`
|
||||
|
||||
## TODO
|
||||
|
||||
- Search repo for "TODO"
|
||||
- My one gripe with Chakra: no way to disable all animations right now and drop the dependence on `framer-motion`. I would prefer to save the ~30kb on bundle and have zero animations. This is on the Chakra roadmap. See https://github.com/chakra-ui/chakra-ui/pull/6368 for last discussion on this. Need to check in on this issue periodically.
|
||||
- Search repo for "TODO"
|
||||
- My one gripe with Chakra: no way to disable all animations right now and drop the dependence on
|
||||
`framer-motion`. I would prefer to save the ~30kb on bundle and have zero animations. This is on
|
||||
the Chakra roadmap. See https://github.com/chakra-ui/chakra-ui/pull/6368 for last discussion on
|
||||
this. Need to check in on this issue periodically.
|
||||
- Mobile friendly layout
|
||||
- Proper image gallery/viewer/manager
|
||||
- Help tooltips and such
|
||||
|
695
frontend/dist/assets/index.cc5cde43.js
vendored
694
frontend/dist/assets/index.de730902.js
vendored
Normal file
2
frontend/dist/index.html
vendored
@ -4,7 +4,7 @@
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>Stable Diffusion Dream Server</title>
|
||||
<script type="module" crossorigin src="/assets/index.cc5cde43.js"></script>
|
||||
<script type="module" crossorigin src="/assets/index.de730902.js"></script>
|
||||
<link rel="stylesheet" href="/assets/index.447eb2a9.css">
|
||||
</head>
|
||||
<body>
|
||||
|
@ -4,8 +4,7 @@
|
||||
"version": "0.0.0",
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
"dev": "tsc-watch --onSuccess 'yarn run vite build -m development'",
|
||||
"hmr": "vite dev",
|
||||
"dev": "vite dev",
|
||||
"build": "tsc && vite build",
|
||||
"build-dev": "tsc && vite build -m development",
|
||||
"preview": "vite preview"
|
||||
|
@ -1,60 +0,0 @@
|
||||
import { Grid, GridItem } from '@chakra-ui/react';
|
||||
import CurrentImage from './features/gallery/CurrentImage';
|
||||
import LogViewer from './features/system/LogViewer';
|
||||
import PromptInput from './features/sd/PromptInput';
|
||||
import ProgressBar from './features/header/ProgressBar';
|
||||
import { useEffect } from 'react';
|
||||
import { useAppDispatch } from './app/hooks';
|
||||
import { requestAllImages } from './app/socketio';
|
||||
import ProcessButtons from './features/sd/ProcessButtons';
|
||||
import ImageRoll from './features/gallery/ImageRoll';
|
||||
import SiteHeader from './features/header/SiteHeader';
|
||||
import OptionsAccordion from './features/sd/OptionsAccordion';
|
||||
|
||||
const App = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
useEffect(() => {
|
||||
dispatch(requestAllImages());
|
||||
}, [dispatch]);
|
||||
return (
|
||||
<>
|
||||
<Grid
|
||||
width='100vw'
|
||||
height='100vh'
|
||||
templateAreas={`
|
||||
"header header header header"
|
||||
"progressBar progressBar progressBar progressBar"
|
||||
"menu prompt processButtons imageRoll"
|
||||
"menu currentImage currentImage imageRoll"`}
|
||||
gridTemplateRows={'36px 10px 100px auto'}
|
||||
gridTemplateColumns={'350px auto 100px 388px'}
|
||||
gap={2}
|
||||
>
|
||||
<GridItem area={'header'} pt={1}>
|
||||
<SiteHeader />
|
||||
</GridItem>
|
||||
<GridItem area={'progressBar'}>
|
||||
<ProgressBar />
|
||||
</GridItem>
|
||||
<GridItem pl='2' area={'menu'} overflowY='scroll'>
|
||||
<OptionsAccordion />
|
||||
</GridItem>
|
||||
<GridItem area={'prompt'}>
|
||||
<PromptInput />
|
||||
</GridItem>
|
||||
<GridItem area={'processButtons'}>
|
||||
<ProcessButtons />
|
||||
</GridItem>
|
||||
<GridItem area={'currentImage'}>
|
||||
<CurrentImage />
|
||||
</GridItem>
|
||||
<GridItem pr='2' area={'imageRoll'} overflowY='scroll'>
|
||||
<ImageRoll />
|
||||
</GridItem>
|
||||
</Grid>
|
||||
<LogViewer />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default App;
|
68
frontend/src/app/App.tsx
Normal file
@ -0,0 +1,68 @@
|
||||
import { Grid, GridItem } from '@chakra-ui/react';
|
||||
import { useEffect, useState } from 'react';
|
||||
import CurrentImageDisplay from '../features/gallery/CurrentImageDisplay';
|
||||
import ImageGallery from '../features/gallery/ImageGallery';
|
||||
import ProgressBar from '../features/header/ProgressBar';
|
||||
import SiteHeader from '../features/header/SiteHeader';
|
||||
import OptionsAccordion from '../features/sd/OptionsAccordion';
|
||||
import ProcessButtons from '../features/sd/ProcessButtons';
|
||||
import PromptInput from '../features/sd/PromptInput';
|
||||
import LogViewer from '../features/system/LogViewer';
|
||||
import Loading from '../Loading';
|
||||
import { useAppDispatch } from './store';
|
||||
import { requestAllImages } from './socketio/actions';
|
||||
|
||||
const App = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const [isReady, setIsReady] = useState<boolean>(false);
|
||||
|
||||
// Load images from the gallery once
|
||||
useEffect(() => {
|
||||
dispatch(requestAllImages());
|
||||
setIsReady(true);
|
||||
}, [dispatch]);
|
||||
|
||||
return isReady ? (
|
||||
<>
|
||||
<Grid
|
||||
width="100vw"
|
||||
height="100vh"
|
||||
templateAreas={`
|
||||
"header header header header"
|
||||
"progressBar progressBar progressBar progressBar"
|
||||
"menu prompt processButtons imageRoll"
|
||||
"menu currentImage currentImage imageRoll"`}
|
||||
gridTemplateRows={'36px 10px 100px auto'}
|
||||
gridTemplateColumns={'350px auto 100px 388px'}
|
||||
gap={2}
|
||||
>
|
||||
<GridItem area={'header'} pt={1}>
|
||||
<SiteHeader />
|
||||
</GridItem>
|
||||
<GridItem area={'progressBar'}>
|
||||
<ProgressBar />
|
||||
</GridItem>
|
||||
<GridItem pl="2" area={'menu'} overflowY="scroll">
|
||||
<OptionsAccordion />
|
||||
</GridItem>
|
||||
<GridItem area={'prompt'}>
|
||||
<PromptInput />
|
||||
</GridItem>
|
||||
<GridItem area={'processButtons'}>
|
||||
<ProcessButtons />
|
||||
</GridItem>
|
||||
<GridItem area={'currentImage'}>
|
||||
<CurrentImageDisplay />
|
||||
</GridItem>
|
||||
<GridItem pr="2" area={'imageRoll'} overflowY="scroll">
|
||||
<ImageGallery />
|
||||
</GridItem>
|
||||
</Grid>
|
||||
<LogViewer />
|
||||
</>
|
||||
) : (
|
||||
<Loading />
|
||||
);
|
||||
};
|
||||
|
||||
export default App;
|
@ -2,52 +2,52 @@
|
||||
|
||||
// Valid samplers
|
||||
export const SAMPLERS: Array<string> = [
|
||||
'ddim',
|
||||
'plms',
|
||||
'k_lms',
|
||||
'k_dpm_2',
|
||||
'k_dpm_2_a',
|
||||
'k_euler',
|
||||
'k_euler_a',
|
||||
'k_heun',
|
||||
'ddim',
|
||||
'plms',
|
||||
'k_lms',
|
||||
'k_dpm_2',
|
||||
'k_dpm_2_a',
|
||||
'k_euler',
|
||||
'k_euler_a',
|
||||
'k_heun',
|
||||
];
|
||||
|
||||
// Valid image widths
|
||||
export const WIDTHS: Array<number> = [
|
||||
64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960,
|
||||
1024,
|
||||
64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960,
|
||||
1024,
|
||||
];
|
||||
|
||||
// Valid image heights
|
||||
export const HEIGHTS: Array<number> = [
|
||||
64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960,
|
||||
1024,
|
||||
64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960,
|
||||
1024,
|
||||
];
|
||||
|
||||
// Valid upscaling levels
|
||||
export const UPSCALING_LEVELS: Array<{ key: string; value: number }> = [
|
||||
{ key: '2x', value: 2 },
|
||||
{ key: '4x', value: 4 },
|
||||
{ key: '2x', value: 2 },
|
||||
{ key: '4x', value: 4 },
|
||||
];
|
||||
|
||||
// Internal to human-readable parameters
|
||||
export const PARAMETERS: { [key: string]: string } = {
|
||||
prompt: 'Prompt',
|
||||
iterations: 'Iterations',
|
||||
steps: 'Steps',
|
||||
cfgScale: 'CFG Scale',
|
||||
height: 'Height',
|
||||
width: 'Width',
|
||||
sampler: 'Sampler',
|
||||
seed: 'Seed',
|
||||
img2imgStrength: 'img2img Strength',
|
||||
gfpganStrength: 'GFPGAN Strength',
|
||||
upscalingLevel: 'Upscaling Level',
|
||||
upscalingStrength: 'Upscaling Strength',
|
||||
initialImagePath: 'Initial Image',
|
||||
maskPath: 'Initial Image Mask',
|
||||
shouldFitToWidthHeight: 'Fit Initial Image',
|
||||
seamless: 'Seamless Tiling',
|
||||
prompt: 'Prompt',
|
||||
iterations: 'Iterations',
|
||||
steps: 'Steps',
|
||||
cfgScale: 'CFG Scale',
|
||||
height: 'Height',
|
||||
width: 'Width',
|
||||
sampler: 'Sampler',
|
||||
seed: 'Seed',
|
||||
img2imgStrength: 'img2img Strength',
|
||||
gfpganStrength: 'GFPGAN Strength',
|
||||
upscalingLevel: 'Upscaling Level',
|
||||
upscalingStrength: 'Upscaling Strength',
|
||||
initialImagePath: 'Initial Image',
|
||||
maskPath: 'Initial Image Mask',
|
||||
shouldFitToWidthHeight: 'Fit Initial Image',
|
||||
seamless: 'Seamless Tiling',
|
||||
};
|
||||
|
||||
export const NUMPY_RAND_MIN = 0;
|
||||
|
@ -1,7 +0,0 @@
|
||||
import { useDispatch, useSelector } from 'react-redux';
|
||||
import type { TypedUseSelectorHook } from 'react-redux';
|
||||
import type { RootState, AppDispatch } from './store';
|
||||
|
||||
// Use throughout your app instead of plain `useDispatch` and `useSelector`
|
||||
export const useAppDispatch: () => AppDispatch = useDispatch;
|
||||
export const useAppSelector: TypedUseSelectorHook<RootState> = useSelector;
|
@ -1,393 +0,0 @@
|
||||
import { createAction, Middleware } from '@reduxjs/toolkit';
|
||||
import { io } from 'socket.io-client';
|
||||
import {
|
||||
addImage,
|
||||
clearIntermediateImage,
|
||||
removeImage,
|
||||
SDImage,
|
||||
SDMetadata,
|
||||
setGalleryImages,
|
||||
setIntermediateImage,
|
||||
} from '../features/gallery/gallerySlice';
|
||||
import {
|
||||
addLogEntry,
|
||||
setCurrentStep,
|
||||
setIsConnected,
|
||||
setIsProcessing,
|
||||
} from '../features/system/systemSlice';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { setInitialImagePath, setMaskPath } from '../features/sd/sdSlice';
|
||||
import {
|
||||
backendToFrontendParameters,
|
||||
frontendToBackendParameters,
|
||||
} from './parameterTranslation';
|
||||
|
||||
export interface SocketIOResponse {
|
||||
status: 'OK' | 'ERROR';
|
||||
message?: string;
|
||||
data?: any;
|
||||
}
|
||||
|
||||
export const socketioMiddleware = () => {
|
||||
const { hostname, port } = new URL(window.location.href);
|
||||
|
||||
const socketio = io(`http://${hostname}:9090`);
|
||||
|
||||
let areListenersSet = false;
|
||||
|
||||
const middleware: Middleware = (store) => (next) => (action) => {
|
||||
const { dispatch, getState } = store;
|
||||
if (!areListenersSet) {
|
||||
// CONNECT
|
||||
socketio.on('connect', () => {
|
||||
try {
|
||||
dispatch(setIsConnected(true));
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
}
|
||||
});
|
||||
|
||||
// DISCONNECT
|
||||
socketio.on('disconnect', () => {
|
||||
try {
|
||||
dispatch(setIsConnected(false));
|
||||
dispatch(setIsProcessing(false));
|
||||
dispatch(addLogEntry(`Disconnected from server`));
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
}
|
||||
});
|
||||
|
||||
// PROCESSING RESULT
|
||||
socketio.on(
|
||||
'result',
|
||||
(data: {
|
||||
url: string;
|
||||
type: 'generation' | 'esrgan' | 'gfpgan';
|
||||
uuid?: string;
|
||||
metadata: { [key: string]: any };
|
||||
}) => {
|
||||
try {
|
||||
const newUuid = uuidv4();
|
||||
const { type, url, uuid, metadata } = data;
|
||||
switch (type) {
|
||||
case 'generation': {
|
||||
const translatedMetadata =
|
||||
backendToFrontendParameters(metadata);
|
||||
dispatch(
|
||||
addImage({
|
||||
uuid: newUuid,
|
||||
url,
|
||||
metadata: translatedMetadata,
|
||||
})
|
||||
);
|
||||
dispatch(
|
||||
addLogEntry(`Image generated: ${url}`)
|
||||
);
|
||||
|
||||
break;
|
||||
}
|
||||
case 'esrgan': {
|
||||
const originalImage =
|
||||
getState().gallery.images.find(
|
||||
(i: SDImage) => i.uuid === uuid
|
||||
);
|
||||
const newMetadata = {
|
||||
...originalImage.metadata,
|
||||
};
|
||||
newMetadata.shouldRunESRGAN = true;
|
||||
newMetadata.upscalingLevel =
|
||||
metadata.upscale[0];
|
||||
newMetadata.upscalingStrength =
|
||||
metadata.upscale[1];
|
||||
dispatch(
|
||||
addImage({
|
||||
uuid: newUuid,
|
||||
url,
|
||||
metadata: newMetadata,
|
||||
})
|
||||
);
|
||||
dispatch(
|
||||
addLogEntry(`ESRGAN upscaled: ${url}`)
|
||||
);
|
||||
|
||||
break;
|
||||
}
|
||||
case 'gfpgan': {
|
||||
const originalImage =
|
||||
getState().gallery.images.find(
|
||||
(i: SDImage) => i.uuid === uuid
|
||||
);
|
||||
const newMetadata = {
|
||||
...originalImage.metadata,
|
||||
};
|
||||
newMetadata.shouldRunGFPGAN = true;
|
||||
newMetadata.gfpganStrength =
|
||||
metadata.gfpgan_strength;
|
||||
dispatch(
|
||||
addImage({
|
||||
uuid: newUuid,
|
||||
url,
|
||||
metadata: newMetadata,
|
||||
})
|
||||
);
|
||||
dispatch(
|
||||
addLogEntry(`GFPGAN fixed faces: ${url}`)
|
||||
);
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
dispatch(setIsProcessing(false));
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
// PROGRESS UPDATE
|
||||
socketio.on('progress', (data: { step: number }) => {
|
||||
try {
|
||||
dispatch(setIsProcessing(true));
|
||||
dispatch(setCurrentStep(data.step));
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
}
|
||||
});
|
||||
|
||||
// INTERMEDIATE IMAGE
|
||||
socketio.on(
|
||||
'intermediateResult',
|
||||
(data: { url: string; metadata: SDMetadata }) => {
|
||||
try {
|
||||
const uuid = uuidv4();
|
||||
const { url, metadata } = data;
|
||||
dispatch(
|
||||
setIntermediateImage({
|
||||
uuid,
|
||||
url,
|
||||
metadata,
|
||||
})
|
||||
);
|
||||
dispatch(
|
||||
addLogEntry(`Intermediate image generated: ${url}`)
|
||||
);
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
// ERROR FROM BACKEND
|
||||
socketio.on('error', (message) => {
|
||||
try {
|
||||
dispatch(addLogEntry(`Server error: ${message}`));
|
||||
dispatch(setIsProcessing(false));
|
||||
dispatch(clearIntermediateImage());
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
}
|
||||
});
|
||||
|
||||
areListenersSet = true;
|
||||
}
|
||||
|
||||
// HANDLE ACTIONS
|
||||
|
||||
switch (action.type) {
|
||||
// GENERATE IMAGE
|
||||
case 'socketio/generateImage': {
|
||||
dispatch(setIsProcessing(true));
|
||||
dispatch(setCurrentStep(-1));
|
||||
|
||||
const {
|
||||
generationParameters,
|
||||
esrganParameters,
|
||||
gfpganParameters,
|
||||
} = frontendToBackendParameters(
|
||||
getState().sd,
|
||||
getState().system
|
||||
);
|
||||
|
||||
socketio.emit(
|
||||
'generateImage',
|
||||
generationParameters,
|
||||
esrganParameters,
|
||||
gfpganParameters
|
||||
);
|
||||
|
||||
dispatch(
|
||||
addLogEntry(
|
||||
`Image generation requested: ${JSON.stringify({
|
||||
...generationParameters,
|
||||
...esrganParameters,
|
||||
...gfpganParameters,
|
||||
})}`
|
||||
)
|
||||
);
|
||||
break;
|
||||
}
|
||||
|
||||
// RUN ESRGAN (UPSCALING)
|
||||
case 'socketio/runESRGAN': {
|
||||
const imageToProcess = action.payload;
|
||||
dispatch(setIsProcessing(true));
|
||||
dispatch(setCurrentStep(-1));
|
||||
const { upscalingLevel, upscalingStrength } = getState().sd;
|
||||
const esrganParameters = {
|
||||
upscale: [upscalingLevel, upscalingStrength],
|
||||
};
|
||||
socketio.emit('runESRGAN', imageToProcess, esrganParameters);
|
||||
dispatch(
|
||||
addLogEntry(
|
||||
`ESRGAN upscale requested: ${JSON.stringify({
|
||||
file: imageToProcess.url,
|
||||
...esrganParameters,
|
||||
})}`
|
||||
)
|
||||
);
|
||||
break;
|
||||
}
|
||||
|
||||
// RUN GFPGAN (FIX FACES)
|
||||
case 'socketio/runGFPGAN': {
|
||||
const imageToProcess = action.payload;
|
||||
dispatch(setIsProcessing(true));
|
||||
dispatch(setCurrentStep(-1));
|
||||
const { gfpganStrength } = getState().sd;
|
||||
|
||||
const gfpganParameters = {
|
||||
gfpgan_strength: gfpganStrength,
|
||||
};
|
||||
socketio.emit('runGFPGAN', imageToProcess, gfpganParameters);
|
||||
dispatch(
|
||||
addLogEntry(
|
||||
`GFPGAN fix faces requested: ${JSON.stringify({
|
||||
file: imageToProcess.url,
|
||||
...gfpganParameters,
|
||||
})}`
|
||||
)
|
||||
);
|
||||
break;
|
||||
}
|
||||
|
||||
// DELETE IMAGE
|
||||
case 'socketio/deleteImage': {
|
||||
const imageToDelete = action.payload;
|
||||
const { url } = imageToDelete;
|
||||
socketio.emit(
|
||||
'deleteImage',
|
||||
url,
|
||||
(response: SocketIOResponse) => {
|
||||
if (response.status === 'OK') {
|
||||
dispatch(removeImage(imageToDelete));
|
||||
dispatch(addLogEntry(`Image deleted: ${url}`));
|
||||
}
|
||||
}
|
||||
);
|
||||
break;
|
||||
}
|
||||
|
||||
// GET ALL IMAGES FOR GALLERY
|
||||
case 'socketio/requestAllImages': {
|
||||
socketio.emit(
|
||||
'requestAllImages',
|
||||
(response: SocketIOResponse) => {
|
||||
dispatch(setGalleryImages(response.data));
|
||||
dispatch(
|
||||
addLogEntry(`Loaded ${response.data.length} images`)
|
||||
);
|
||||
}
|
||||
);
|
||||
break;
|
||||
}
|
||||
|
||||
// CANCEL PROCESSING
|
||||
case 'socketio/cancelProcessing': {
|
||||
socketio.emit('cancel', (response: SocketIOResponse) => {
|
||||
const { intermediateImage } = getState().gallery;
|
||||
if (response.status === 'OK') {
|
||||
dispatch(setIsProcessing(false));
|
||||
if (intermediateImage) {
|
||||
dispatch(addImage(intermediateImage));
|
||||
dispatch(
|
||||
addLogEntry(
|
||||
`Intermediate image saved: ${intermediateImage.url}`
|
||||
)
|
||||
);
|
||||
|
||||
dispatch(clearIntermediateImage());
|
||||
}
|
||||
dispatch(addLogEntry(`Processing canceled`));
|
||||
}
|
||||
});
|
||||
break;
|
||||
}
|
||||
|
||||
// UPLOAD INITIAL IMAGE
|
||||
case 'socketio/uploadInitialImage': {
|
||||
const file = action.payload;
|
||||
|
||||
socketio.emit(
|
||||
'uploadInitialImage',
|
||||
file,
|
||||
file.name,
|
||||
(response: SocketIOResponse) => {
|
||||
if (response.status === 'OK') {
|
||||
dispatch(setInitialImagePath(response.data));
|
||||
dispatch(
|
||||
addLogEntry(
|
||||
`Initial image uploaded: ${response.data}`
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
);
|
||||
break;
|
||||
}
|
||||
|
||||
// UPLOAD MASK IMAGE
|
||||
case 'socketio/uploadMaskImage': {
|
||||
const file = action.payload;
|
||||
|
||||
socketio.emit(
|
||||
'uploadMaskImage',
|
||||
file,
|
||||
file.name,
|
||||
(response: SocketIOResponse) => {
|
||||
if (response.status === 'OK') {
|
||||
dispatch(setMaskPath(response.data));
|
||||
dispatch(
|
||||
addLogEntry(
|
||||
`Mask image uploaded: ${response.data}`
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
next(action);
|
||||
};
|
||||
|
||||
return middleware;
|
||||
};
|
||||
|
||||
// Actions to be used by app
|
||||
|
||||
export const generateImage = createAction<undefined>('socketio/generateImage');
|
||||
export const runESRGAN = createAction<SDImage>('socketio/runESRGAN');
|
||||
export const runGFPGAN = createAction<SDImage>('socketio/runGFPGAN');
|
||||
export const deleteImage = createAction<SDImage>('socketio/deleteImage');
|
||||
export const requestAllImages = createAction<undefined>(
|
||||
'socketio/requestAllImages'
|
||||
);
|
||||
export const cancelProcessing = createAction<undefined>(
|
||||
'socketio/cancelProcessing'
|
||||
);
|
||||
export const uploadInitialImage = createAction<File>(
|
||||
'socketio/uploadInitialImage'
|
||||
);
|
||||
export const uploadMaskImage = createAction<File>('socketio/uploadMaskImage');
|
24
frontend/src/app/socketio/actions.ts
Normal file
@ -0,0 +1,24 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { SDImage } from '../../features/gallery/gallerySlice';
|
||||
|
||||
/**
|
||||
* We can't use redux-toolkit's createSlice() to make these actions,
|
||||
* because they have no associated reducer. They only exist to dispatch
|
||||
* requests to the server via socketio. These actions will be handled
|
||||
* by the middleware.
|
||||
*/
|
||||
|
||||
export const generateImage = createAction<undefined>('socketio/generateImage');
|
||||
export const runESRGAN = createAction<SDImage>('socketio/runESRGAN');
|
||||
export const runGFPGAN = createAction<SDImage>('socketio/runGFPGAN');
|
||||
export const deleteImage = createAction<SDImage>('socketio/deleteImage');
|
||||
export const requestAllImages = createAction<undefined>(
|
||||
'socketio/requestAllImages'
|
||||
);
|
||||
export const cancelProcessing = createAction<undefined>(
|
||||
'socketio/cancelProcessing'
|
||||
);
|
||||
export const uploadInitialImage = createAction<File>(
|
||||
'socketio/uploadInitialImage'
|
||||
);
|
||||
export const uploadMaskImage = createAction<File>('socketio/uploadMaskImage');
|
101
frontend/src/app/socketio/emitters.ts
Normal file
@ -0,0 +1,101 @@
|
||||
import { AnyAction, Dispatch, MiddlewareAPI } from '@reduxjs/toolkit';
|
||||
import dateFormat from 'dateformat';
|
||||
import { Socket } from 'socket.io-client';
|
||||
import { frontendToBackendParameters } from '../../common/util/parameterTranslation';
|
||||
import { SDImage } from '../../features/gallery/gallerySlice';
|
||||
import {
|
||||
addLogEntry,
|
||||
setIsProcessing,
|
||||
} from '../../features/system/systemSlice';
|
||||
|
||||
/**
|
||||
* Returns an object containing all functions which use `socketio.emit()`.
|
||||
* i.e. those which make server requests.
|
||||
*/
|
||||
const makeSocketIOEmitters = (
|
||||
store: MiddlewareAPI<Dispatch<AnyAction>, any>,
|
||||
socketio: Socket
|
||||
) => {
|
||||
// We need to dispatch actions to redux and get pieces of state from the store.
|
||||
const { dispatch, getState } = store;
|
||||
|
||||
return {
|
||||
emitGenerateImage: () => {
|
||||
dispatch(setIsProcessing(true));
|
||||
|
||||
const { generationParameters, esrganParameters, gfpganParameters } =
|
||||
frontendToBackendParameters(getState().sd, getState().system);
|
||||
|
||||
socketio.emit(
|
||||
'generateImage',
|
||||
generationParameters,
|
||||
esrganParameters,
|
||||
gfpganParameters
|
||||
);
|
||||
|
||||
dispatch(
|
||||
addLogEntry({
|
||||
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||
message: `Image generation requested: ${JSON.stringify({
|
||||
...generationParameters,
|
||||
...esrganParameters,
|
||||
...gfpganParameters,
|
||||
})}`,
|
||||
})
|
||||
);
|
||||
},
|
||||
emitRunESRGAN: (imageToProcess: SDImage) => {
|
||||
dispatch(setIsProcessing(true));
|
||||
const { upscalingLevel, upscalingStrength } = getState().sd;
|
||||
const esrganParameters = {
|
||||
upscale: [upscalingLevel, upscalingStrength],
|
||||
};
|
||||
socketio.emit('runESRGAN', imageToProcess, esrganParameters);
|
||||
dispatch(
|
||||
addLogEntry({
|
||||
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||
message: `ESRGAN upscale requested: ${JSON.stringify({
|
||||
file: imageToProcess.url,
|
||||
...esrganParameters,
|
||||
})}`,
|
||||
})
|
||||
);
|
||||
},
|
||||
emitRunGFPGAN: (imageToProcess: SDImage) => {
|
||||
dispatch(setIsProcessing(true));
|
||||
const { gfpganStrength } = getState().sd;
|
||||
|
||||
const gfpganParameters = {
|
||||
gfpgan_strength: gfpganStrength,
|
||||
};
|
||||
socketio.emit('runGFPGAN', imageToProcess, gfpganParameters);
|
||||
dispatch(
|
||||
addLogEntry({
|
||||
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||
message: `GFPGAN fix faces requested: ${JSON.stringify({
|
||||
file: imageToProcess.url,
|
||||
...gfpganParameters,
|
||||
})}`,
|
||||
})
|
||||
);
|
||||
},
|
||||
emitDeleteImage: (imageToDelete: SDImage) => {
|
||||
const { url, uuid } = imageToDelete;
|
||||
socketio.emit('deleteImage', url, uuid);
|
||||
},
|
||||
emitRequestAllImages: () => {
|
||||
socketio.emit('requestAllImages');
|
||||
},
|
||||
emitCancelProcessing: () => {
|
||||
socketio.emit('cancel');
|
||||
},
|
||||
emitUploadInitialImage: (file: File) => {
|
||||
socketio.emit('uploadInitialImage', file, file.name);
|
||||
},
|
||||
emitUploadMaskImage: (file: File) => {
|
||||
socketio.emit('uploadMaskImage', file, file.name);
|
||||
},
|
||||
};
|
||||
};
|
||||
|
||||
export default makeSocketIOEmitters;
|
338
frontend/src/app/socketio/listeners.ts
Normal file
@ -0,0 +1,338 @@
|
||||
import { AnyAction, MiddlewareAPI, Dispatch } from '@reduxjs/toolkit';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import dateFormat from 'dateformat';
|
||||
|
||||
import {
|
||||
addLogEntry,
|
||||
setIsConnected,
|
||||
setIsProcessing,
|
||||
SystemStatus,
|
||||
setSystemStatus,
|
||||
setCurrentStatus,
|
||||
} from '../../features/system/systemSlice';
|
||||
|
||||
import type {
|
||||
ServerGenerationResult,
|
||||
ServerESRGANResult,
|
||||
ServerGFPGANResult,
|
||||
ServerIntermediateResult,
|
||||
ServerError,
|
||||
ServerGalleryImages,
|
||||
ServerImageUrlAndUuid,
|
||||
ServerImageUrl,
|
||||
} from './types';
|
||||
|
||||
import { backendToFrontendParameters } from '../../common/util/parameterTranslation';
|
||||
|
||||
import {
|
||||
addImage,
|
||||
clearIntermediateImage,
|
||||
removeImage,
|
||||
SDImage,
|
||||
setGalleryImages,
|
||||
setIntermediateImage,
|
||||
} from '../../features/gallery/gallerySlice';
|
||||
|
||||
import { setInitialImagePath, setMaskPath } from '../../features/sd/sdSlice';
|
||||
|
||||
/**
|
||||
* Returns an object containing listener callbacks for socketio events.
|
||||
* TODO: This file is large, but simple. Should it be split up further?
|
||||
*/
|
||||
const makeSocketIOListeners = (
|
||||
store: MiddlewareAPI<Dispatch<AnyAction>, any>
|
||||
) => {
|
||||
const { dispatch, getState } = store;
|
||||
|
||||
return {
|
||||
/**
|
||||
* Callback to run when we receive a 'connect' event.
|
||||
*/
|
||||
onConnect: () => {
|
||||
try {
|
||||
dispatch(setIsConnected(true));
|
||||
dispatch(setCurrentStatus('Connected'));
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
}
|
||||
},
|
||||
/**
|
||||
* Callback to run when we receive a 'disconnect' event.
|
||||
*/
|
||||
onDisconnect: () => {
|
||||
try {
|
||||
dispatch(setIsConnected(false));
|
||||
dispatch(setIsProcessing(false));
|
||||
dispatch(setCurrentStatus('Disconnected'));
|
||||
|
||||
dispatch(
|
||||
addLogEntry({
|
||||
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||
message: `Disconnected from server`,
|
||||
level: 'warning',
|
||||
})
|
||||
);
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
}
|
||||
},
|
||||
/**
|
||||
* Callback to run when we receive a 'generationResult' event.
|
||||
*/
|
||||
onGenerationResult: (data: ServerGenerationResult) => {
|
||||
try {
|
||||
const { url, metadata } = data;
|
||||
const newUuid = uuidv4();
|
||||
|
||||
const translatedMetadata = backendToFrontendParameters(metadata);
|
||||
|
||||
dispatch(
|
||||
addImage({
|
||||
uuid: newUuid,
|
||||
url,
|
||||
metadata: translatedMetadata,
|
||||
})
|
||||
);
|
||||
dispatch(
|
||||
addLogEntry({
|
||||
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||
message: `Image generated: ${url}`,
|
||||
})
|
||||
);
|
||||
dispatch(setIsProcessing(false));
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
}
|
||||
},
|
||||
/**
|
||||
* Callback to run when we receive a 'intermediateResult' event.
|
||||
*/
|
||||
onIntermediateResult: (data: ServerIntermediateResult) => {
|
||||
try {
|
||||
const uuid = uuidv4();
|
||||
const { url, metadata } = data;
|
||||
dispatch(
|
||||
setIntermediateImage({
|
||||
uuid,
|
||||
url,
|
||||
metadata,
|
||||
})
|
||||
);
|
||||
dispatch(
|
||||
addLogEntry({
|
||||
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||
message: `Intermediate image generated: ${url}`,
|
||||
})
|
||||
);
|
||||
dispatch(setIsProcessing(false));
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
}
|
||||
},
|
||||
/**
|
||||
* Callback to run when we receive an 'esrganResult' event.
|
||||
*/
|
||||
onESRGANResult: (data: ServerESRGANResult) => {
|
||||
try {
|
||||
const { url, uuid, metadata } = data;
|
||||
const newUuid = uuidv4();
|
||||
|
||||
// This image was only ESRGAN'd, grab the original image's metadata
|
||||
const originalImage = getState().gallery.images.find(
|
||||
(i: SDImage) => i.uuid === uuid
|
||||
);
|
||||
|
||||
// Retain the original metadata
|
||||
const newMetadata = {
|
||||
...originalImage.metadata,
|
||||
};
|
||||
|
||||
// Update the ESRGAN-related fields
|
||||
newMetadata.shouldRunESRGAN = true;
|
||||
newMetadata.upscalingLevel = metadata.upscale[0];
|
||||
newMetadata.upscalingStrength = metadata.upscale[1];
|
||||
|
||||
dispatch(
|
||||
addImage({
|
||||
uuid: newUuid,
|
||||
url,
|
||||
metadata: newMetadata,
|
||||
})
|
||||
);
|
||||
|
||||
dispatch(
|
||||
addLogEntry({
|
||||
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||
message: `Upscaled: ${url}`,
|
||||
})
|
||||
);
|
||||
dispatch(setIsProcessing(false));
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
}
|
||||
},
|
||||
/**
|
||||
* Callback to run when we receive a 'gfpganResult' event.
|
||||
*/
|
||||
onGFPGANResult: (data: ServerGFPGANResult) => {
|
||||
try {
|
||||
const { url, uuid, metadata } = data;
|
||||
const newUuid = uuidv4();
|
||||
|
||||
// This image was only GFPGAN'd, grab the original image's metadata
|
||||
const originalImage = getState().gallery.images.find(
|
||||
(i: SDImage) => i.uuid === uuid
|
||||
);
|
||||
|
||||
// Retain the original metadata
|
||||
const newMetadata = {
|
||||
...originalImage.metadata,
|
||||
};
|
||||
|
||||
// Update the GFPGAN-related fields
|
||||
newMetadata.shouldRunGFPGAN = true;
|
||||
newMetadata.gfpganStrength = metadata.gfpgan_strength;
|
||||
|
||||
dispatch(
|
||||
addImage({
|
||||
uuid: newUuid,
|
||||
url,
|
||||
metadata: newMetadata,
|
||||
})
|
||||
);
|
||||
|
||||
dispatch(
|
||||
addLogEntry({
|
||||
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||
message: `Fixed faces: ${url}`,
|
||||
})
|
||||
);
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
}
|
||||
},
|
||||
/**
|
||||
* Callback to run when we receive a 'progressUpdate' event.
|
||||
* TODO: Add additional progress phases
|
||||
*/
|
||||
onProgressUpdate: (data: SystemStatus) => {
|
||||
try {
|
||||
dispatch(setIsProcessing(true));
|
||||
dispatch(setSystemStatus(data));
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
}
|
||||
},
|
||||
/**
|
||||
* Callback to run when we receive a 'progressUpdate' event.
|
||||
*/
|
||||
onError: (data: ServerError) => {
|
||||
const { message, additionalData } = data;
|
||||
|
||||
if (additionalData) {
|
||||
// TODO: handle more data than short message
|
||||
}
|
||||
|
||||
try {
|
||||
dispatch(
|
||||
addLogEntry({
|
||||
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||
message: `Server error: ${message}`,
|
||||
level: 'error',
|
||||
})
|
||||
);
|
||||
dispatch(setIsProcessing(false));
|
||||
dispatch(clearIntermediateImage());
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
}
|
||||
},
|
||||
/**
|
||||
* Callback to run when we receive a 'galleryImages' event.
|
||||
*/
|
||||
onGalleryImages: (data: ServerGalleryImages) => {
|
||||
const { images } = data;
|
||||
const preparedImages = images.map((image): SDImage => {
|
||||
return {
|
||||
uuid: uuidv4(),
|
||||
url: image.path,
|
||||
metadata: backendToFrontendParameters(image.metadata),
|
||||
};
|
||||
});
|
||||
dispatch(setGalleryImages(preparedImages));
|
||||
dispatch(
|
||||
addLogEntry({
|
||||
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||
message: `Loaded ${images.length} images`,
|
||||
})
|
||||
);
|
||||
},
|
||||
/**
|
||||
* Callback to run when we receive a 'processingCanceled' event.
|
||||
*/
|
||||
onProcessingCanceled: () => {
|
||||
dispatch(setIsProcessing(false));
|
||||
|
||||
const { intermediateImage } = getState().gallery;
|
||||
|
||||
if (intermediateImage) {
|
||||
dispatch(addImage(intermediateImage));
|
||||
dispatch(
|
||||
addLogEntry({
|
||||
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||
message: `Intermediate image saved: ${intermediateImage.url}`,
|
||||
})
|
||||
);
|
||||
dispatch(clearIntermediateImage());
|
||||
}
|
||||
dispatch(
|
||||
addLogEntry({
|
||||
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||
message: `Processing canceled`,
|
||||
level: 'warning',
|
||||
})
|
||||
);
|
||||
},
|
||||
/**
|
||||
* Callback to run when we receive a 'imageDeleted' event.
|
||||
*/
|
||||
onImageDeleted: (data: ServerImageUrlAndUuid) => {
|
||||
const { url, uuid } = data;
|
||||
dispatch(removeImage(uuid));
|
||||
dispatch(
|
||||
addLogEntry({
|
||||
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||
message: `Image deleted: ${url}`,
|
||||
})
|
||||
);
|
||||
},
|
||||
/**
|
||||
* Callback to run when we receive a 'initialImageUploaded' event.
|
||||
*/
|
||||
onInitialImageUploaded: (data: ServerImageUrl) => {
|
||||
const { url } = data;
|
||||
dispatch(setInitialImagePath(url));
|
||||
dispatch(
|
||||
addLogEntry({
|
||||
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||
message: `Initial image uploaded: ${url}`,
|
||||
})
|
||||
);
|
||||
},
|
||||
/**
|
||||
* Callback to run when we receive a 'maskImageUploaded' event.
|
||||
*/
|
||||
onMaskImageUploaded: (data: ServerImageUrl) => {
|
||||
const { url } = data;
|
||||
dispatch(setMaskPath(url));
|
||||
dispatch(
|
||||
addLogEntry({
|
||||
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||
message: `Mask image uploaded: ${url}`,
|
||||
})
|
||||
);
|
||||
},
|
||||
};
|
||||
};
|
||||
|
||||
export default makeSocketIOListeners;
|
157
frontend/src/app/socketio/middleware.ts
Normal file
@ -0,0 +1,157 @@
|
||||
import { Middleware } from '@reduxjs/toolkit';
|
||||
import { io } from 'socket.io-client';
|
||||
|
||||
import makeSocketIOListeners from './listeners';
|
||||
import makeSocketIOEmitters from './emitters';
|
||||
|
||||
import type {
|
||||
ServerGenerationResult,
|
||||
ServerESRGANResult,
|
||||
ServerGFPGANResult,
|
||||
ServerIntermediateResult,
|
||||
ServerError,
|
||||
ServerGalleryImages,
|
||||
ServerImageUrlAndUuid,
|
||||
ServerImageUrl,
|
||||
} from './types';
|
||||
import { SystemStatus } from '../../features/system/systemSlice';
|
||||
|
||||
export const socketioMiddleware = () => {
|
||||
const { hostname, port } = new URL(window.location.href);
|
||||
|
||||
const socketio = io(`http://${hostname}:9090`);
|
||||
|
||||
let areListenersSet = false;
|
||||
|
||||
const middleware: Middleware = (store) => (next) => (action) => {
|
||||
const {
|
||||
onConnect,
|
||||
onDisconnect,
|
||||
onError,
|
||||
onESRGANResult,
|
||||
onGFPGANResult,
|
||||
onGenerationResult,
|
||||
onIntermediateResult,
|
||||
onProgressUpdate,
|
||||
onGalleryImages,
|
||||
onProcessingCanceled,
|
||||
onImageDeleted,
|
||||
onInitialImageUploaded,
|
||||
onMaskImageUploaded,
|
||||
} = makeSocketIOListeners(store);
|
||||
|
||||
const {
|
||||
emitGenerateImage,
|
||||
emitRunESRGAN,
|
||||
emitRunGFPGAN,
|
||||
emitDeleteImage,
|
||||
emitRequestAllImages,
|
||||
emitCancelProcessing,
|
||||
emitUploadInitialImage,
|
||||
emitUploadMaskImage,
|
||||
} = makeSocketIOEmitters(store, socketio);
|
||||
|
||||
/**
|
||||
* If this is the first time the middleware has been called (e.g. during store setup),
|
||||
* initialize all our socket.io listeners.
|
||||
*/
|
||||
if (!areListenersSet) {
|
||||
socketio.on('connect', () => onConnect());
|
||||
|
||||
socketio.on('disconnect', () => onDisconnect());
|
||||
|
||||
socketio.on('error', (data: ServerError) => onError(data));
|
||||
|
||||
socketio.on('generationResult', (data: ServerGenerationResult) =>
|
||||
onGenerationResult(data)
|
||||
);
|
||||
|
||||
socketio.on('esrganResult', (data: ServerESRGANResult) =>
|
||||
onESRGANResult(data)
|
||||
);
|
||||
|
||||
socketio.on('gfpganResult', (data: ServerGFPGANResult) =>
|
||||
onGFPGANResult(data)
|
||||
);
|
||||
|
||||
socketio.on('intermediateResult', (data: ServerIntermediateResult) =>
|
||||
onIntermediateResult(data)
|
||||
);
|
||||
|
||||
socketio.on('progressUpdate', (data: SystemStatus) =>
|
||||
onProgressUpdate(data)
|
||||
);
|
||||
|
||||
socketio.on('galleryImages', (data: ServerGalleryImages) =>
|
||||
onGalleryImages(data)
|
||||
);
|
||||
|
||||
socketio.on('processingCanceled', () => {
|
||||
onProcessingCanceled();
|
||||
});
|
||||
|
||||
socketio.on('imageDeleted', (data: ServerImageUrlAndUuid) => {
|
||||
onImageDeleted(data);
|
||||
});
|
||||
|
||||
socketio.on('initialImageUploaded', (data: ServerImageUrl) => {
|
||||
onInitialImageUploaded(data);
|
||||
});
|
||||
|
||||
socketio.on('maskImageUploaded', (data: ServerImageUrl) => {
|
||||
onMaskImageUploaded(data);
|
||||
});
|
||||
|
||||
areListenersSet = true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle redux actions caught by middleware.
|
||||
*/
|
||||
switch (action.type) {
|
||||
case 'socketio/generateImage': {
|
||||
emitGenerateImage();
|
||||
break;
|
||||
}
|
||||
|
||||
case 'socketio/runESRGAN': {
|
||||
emitRunESRGAN(action.payload);
|
||||
break;
|
||||
}
|
||||
|
||||
case 'socketio/runGFPGAN': {
|
||||
emitRunGFPGAN(action.payload);
|
||||
break;
|
||||
}
|
||||
|
||||
case 'socketio/deleteImage': {
|
||||
emitDeleteImage(action.payload);
|
||||
break;
|
||||
}
|
||||
|
||||
case 'socketio/requestAllImages': {
|
||||
emitRequestAllImages();
|
||||
break;
|
||||
}
|
||||
|
||||
case 'socketio/cancelProcessing': {
|
||||
emitCancelProcessing();
|
||||
break;
|
||||
}
|
||||
|
||||
case 'socketio/uploadInitialImage': {
|
||||
emitUploadInitialImage(action.payload);
|
||||
break;
|
||||
}
|
||||
|
||||
case 'socketio/uploadMaskImage': {
|
||||
emitUploadMaskImage(action.payload);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
next(action);
|
||||
};
|
||||
|
||||
return middleware;
|
||||
};
|
46
frontend/src/app/socketio/types.d.ts
vendored
Normal file
@ -0,0 +1,46 @@
|
||||
/**
|
||||
* Interfaces used by the socketio middleware.
|
||||
*/
|
||||
|
||||
export declare interface ServerGenerationResult {
|
||||
url: string;
|
||||
metadata: { [key: string]: any };
|
||||
}
|
||||
|
||||
export declare interface ServerESRGANResult {
|
||||
url: string;
|
||||
uuid: string;
|
||||
metadata: { [key: string]: any };
|
||||
}
|
||||
|
||||
export declare interface ServerGFPGANResult {
|
||||
url: string;
|
||||
uuid: string;
|
||||
metadata: { [key: string]: any };
|
||||
}
|
||||
|
||||
export declare interface ServerIntermediateResult {
|
||||
url: string;
|
||||
metadata: { [key: string]: any };
|
||||
}
|
||||
|
||||
export declare interface ServerError {
|
||||
message: string;
|
||||
additionalData?: string;
|
||||
}
|
||||
|
||||
export declare interface ServerGalleryImages {
|
||||
images: Array<{
|
||||
path: string;
|
||||
metadata: { [key: string]: any };
|
||||
}>;
|
||||
}
|
||||
|
||||
export declare interface ServerImageUrlAndUuid {
|
||||
uuid: string;
|
||||
url: string;
|
||||
}
|
||||
|
||||
export declare interface ServerImageUrl {
|
||||
url: string;
|
||||
}
|
@ -1,53 +1,78 @@
|
||||
import { combineReducers, configureStore } from '@reduxjs/toolkit';
|
||||
import { useDispatch, useSelector } from 'react-redux';
|
||||
import type { TypedUseSelectorHook } from 'react-redux';
|
||||
|
||||
import { persistReducer } from 'redux-persist';
|
||||
import storage from 'redux-persist/lib/storage'; // defaults to localStorage for web
|
||||
|
||||
import sdReducer from '../features/sd/sdSlice';
|
||||
import galleryReducer from '../features/gallery/gallerySlice';
|
||||
import systemReducer from '../features/system/systemSlice';
|
||||
import { socketioMiddleware } from './socketio';
|
||||
import { socketioMiddleware } from './socketio/middleware';
|
||||
|
||||
/**
|
||||
* redux-persist provides an easy and reliable way to persist state across reloads.
|
||||
*
|
||||
* While we definitely want generation parameters to be persisted, there are a number
|
||||
* of things we do *not* want to be persisted across reloads:
|
||||
* - Gallery/selected image (user may add/delete images from disk between page loads)
|
||||
* - Connection/processing status
|
||||
* - Availability of external libraries like ESRGAN/GFPGAN
|
||||
*
|
||||
* These can be blacklisted in redux-persist.
|
||||
*
|
||||
* The necesssary nested persistors with blacklists are configured below.
|
||||
*
|
||||
* TODO: Do we blacklist initialImagePath? If the image is deleted from disk we get an
|
||||
* ugly 404. But if we blacklist it, then this is a valuable parameter that is lost
|
||||
* on reload. Need to figure out a good way to handle this.
|
||||
*/
|
||||
|
||||
const rootPersistConfig = {
|
||||
key: 'root',
|
||||
storage,
|
||||
blacklist: ['gallery', 'system'],
|
||||
};
|
||||
|
||||
const systemPersistConfig = {
|
||||
key: 'system',
|
||||
storage,
|
||||
blacklist: [
|
||||
'isConnected',
|
||||
'isProcessing',
|
||||
'currentStep',
|
||||
'socketId',
|
||||
'isESRGANAvailable',
|
||||
'isGFPGANAvailable',
|
||||
'currentStep',
|
||||
'totalSteps',
|
||||
'currentIteration',
|
||||
'totalIterations',
|
||||
'currentStatus',
|
||||
],
|
||||
};
|
||||
|
||||
const reducers = combineReducers({
|
||||
sd: sdReducer,
|
||||
gallery: galleryReducer,
|
||||
system: systemReducer,
|
||||
system: persistReducer(systemPersistConfig, systemReducer),
|
||||
});
|
||||
|
||||
const persistConfig = {
|
||||
key: 'root',
|
||||
storage,
|
||||
};
|
||||
|
||||
const persistedReducer = persistReducer(persistConfig, reducers);
|
||||
|
||||
/*
|
||||
The frontend needs to be distributed as a production build, so
|
||||
we cannot reasonably ask users to edit the JS and specify the
|
||||
host and port on which the socket.io server will run.
|
||||
|
||||
The solution is to allow server script to be run with arguments
|
||||
(or just edited) providing the host and port. Then, the server
|
||||
serves a route `/socketio_config` which responds with the host
|
||||
and port.
|
||||
|
||||
When the frontend loads, it synchronously requests that route
|
||||
and thus gets the host and port. This requires a suspicious
|
||||
fetch somewhere, and the store setup seems like as good a place
|
||||
as any to make this fetch request.
|
||||
*/
|
||||
|
||||
const persistedReducer = persistReducer(rootPersistConfig, reducers);
|
||||
|
||||
// Continue with store setup
|
||||
export const store = configureStore({
|
||||
reducer: persistedReducer,
|
||||
middleware: (getDefaultMiddleware) =>
|
||||
getDefaultMiddleware({
|
||||
// redux-persist sometimes needs to have a function in redux, need to disable this check
|
||||
// redux-persist sometimes needs to temporarily put a function in redux state, need to disable this check
|
||||
serializableCheck: false,
|
||||
}).concat(socketioMiddleware()),
|
||||
});
|
||||
|
||||
// Infer the `RootState` and `AppDispatch` types from the store itself
|
||||
export type RootState = ReturnType<typeof store.getState>;
|
||||
// Inferred type: {posts: PostsState, comments: CommentsState, users: UsersState}
|
||||
export type AppDispatch = typeof store.dispatch;
|
||||
|
||||
// Use throughout your app instead of plain `useDispatch` and `useSelector`
|
||||
export const useAppDispatch: () => AppDispatch = useDispatch;
|
||||
export const useAppSelector: TypedUseSelectorHook<RootState> = useSelector;
|
||||
|
171
frontend/src/app/wip_types.ts
Normal file
@ -0,0 +1,171 @@
|
||||
|
||||
/**
|
||||
* Defines common parameters required to generate an image.
|
||||
* See #266 for the eventual maturation of this interface.
|
||||
*/
|
||||
interface CommonParameters {
|
||||
/**
|
||||
* The "txt2img" prompt. String. Minimum one character. No maximum.
|
||||
*/
|
||||
prompt: string;
|
||||
/**
|
||||
* The number of sampler steps. Integer. Minimum value 1. No maximum.
|
||||
*/
|
||||
steps: number;
|
||||
/**
|
||||
* Classifier-free guidance scale. Float. Minimum value 0. Maximum?
|
||||
*/
|
||||
cfgScale: number;
|
||||
/**
|
||||
* Height of output image in pixels. Integer. Minimum 64. Must be multiple of 64. No maximum.
|
||||
*/
|
||||
height: number;
|
||||
/**
|
||||
* Width of output image in pixels. Integer. Minimum 64. Must be multiple of 64. No maximum.
|
||||
*/
|
||||
width: number;
|
||||
/**
|
||||
* Name of the sampler to use. String. Restricted values.
|
||||
*/
|
||||
sampler:
|
||||
| 'ddim'
|
||||
| 'plms'
|
||||
| 'k_lms'
|
||||
| 'k_dpm_2'
|
||||
| 'k_dpm_2_a'
|
||||
| 'k_euler'
|
||||
| 'k_euler_a'
|
||||
| 'k_heun';
|
||||
/**
|
||||
* Seed used for randomness. Integer. 0 --> 4294967295, inclusive.
|
||||
*/
|
||||
seed: number;
|
||||
/**
|
||||
* Flag to enable seamless tiling image generation. Boolean.
|
||||
*/
|
||||
seamless: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Defines parameters needed to use the "img2img" generation method.
|
||||
*/
|
||||
interface ImageToImageParameters {
|
||||
/**
|
||||
* Folder path to the image used as the initial image. String.
|
||||
*/
|
||||
initialImagePath: string;
|
||||
/**
|
||||
* Flag to enable the use of a mask image during "img2img" generations.
|
||||
* Requires valid ImageToImageParameters. Boolean.
|
||||
*/
|
||||
shouldUseMaskImage: boolean;
|
||||
/**
|
||||
* Folder path to the image used as a mask image. String.
|
||||
*/
|
||||
maskImagePath: string;
|
||||
/**
|
||||
* Strength of adherance to initial image. Float. 0 --> 1, exclusive.
|
||||
*/
|
||||
img2imgStrength: number;
|
||||
/**
|
||||
* Flag to enable the stretching of init image to desired output. Boolean.
|
||||
*/
|
||||
shouldFit: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Defines the parameters needed to generate variations.
|
||||
*/
|
||||
interface VariationParameters {
|
||||
/**
|
||||
* Variation amount. Float. 0 --> 1, exclusive.
|
||||
* TODO: What does this really do?
|
||||
*/
|
||||
variationAmount: number;
|
||||
/**
|
||||
* List of seed-weight pairs formatted as "seed:weight,...".
|
||||
* Seed is a valid seed. Weight is a float, 0 --> 1, exclusive.
|
||||
* String, must be parseable into [[seed,weight],...] format.
|
||||
*/
|
||||
seedWeights: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Defines the parameters needed to use GFPGAN postprocessing.
|
||||
*/
|
||||
interface GFPGANParameters {
|
||||
/**
|
||||
* GFPGAN strength. Strength to apply face-fixing processing. Float. 0 --> 1, exclusive.
|
||||
*/
|
||||
gfpganStrength: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Defines the parameters needed to use ESRGAN postprocessing.
|
||||
*/
|
||||
interface ESRGANParameters {
|
||||
/**
|
||||
* ESRGAN strength. Strength to apply upscaling. Float. 0 --> 1, exclusive.
|
||||
*/
|
||||
esrganStrength: number;
|
||||
/**
|
||||
* ESRGAN upscaling scale. One of 2x | 4x. Represented as integer.
|
||||
*/
|
||||
esrganScale: 2 | 4;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extends the generation and processing method parameters, adding flags to enable each.
|
||||
*/
|
||||
interface ProcessingParameters extends CommonParameters {
|
||||
/**
|
||||
* Flag to enable the generation of variations. Requires valid VariationParameters. Boolean.
|
||||
*/
|
||||
shouldGenerateVariations: boolean;
|
||||
/**
|
||||
* Variation parameters.
|
||||
*/
|
||||
variationParameters: VariationParameters;
|
||||
/**
|
||||
* Flag to enable the use of an initial image, i.e. to use "img2img" generation.
|
||||
* Requires valid ImageToImageParameters. Boolean.
|
||||
*/
|
||||
shouldUseImageToImage: boolean;
|
||||
/**
|
||||
* ImageToImage parameters.
|
||||
*/
|
||||
imageToImageParameters: ImageToImageParameters;
|
||||
/**
|
||||
* Flag to enable GFPGAN postprocessing. Requires valid GFPGANParameters. Boolean.
|
||||
*/
|
||||
shouldRunGFPGAN: boolean;
|
||||
/**
|
||||
* GFPGAN parameters.
|
||||
*/
|
||||
gfpganParameters: GFPGANParameters;
|
||||
/**
|
||||
* Flag to enable ESRGAN postprocessing. Requires valid ESRGANParameters. Boolean.
|
||||
*/
|
||||
shouldRunESRGAN: boolean;
|
||||
/**
|
||||
* ESRGAN parameters.
|
||||
*/
|
||||
esrganParameters: GFPGANParameters;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extends ProcessingParameters, adding items needed to request processing.
|
||||
*/
|
||||
interface ProcessingState extends ProcessingParameters {
|
||||
/**
|
||||
* Number of images to generate. Integer. Minimum 1.
|
||||
*/
|
||||
iterations: number;
|
||||
/**
|
||||
* Flag to enable the randomization of the seed on each generation. Boolean.
|
||||
*/
|
||||
shouldRandomizeSeed: boolean;
|
||||
}
|
||||
|
||||
|
||||
export {}
|
21
frontend/src/common/components/SDButton.tsx
Normal file
@ -0,0 +1,21 @@
|
||||
import { Button, ButtonProps } from '@chakra-ui/react';
|
||||
|
||||
interface Props extends ButtonProps {
|
||||
label: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Reusable customized button component. Originally was more customized - now probably unecessary.
|
||||
*
|
||||
* TODO: Get rid of this.
|
||||
*/
|
||||
const SDButton = (props: Props) => {
|
||||
const { label, size = 'sm', ...rest } = props;
|
||||
return (
|
||||
<Button size={size} {...rest}>
|
||||
{label}
|
||||
</Button>
|
||||
);
|
||||
};
|
||||
|
||||
export default SDButton;
|
@ -16,6 +16,9 @@ interface Props extends NumberInputProps {
|
||||
width?: string | number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Customized Chakra FormControl + NumberInput multi-part component.
|
||||
*/
|
||||
const SDNumberInput = (props: Props) => {
|
||||
const {
|
||||
label,
|
||||
@ -31,7 +34,7 @@ const SDNumberInput = (props: Props) => {
|
||||
<Flex gap={2} justifyContent={'space-between'} alignItems={'center'}>
|
||||
{label && (
|
||||
<FormLabel marginBottom={1}>
|
||||
<Text fontSize={fontSize} whiteSpace='nowrap'>
|
||||
<Text fontSize={fontSize} whiteSpace="nowrap">
|
||||
{label}
|
||||
</Text>
|
||||
</FormLabel>
|
||||
@ -42,7 +45,7 @@ const SDNumberInput = (props: Props) => {
|
||||
keepWithinRange={false}
|
||||
clampValueOnBlur={true}
|
||||
>
|
||||
<NumberInputField fontSize={'md'}/>
|
||||
<NumberInputField fontSize={'md'} />
|
||||
<NumberInputStepper>
|
||||
<NumberIncrementStepper />
|
||||
<NumberDecrementStepper />
|
56
frontend/src/common/components/SDSelect.tsx
Normal file
@ -0,0 +1,56 @@
|
||||
import {
|
||||
Flex,
|
||||
FormControl,
|
||||
FormLabel,
|
||||
Select,
|
||||
SelectProps,
|
||||
Text,
|
||||
} from '@chakra-ui/react';
|
||||
|
||||
interface Props extends SelectProps {
|
||||
label: string;
|
||||
validValues:
|
||||
| Array<number | string>
|
||||
| Array<{ key: string; value: string | number }>;
|
||||
}
|
||||
/**
|
||||
* Customized Chakra FormControl + Select multi-part component.
|
||||
*/
|
||||
const SDSelect = (props: Props) => {
|
||||
const {
|
||||
label,
|
||||
isDisabled,
|
||||
validValues,
|
||||
size = 'sm',
|
||||
fontSize = 'md',
|
||||
marginBottom = 1,
|
||||
whiteSpace = 'nowrap',
|
||||
...rest
|
||||
} = props;
|
||||
return (
|
||||
<FormControl isDisabled={isDisabled}>
|
||||
<Flex justifyContent={'space-between'} alignItems={'center'}>
|
||||
<FormLabel marginBottom={marginBottom}>
|
||||
<Text fontSize={fontSize} whiteSpace={whiteSpace}>
|
||||
{label}
|
||||
</Text>
|
||||
</FormLabel>
|
||||
<Select fontSize={fontSize} size={size} {...rest}>
|
||||
{validValues.map((opt) => {
|
||||
return typeof opt === 'string' || typeof opt === 'number' ? (
|
||||
<option key={opt} value={opt}>
|
||||
{opt}
|
||||
</option>
|
||||
) : (
|
||||
<option key={opt.value} value={opt.value}>
|
||||
{opt.key}
|
||||
</option>
|
||||
);
|
||||
})}
|
||||
</Select>
|
||||
</Flex>
|
||||
</FormControl>
|
||||
);
|
||||
};
|
||||
|
||||
export default SDSelect;
|
@ -11,6 +11,9 @@ interface Props extends SwitchProps {
|
||||
width?: string | number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Customized Chakra FormControl + Switch multi-part component.
|
||||
*/
|
||||
const SDSwitch = (props: Props) => {
|
||||
const {
|
||||
label,
|
||||
@ -28,7 +31,7 @@ const SDSwitch = (props: Props) => {
|
||||
fontSize={fontSize}
|
||||
marginBottom={1}
|
||||
flexGrow={2}
|
||||
whiteSpace='nowrap'
|
||||
whiteSpace="nowrap"
|
||||
>
|
||||
{label}
|
||||
</FormLabel>
|
104
frontend/src/common/hooks/useCheckParameters.ts
Normal file
@ -0,0 +1,104 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { isEqual } from 'lodash';
|
||||
import { useMemo } from 'react';
|
||||
import { useAppSelector } from '../../app/store';
|
||||
import { RootState } from '../../app/store';
|
||||
import { SDState } from '../../features/sd/sdSlice';
|
||||
import { SystemState } from '../../features/system/systemSlice';
|
||||
import { validateSeedWeights } from '../util/seedWeightPairs';
|
||||
|
||||
const sdSelector = createSelector(
|
||||
(state: RootState) => state.sd,
|
||||
(sd: SDState) => {
|
||||
return {
|
||||
prompt: sd.prompt,
|
||||
shouldGenerateVariations: sd.shouldGenerateVariations,
|
||||
seedWeights: sd.seedWeights,
|
||||
maskPath: sd.maskPath,
|
||||
initialImagePath: sd.initialImagePath,
|
||||
seed: sd.seed,
|
||||
};
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
const systemSelector = createSelector(
|
||||
(state: RootState) => state.system,
|
||||
(system: SystemState) => {
|
||||
return {
|
||||
isProcessing: system.isProcessing,
|
||||
isConnected: system.isConnected,
|
||||
};
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
/**
|
||||
* Checks relevant pieces of state to confirm generation will not deterministically fail.
|
||||
* This is used to prevent the 'Generate' button from being clicked.
|
||||
*/
|
||||
const useCheckParameters = (): boolean => {
|
||||
const {
|
||||
prompt,
|
||||
shouldGenerateVariations,
|
||||
seedWeights,
|
||||
maskPath,
|
||||
initialImagePath,
|
||||
seed,
|
||||
} = useAppSelector(sdSelector);
|
||||
|
||||
const { isProcessing, isConnected } = useAppSelector(systemSelector);
|
||||
|
||||
return useMemo(() => {
|
||||
// Cannot generate without a prompt
|
||||
if (!prompt) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Cannot generate with a mask without img2img
|
||||
if (maskPath && !initialImagePath) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: job queue
|
||||
// Cannot generate if already processing an image
|
||||
if (isProcessing) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Cannot generate if not connected
|
||||
if (!isConnected) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Cannot generate variations without valid seed weights
|
||||
if (
|
||||
shouldGenerateVariations &&
|
||||
(!(validateSeedWeights(seedWeights) || seedWeights === '') || seed === -1)
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// All good
|
||||
return true;
|
||||
}, [
|
||||
prompt,
|
||||
maskPath,
|
||||
initialImagePath,
|
||||
isProcessing,
|
||||
isConnected,
|
||||
shouldGenerateVariations,
|
||||
seedWeights,
|
||||
seed,
|
||||
]);
|
||||
};
|
||||
|
||||
export default useCheckParameters;
|
@ -1,17 +1,15 @@
|
||||
import { SDState } from '../features/sd/sdSlice';
|
||||
import randomInt from '../features/sd/util/randomInt';
|
||||
import {
|
||||
seedWeightsToString,
|
||||
stringToSeedWeights,
|
||||
} from '../features/sd/util/seedWeightPairs';
|
||||
import { SystemState } from '../features/system/systemSlice';
|
||||
import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from './constants';
|
||||
|
||||
/*
|
||||
These functions translate frontend state into parameters
|
||||
suitable for consumption by the backend, and vice-versa.
|
||||
*/
|
||||
|
||||
import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from "../../app/constants";
|
||||
import { SDState } from "../../features/sd/sdSlice";
|
||||
import { SystemState } from "../../features/system/systemSlice";
|
||||
import randomInt from "./randomInt";
|
||||
import { seedWeightsToString, stringToSeedWeights } from "./seedWeightPairs";
|
||||
|
||||
export const frontendToBackendParameters = (
|
||||
sdState: SDState,
|
||||
systemState: SystemState
|
||||
@ -32,7 +30,7 @@ export const frontendToBackendParameters = (
|
||||
maskPath,
|
||||
shouldFitToWidthHeight,
|
||||
shouldGenerateVariations,
|
||||
variantAmount,
|
||||
variationAmount,
|
||||
seedWeights,
|
||||
shouldRunESRGAN,
|
||||
upscalingLevel,
|
||||
@ -71,7 +69,7 @@ export const frontendToBackendParameters = (
|
||||
}
|
||||
|
||||
if (shouldGenerateVariations) {
|
||||
generationParameters.variation_amount = variantAmount;
|
||||
generationParameters.variation_amount = variationAmount;
|
||||
if (seedWeights) {
|
||||
generationParameters.with_variations =
|
||||
stringToSeedWeights(seedWeights);
|
||||
@ -138,7 +136,7 @@ export const backendToFrontendParameters = (parameters: {
|
||||
|
||||
if (variation_amount > 0) {
|
||||
sd.shouldGenerateVariations = true;
|
||||
sd.variantAmount = variation_amount;
|
||||
sd.variationAmount = variation_amount;
|
||||
if (with_variations) {
|
||||
sd.seedWeights = seedWeightsToString(with_variations);
|
||||
}
|
@ -1,16 +0,0 @@
|
||||
import { Button, ButtonProps } from '@chakra-ui/react';
|
||||
|
||||
interface Props extends ButtonProps {
|
||||
label: string;
|
||||
}
|
||||
|
||||
const SDButton = (props: Props) => {
|
||||
const { label, size = 'sm', ...rest } = props;
|
||||
return (
|
||||
<Button size={size} {...rest}>
|
||||
{label}
|
||||
</Button>
|
||||
);
|
||||
};
|
||||
|
||||
export default SDButton;
|
@ -1,57 +0,0 @@
|
||||
import {
|
||||
Flex,
|
||||
FormControl,
|
||||
FormLabel,
|
||||
Select,
|
||||
SelectProps,
|
||||
Text,
|
||||
} from '@chakra-ui/react';
|
||||
|
||||
interface Props extends SelectProps {
|
||||
label: string;
|
||||
validValues:
|
||||
| Array<number | string>
|
||||
| Array<{ key: string; value: string | number }>;
|
||||
}
|
||||
|
||||
const SDSelect = (props: Props) => {
|
||||
const {
|
||||
label,
|
||||
isDisabled,
|
||||
validValues,
|
||||
size = 'sm',
|
||||
fontSize = 'md',
|
||||
marginBottom = 1,
|
||||
whiteSpace = 'nowrap',
|
||||
...rest
|
||||
} = props;
|
||||
return (
|
||||
<FormControl isDisabled={isDisabled}>
|
||||
<Flex justifyContent={'space-between'} alignItems={'center'}>
|
||||
<FormLabel
|
||||
marginBottom={marginBottom}
|
||||
>
|
||||
<Text fontSize={fontSize} whiteSpace={whiteSpace}>
|
||||
{label}
|
||||
</Text>
|
||||
</FormLabel>
|
||||
<Select fontSize={fontSize} size={size} {...rest}>
|
||||
{validValues.map((opt) => {
|
||||
return typeof opt === 'string' ||
|
||||
typeof opt === 'number' ? (
|
||||
<option key={opt} value={opt}>
|
||||
{opt}
|
||||
</option>
|
||||
) : (
|
||||
<option key={opt.value} value={opt.value}>
|
||||
{opt.key}
|
||||
</option>
|
||||
);
|
||||
})}
|
||||
</Select>
|
||||
</Flex>
|
||||
</FormControl>
|
||||
);
|
||||
};
|
||||
|
||||
export default SDSelect;
|
@ -1,161 +0,0 @@
|
||||
import { Center, Flex, Image, useColorModeValue } from '@chakra-ui/react';
|
||||
import { useAppDispatch, useAppSelector } from '../../app/hooks';
|
||||
import { RootState } from '../../app/store';
|
||||
import { setAllParameters, setInitialImagePath, setSeed } from '../sd/sdSlice';
|
||||
import { useState } from 'react';
|
||||
import ImageMetadataViewer from './ImageMetadataViewer';
|
||||
import DeleteImageModalButton from './DeleteImageModalButton';
|
||||
import SDButton from '../../components/SDButton';
|
||||
import { runESRGAN, runGFPGAN } from '../../app/socketio';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { SystemState } from '../system/systemSlice';
|
||||
import { isEqual } from 'lodash';
|
||||
|
||||
const height = 'calc(100vh - 238px)';
|
||||
|
||||
const systemSelector = createSelector(
|
||||
(state: RootState) => state.system,
|
||||
(system: SystemState) => {
|
||||
return {
|
||||
isProcessing: system.isProcessing,
|
||||
isConnected: system.isConnected,
|
||||
isGFPGANAvailable: system.isGFPGANAvailable,
|
||||
isESRGANAvailable: system.isESRGANAvailable,
|
||||
};
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
const CurrentImage = () => {
|
||||
const { currentImage, intermediateImage } = useAppSelector(
|
||||
(state: RootState) => state.gallery
|
||||
);
|
||||
const { isProcessing, isConnected, isGFPGANAvailable, isESRGANAvailable } =
|
||||
useAppSelector(systemSelector);
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const bgColor = useColorModeValue(
|
||||
'rgba(255, 255, 255, 0.85)',
|
||||
'rgba(0, 0, 0, 0.8)'
|
||||
);
|
||||
|
||||
const [shouldShowImageDetails, setShouldShowImageDetails] =
|
||||
useState<boolean>(false);
|
||||
|
||||
const imageToDisplay = intermediateImage || currentImage;
|
||||
|
||||
return (
|
||||
<Flex direction={'column'} rounded={'md'} borderWidth={1} p={2} gap={2}>
|
||||
{imageToDisplay && (
|
||||
<Flex gap={2}>
|
||||
<SDButton
|
||||
label='Use as initial image'
|
||||
colorScheme={'gray'}
|
||||
flexGrow={1}
|
||||
variant={'outline'}
|
||||
onClick={() =>
|
||||
dispatch(setInitialImagePath(imageToDisplay.url))
|
||||
}
|
||||
/>
|
||||
|
||||
<SDButton
|
||||
label='Use all'
|
||||
colorScheme={'gray'}
|
||||
flexGrow={1}
|
||||
variant={'outline'}
|
||||
onClick={() =>
|
||||
dispatch(setAllParameters(imageToDisplay.metadata))
|
||||
}
|
||||
/>
|
||||
|
||||
<SDButton
|
||||
label='Use seed'
|
||||
colorScheme={'gray'}
|
||||
flexGrow={1}
|
||||
variant={'outline'}
|
||||
isDisabled={!imageToDisplay.metadata.seed}
|
||||
onClick={() =>
|
||||
dispatch(setSeed(imageToDisplay.metadata.seed!))
|
||||
}
|
||||
/>
|
||||
|
||||
<SDButton
|
||||
label='Upscale'
|
||||
colorScheme={'gray'}
|
||||
flexGrow={1}
|
||||
variant={'outline'}
|
||||
isDisabled={
|
||||
!isESRGANAvailable ||
|
||||
Boolean(intermediateImage) ||
|
||||
!(isConnected && !isProcessing)
|
||||
}
|
||||
onClick={() => dispatch(runESRGAN(imageToDisplay))}
|
||||
/>
|
||||
<SDButton
|
||||
label='Fix faces'
|
||||
colorScheme={'gray'}
|
||||
flexGrow={1}
|
||||
variant={'outline'}
|
||||
isDisabled={
|
||||
!isGFPGANAvailable ||
|
||||
Boolean(intermediateImage) ||
|
||||
!(isConnected && !isProcessing)
|
||||
}
|
||||
onClick={() => dispatch(runGFPGAN(imageToDisplay))}
|
||||
/>
|
||||
<SDButton
|
||||
label='Details'
|
||||
colorScheme={'gray'}
|
||||
variant={shouldShowImageDetails ? 'solid' : 'outline'}
|
||||
borderWidth={1}
|
||||
flexGrow={1}
|
||||
onClick={() =>
|
||||
setShouldShowImageDetails(!shouldShowImageDetails)
|
||||
}
|
||||
/>
|
||||
<DeleteImageModalButton image={imageToDisplay}>
|
||||
<SDButton
|
||||
label='Delete'
|
||||
colorScheme={'red'}
|
||||
flexGrow={1}
|
||||
variant={'outline'}
|
||||
isDisabled={Boolean(intermediateImage)}
|
||||
/>
|
||||
</DeleteImageModalButton>
|
||||
</Flex>
|
||||
)}
|
||||
<Center height={height} position={'relative'}>
|
||||
{imageToDisplay && (
|
||||
<Image
|
||||
src={imageToDisplay.url}
|
||||
fit='contain'
|
||||
maxWidth={'100%'}
|
||||
maxHeight={'100%'}
|
||||
/>
|
||||
)}
|
||||
{imageToDisplay && shouldShowImageDetails && (
|
||||
<Flex
|
||||
width={'100%'}
|
||||
height={'100%'}
|
||||
position={'absolute'}
|
||||
top={0}
|
||||
left={0}
|
||||
p={3}
|
||||
boxSizing='border-box'
|
||||
backgroundColor={bgColor}
|
||||
overflow='scroll'
|
||||
>
|
||||
<ImageMetadataViewer image={imageToDisplay} />
|
||||
</Flex>
|
||||
)}
|
||||
</Center>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default CurrentImage;
|
149
frontend/src/features/gallery/CurrentImageButtons.tsx
Normal file
@ -0,0 +1,149 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import { useAppDispatch, useAppSelector } from '../../app/store';
|
||||
import { RootState } from '../../app/store';
|
||||
import { setAllParameters, setInitialImagePath, setSeed } from '../sd/sdSlice';
|
||||
import DeleteImageModal from './DeleteImageModal';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { SystemState } from '../system/systemSlice';
|
||||
import { isEqual } from 'lodash';
|
||||
import { SDImage } from './gallerySlice';
|
||||
import SDButton from '../../common/components/SDButton';
|
||||
import { runESRGAN, runGFPGAN } from '../../app/socketio/actions';
|
||||
|
||||
const systemSelector = createSelector(
|
||||
(state: RootState) => state.system,
|
||||
(system: SystemState) => {
|
||||
return {
|
||||
isProcessing: system.isProcessing,
|
||||
isConnected: system.isConnected,
|
||||
isGFPGANAvailable: system.isGFPGANAvailable,
|
||||
isESRGANAvailable: system.isESRGANAvailable,
|
||||
};
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
type CurrentImageButtonsProps = {
|
||||
image: SDImage;
|
||||
shouldShowImageDetails: boolean;
|
||||
setShouldShowImageDetails: (b: boolean) => void;
|
||||
};
|
||||
|
||||
/**
|
||||
* Row of buttons for common actions:
|
||||
* Use as init image, use all params, use seed, upscale, fix faces, details, delete.
|
||||
*/
|
||||
const CurrentImageButtons = ({
|
||||
image,
|
||||
shouldShowImageDetails,
|
||||
setShouldShowImageDetails,
|
||||
}: CurrentImageButtonsProps) => {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const { intermediateImage } = useAppSelector(
|
||||
(state: RootState) => state.gallery
|
||||
);
|
||||
|
||||
const { upscalingLevel, gfpganStrength } = useAppSelector(
|
||||
(state: RootState) => state.sd
|
||||
);
|
||||
|
||||
const { isProcessing, isConnected, isGFPGANAvailable, isESRGANAvailable } =
|
||||
useAppSelector(systemSelector);
|
||||
|
||||
const handleClickUseAsInitialImage = () =>
|
||||
dispatch(setInitialImagePath(image.url));
|
||||
|
||||
const handleClickUseAllParameters = () =>
|
||||
dispatch(setAllParameters(image.metadata));
|
||||
|
||||
// Non-null assertion: this button is disabled if there is no seed.
|
||||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||
const handleClickUseSeed = () => dispatch(setSeed(image.metadata.seed!));
|
||||
|
||||
const handleClickUpscale = () => dispatch(runESRGAN(image));
|
||||
|
||||
const handleClickFixFaces = () => dispatch(runGFPGAN(image));
|
||||
|
||||
const handleClickShowImageDetails = () =>
|
||||
setShouldShowImageDetails(!shouldShowImageDetails);
|
||||
|
||||
return (
|
||||
<Flex gap={2}>
|
||||
<SDButton
|
||||
label="Use as initial image"
|
||||
colorScheme={'gray'}
|
||||
flexGrow={1}
|
||||
variant={'outline'}
|
||||
onClick={handleClickUseAsInitialImage}
|
||||
/>
|
||||
|
||||
<SDButton
|
||||
label="Use all"
|
||||
colorScheme={'gray'}
|
||||
flexGrow={1}
|
||||
variant={'outline'}
|
||||
onClick={handleClickUseAllParameters}
|
||||
/>
|
||||
|
||||
<SDButton
|
||||
label="Use seed"
|
||||
colorScheme={'gray'}
|
||||
flexGrow={1}
|
||||
variant={'outline'}
|
||||
isDisabled={!image.metadata.seed}
|
||||
onClick={handleClickUseSeed}
|
||||
/>
|
||||
|
||||
<SDButton
|
||||
label="Upscale"
|
||||
colorScheme={'gray'}
|
||||
flexGrow={1}
|
||||
variant={'outline'}
|
||||
isDisabled={
|
||||
!isESRGANAvailable ||
|
||||
Boolean(intermediateImage) ||
|
||||
!(isConnected && !isProcessing) ||
|
||||
!upscalingLevel
|
||||
}
|
||||
onClick={handleClickUpscale}
|
||||
/>
|
||||
<SDButton
|
||||
label="Fix faces"
|
||||
colorScheme={'gray'}
|
||||
flexGrow={1}
|
||||
variant={'outline'}
|
||||
isDisabled={
|
||||
!isGFPGANAvailable ||
|
||||
Boolean(intermediateImage) ||
|
||||
!(isConnected && !isProcessing) ||
|
||||
!gfpganStrength
|
||||
}
|
||||
onClick={handleClickFixFaces}
|
||||
/>
|
||||
<SDButton
|
||||
label="Details"
|
||||
colorScheme={'gray'}
|
||||
variant={shouldShowImageDetails ? 'solid' : 'outline'}
|
||||
borderWidth={1}
|
||||
flexGrow={1}
|
||||
onClick={handleClickShowImageDetails}
|
||||
/>
|
||||
<DeleteImageModal image={image}>
|
||||
<SDButton
|
||||
label="Delete"
|
||||
colorScheme={'red'}
|
||||
flexGrow={1}
|
||||
variant={'outline'}
|
||||
isDisabled={Boolean(intermediateImage)}
|
||||
/>
|
||||
</DeleteImageModal>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default CurrentImageButtons;
|
67
frontend/src/features/gallery/CurrentImageDisplay.tsx
Normal file
@ -0,0 +1,67 @@
|
||||
import { Center, Flex, Image, Text, useColorModeValue } from '@chakra-ui/react';
|
||||
import { useAppSelector } from '../../app/store';
|
||||
import { RootState } from '../../app/store';
|
||||
import { useState } from 'react';
|
||||
import ImageMetadataViewer from './ImageMetadataViewer';
|
||||
import CurrentImageButtons from './CurrentImageButtons';
|
||||
|
||||
// TODO: With CSS Grid I had a hard time centering the image in a grid item. This is needed for that.
|
||||
const height = 'calc(100vh - 238px)';
|
||||
|
||||
/**
|
||||
* Displays the current image if there is one, plus associated actions.
|
||||
*/
|
||||
const CurrentImageDisplay = () => {
|
||||
const { currentImage, intermediateImage } = useAppSelector(
|
||||
(state: RootState) => state.gallery
|
||||
);
|
||||
|
||||
const bgColor = useColorModeValue(
|
||||
'rgba(255, 255, 255, 0.85)',
|
||||
'rgba(0, 0, 0, 0.8)'
|
||||
);
|
||||
|
||||
const [shouldShowImageDetails, setShouldShowImageDetails] =
|
||||
useState<boolean>(false);
|
||||
|
||||
const imageToDisplay = intermediateImage || currentImage;
|
||||
|
||||
return imageToDisplay ? (
|
||||
<Flex direction={'column'} borderWidth={1} rounded={'md'} p={2} gap={2}>
|
||||
<CurrentImageButtons
|
||||
image={imageToDisplay}
|
||||
shouldShowImageDetails={shouldShowImageDetails}
|
||||
setShouldShowImageDetails={setShouldShowImageDetails}
|
||||
/>
|
||||
<Center height={height} position={'relative'}>
|
||||
<Image
|
||||
src={imageToDisplay.url}
|
||||
fit="contain"
|
||||
maxWidth={'100%'}
|
||||
maxHeight={'100%'}
|
||||
/>
|
||||
{shouldShowImageDetails && (
|
||||
<Flex
|
||||
width={'100%'}
|
||||
height={'100%'}
|
||||
position={'absolute'}
|
||||
top={0}
|
||||
left={0}
|
||||
p={3}
|
||||
boxSizing="border-box"
|
||||
backgroundColor={bgColor}
|
||||
overflow="scroll"
|
||||
>
|
||||
<ImageMetadataViewer image={imageToDisplay} />
|
||||
</Flex>
|
||||
)}
|
||||
</Center>
|
||||
</Flex>
|
||||
) : (
|
||||
<Center height={'100%'} position={'relative'}>
|
||||
<Text size={'xl'}>No image selected</Text>
|
||||
</Center>
|
||||
);
|
||||
};
|
||||
|
||||
export default CurrentImageDisplay;
|
121
frontend/src/features/gallery/DeleteImageModal.tsx
Normal file
@ -0,0 +1,121 @@
|
||||
import {
|
||||
Text,
|
||||
AlertDialog,
|
||||
AlertDialogBody,
|
||||
AlertDialogFooter,
|
||||
AlertDialogHeader,
|
||||
AlertDialogContent,
|
||||
AlertDialogOverlay,
|
||||
useDisclosure,
|
||||
Button,
|
||||
Switch,
|
||||
FormControl,
|
||||
FormLabel,
|
||||
Flex,
|
||||
} from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import {
|
||||
ChangeEvent,
|
||||
cloneElement,
|
||||
ReactElement,
|
||||
SyntheticEvent,
|
||||
useRef,
|
||||
} from 'react';
|
||||
import { useAppDispatch, useAppSelector } from '../../app/store';
|
||||
import { deleteImage } from '../../app/socketio/actions';
|
||||
import { RootState } from '../../app/store';
|
||||
import { setShouldConfirmOnDelete, SystemState } from '../system/systemSlice';
|
||||
import { SDImage } from './gallerySlice';
|
||||
|
||||
interface DeleteImageModalProps {
|
||||
/**
|
||||
* Component which, on click, should delete the image/open the modal.
|
||||
*/
|
||||
children: ReactElement;
|
||||
/**
|
||||
* The image to delete.
|
||||
*/
|
||||
image: SDImage;
|
||||
}
|
||||
|
||||
const systemSelector = createSelector(
|
||||
(state: RootState) => state.system,
|
||||
(system: SystemState) => system.shouldConfirmOnDelete
|
||||
);
|
||||
|
||||
/**
|
||||
* Needs a child, which will act as the button to delete an image.
|
||||
* If system.shouldConfirmOnDelete is true, a confirmation modal is displayed.
|
||||
* If it is false, the image is deleted immediately.
|
||||
* The confirmation modal has a "Don't ask me again" switch to set the boolean.
|
||||
*/
|
||||
const DeleteImageModal = ({ image, children }: DeleteImageModalProps) => {
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
const dispatch = useAppDispatch();
|
||||
const shouldConfirmOnDelete = useAppSelector(systemSelector);
|
||||
const cancelRef = useRef<HTMLButtonElement>(null);
|
||||
|
||||
const handleClickDelete = (e: SyntheticEvent) => {
|
||||
e.stopPropagation();
|
||||
shouldConfirmOnDelete ? onOpen() : handleDelete();
|
||||
};
|
||||
|
||||
const handleDelete = () => {
|
||||
dispatch(deleteImage(image));
|
||||
onClose();
|
||||
};
|
||||
|
||||
const handleChangeShouldConfirmOnDelete = (
|
||||
e: ChangeEvent<HTMLInputElement>
|
||||
) => dispatch(setShouldConfirmOnDelete(!e.target.checked));
|
||||
|
||||
return (
|
||||
<>
|
||||
{cloneElement(children, {
|
||||
// TODO: This feels wrong.
|
||||
onClick: handleClickDelete,
|
||||
})}
|
||||
|
||||
<AlertDialog
|
||||
isOpen={isOpen}
|
||||
leastDestructiveRef={cancelRef}
|
||||
onClose={onClose}
|
||||
>
|
||||
<AlertDialogOverlay>
|
||||
<AlertDialogContent>
|
||||
<AlertDialogHeader fontSize="lg" fontWeight="bold">
|
||||
Delete image
|
||||
</AlertDialogHeader>
|
||||
|
||||
<AlertDialogBody>
|
||||
<Flex direction={'column'} gap={5}>
|
||||
<Text>
|
||||
Are you sure? You can't undo this action afterwards.
|
||||
</Text>
|
||||
<FormControl>
|
||||
<Flex alignItems={'center'}>
|
||||
<FormLabel mb={0}>Don't ask me again</FormLabel>
|
||||
<Switch
|
||||
checked={!shouldConfirmOnDelete}
|
||||
onChange={handleChangeShouldConfirmOnDelete}
|
||||
/>
|
||||
</Flex>
|
||||
</FormControl>
|
||||
</Flex>
|
||||
</AlertDialogBody>
|
||||
<AlertDialogFooter>
|
||||
<Button ref={cancelRef} onClick={onClose}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button colorScheme="red" onClick={handleDelete} ml={3}>
|
||||
Delete
|
||||
</Button>
|
||||
</AlertDialogFooter>
|
||||
</AlertDialogContent>
|
||||
</AlertDialogOverlay>
|
||||
</AlertDialog>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default DeleteImageModal;
|
@ -1,94 +0,0 @@
|
||||
import {
|
||||
IconButtonProps,
|
||||
Modal,
|
||||
ModalBody,
|
||||
ModalCloseButton,
|
||||
ModalContent,
|
||||
ModalFooter,
|
||||
ModalHeader,
|
||||
ModalOverlay,
|
||||
Text,
|
||||
useDisclosure,
|
||||
} from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import {
|
||||
cloneElement,
|
||||
ReactElement,
|
||||
SyntheticEvent,
|
||||
} from 'react';
|
||||
import { useAppDispatch, useAppSelector } from '../../app/hooks';
|
||||
import { deleteImage } from '../../app/socketio';
|
||||
import { RootState } from '../../app/store';
|
||||
import SDButton from '../../components/SDButton';
|
||||
import { setShouldConfirmOnDelete, SystemState } from '../system/systemSlice';
|
||||
import { SDImage } from './gallerySlice';
|
||||
|
||||
interface Props extends IconButtonProps {
|
||||
image: SDImage;
|
||||
'aria-label': string;
|
||||
children: ReactElement;
|
||||
}
|
||||
|
||||
const systemSelector = createSelector(
|
||||
(state: RootState) => state.system,
|
||||
(system: SystemState) => system.shouldConfirmOnDelete
|
||||
);
|
||||
|
||||
/*
|
||||
TODO: The modal and button to open it should be two different components,
|
||||
but their state is closely related and I'm not sure how best to accomplish it.
|
||||
*/
|
||||
const DeleteImageModalButton = (props: Omit<Props, 'aria-label'>) => {
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
const dispatch = useAppDispatch();
|
||||
const shouldConfirmOnDelete = useAppSelector(systemSelector);
|
||||
|
||||
const handleClickDelete = (e: SyntheticEvent) => {
|
||||
e.stopPropagation();
|
||||
shouldConfirmOnDelete ? onOpen() : handleDelete();
|
||||
};
|
||||
|
||||
const { image, children } = props;
|
||||
|
||||
const handleDelete = () => {
|
||||
dispatch(deleteImage(image));
|
||||
onClose();
|
||||
};
|
||||
|
||||
const handleDeleteAndDontAsk = () => {
|
||||
dispatch(deleteImage(image));
|
||||
dispatch(setShouldConfirmOnDelete(false));
|
||||
onClose();
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
{cloneElement(children, {
|
||||
onClick: handleClickDelete,
|
||||
})}
|
||||
|
||||
<Modal isOpen={isOpen} onClose={onClose}>
|
||||
<ModalOverlay />
|
||||
<ModalContent>
|
||||
<ModalHeader>Are you sure you want to delete this image?</ModalHeader>
|
||||
<ModalCloseButton />
|
||||
<ModalBody>
|
||||
<Text>It will be deleted forever!</Text>
|
||||
</ModalBody>
|
||||
|
||||
<ModalFooter justifyContent={'space-between'}>
|
||||
<SDButton label={'Yes'} colorScheme='red' onClick={handleDelete} />
|
||||
<SDButton
|
||||
label={"Yes, and don't ask me again"}
|
||||
colorScheme='red'
|
||||
onClick={handleDeleteAndDontAsk}
|
||||
/>
|
||||
<SDButton label='Cancel' colorScheme='blue' onClick={onClose} />
|
||||
</ModalFooter>
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default DeleteImageModalButton;
|
131
frontend/src/features/gallery/HoverableImage.tsx
Normal file
@ -0,0 +1,131 @@
|
||||
import {
|
||||
Box,
|
||||
Flex,
|
||||
Icon,
|
||||
IconButton,
|
||||
Image,
|
||||
useColorModeValue,
|
||||
} from '@chakra-ui/react';
|
||||
import { useAppDispatch } from '../../app/store';
|
||||
import { SDImage, setCurrentImage } from './gallerySlice';
|
||||
import { FaCheck, FaCopy, FaSeedling, FaTrash } from 'react-icons/fa';
|
||||
import DeleteImageModal from './DeleteImageModal';
|
||||
import { memo, SyntheticEvent, useState } from 'react';
|
||||
import { setAllParameters, setSeed } from '../sd/sdSlice';
|
||||
|
||||
interface HoverableImageProps {
|
||||
image: SDImage;
|
||||
isSelected: boolean;
|
||||
}
|
||||
|
||||
const memoEqualityCheck = (
|
||||
prev: HoverableImageProps,
|
||||
next: HoverableImageProps
|
||||
) => prev.image.uuid === next.image.uuid && prev.isSelected === next.isSelected;
|
||||
|
||||
/**
|
||||
* Gallery image component with delete/use all/use seed buttons on hover.
|
||||
*/
|
||||
const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
const [isHovered, setIsHovered] = useState<boolean>(false);
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const checkColor = useColorModeValue('green.600', 'green.300');
|
||||
const bgColor = useColorModeValue('gray.200', 'gray.700');
|
||||
const bgGradient = useColorModeValue(
|
||||
'radial-gradient(circle, rgba(255,255,255,0.7) 0%, rgba(255,255,255,0.7) 20%, rgba(0,0,0,0) 100%)',
|
||||
'radial-gradient(circle, rgba(0,0,0,0.7) 0%, rgba(0,0,0,0.7) 20%, rgba(0,0,0,0) 100%)'
|
||||
);
|
||||
|
||||
const { image, isSelected } = props;
|
||||
const { url, uuid, metadata } = image;
|
||||
|
||||
const handleMouseOver = () => setIsHovered(true);
|
||||
const handleMouseOut = () => setIsHovered(false);
|
||||
|
||||
const handleClickSetAllParameters = (e: SyntheticEvent) => {
|
||||
e.stopPropagation();
|
||||
dispatch(setAllParameters(metadata));
|
||||
};
|
||||
|
||||
const handleClickSetSeed = (e: SyntheticEvent) => {
|
||||
e.stopPropagation();
|
||||
// Non-null assertion: this button is not rendered unless this exists
|
||||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||
dispatch(setSeed(image.metadata.seed!));
|
||||
};
|
||||
|
||||
const handleClickImage = () => dispatch(setCurrentImage(image));
|
||||
|
||||
return (
|
||||
<Box position={'relative'} key={uuid}>
|
||||
<Image
|
||||
width={120}
|
||||
height={120}
|
||||
objectFit="cover"
|
||||
rounded={'md'}
|
||||
src={url}
|
||||
loading={'lazy'}
|
||||
backgroundColor={bgColor}
|
||||
/>
|
||||
<Flex
|
||||
cursor={'pointer'}
|
||||
position={'absolute'}
|
||||
top={0}
|
||||
left={0}
|
||||
rounded={'md'}
|
||||
width="100%"
|
||||
height="100%"
|
||||
alignItems={'center'}
|
||||
justifyContent={'center'}
|
||||
background={isSelected ? bgGradient : undefined}
|
||||
onClick={handleClickImage}
|
||||
onMouseOver={handleMouseOver}
|
||||
onMouseOut={handleMouseOut}
|
||||
>
|
||||
{isSelected && (
|
||||
<Icon fill={checkColor} width={'50%'} height={'50%'} as={FaCheck} />
|
||||
)}
|
||||
{isHovered && (
|
||||
<Flex
|
||||
direction={'column'}
|
||||
gap={1}
|
||||
position={'absolute'}
|
||||
top={1}
|
||||
right={1}
|
||||
>
|
||||
<DeleteImageModal image={image}>
|
||||
<IconButton
|
||||
colorScheme="red"
|
||||
aria-label="Delete image"
|
||||
icon={<FaTrash />}
|
||||
size="xs"
|
||||
fontSize={15}
|
||||
/>
|
||||
</DeleteImageModal>
|
||||
<IconButton
|
||||
aria-label="Use all parameters"
|
||||
colorScheme={'blue'}
|
||||
icon={<FaCopy />}
|
||||
size="xs"
|
||||
fontSize={15}
|
||||
onClickCapture={handleClickSetAllParameters}
|
||||
/>
|
||||
{image.metadata.seed && (
|
||||
<IconButton
|
||||
aria-label="Use seed"
|
||||
colorScheme={'blue'}
|
||||
icon={<FaSeedling />}
|
||||
size="xs"
|
||||
fontSize={16}
|
||||
onClickCapture={handleClickSetSeed}
|
||||
/>
|
||||
)}
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
</Box>
|
||||
);
|
||||
}, memoEqualityCheck);
|
||||
|
||||
export default HoverableImage;
|
39
frontend/src/features/gallery/ImageGallery.tsx
Normal file
@ -0,0 +1,39 @@
|
||||
import { Center, Flex, Text } from '@chakra-ui/react';
|
||||
import { RootState } from '../../app/store';
|
||||
import { useAppSelector } from '../../app/store';
|
||||
import HoverableImage from './HoverableImage';
|
||||
|
||||
/**
|
||||
* Simple image gallery.
|
||||
*/
|
||||
const ImageGallery = () => {
|
||||
const { images, currentImageUuid } = useAppSelector(
|
||||
(state: RootState) => state.gallery
|
||||
);
|
||||
|
||||
/**
|
||||
* I don't like that this needs to rerender whenever the current image is changed.
|
||||
* What if we have a large number of images? I suppose pagination (planned) will
|
||||
* mitigate this issue.
|
||||
*
|
||||
* TODO: Refactor if performance complaints, or after migrating to new API which supports pagination.
|
||||
*/
|
||||
|
||||
return images.length ? (
|
||||
<Flex gap={2} wrap="wrap" pb={2}>
|
||||
{[...images].reverse().map((image) => {
|
||||
const { uuid } = image;
|
||||
const isSelected = currentImageUuid === uuid;
|
||||
return (
|
||||
<HoverableImage key={uuid} image={image} isSelected={isSelected} />
|
||||
);
|
||||
})}
|
||||
</Flex>
|
||||
) : (
|
||||
<Center height={'100%'} position={'relative'}>
|
||||
<Text size={'xl'}>No images in gallery</Text>
|
||||
</Center>
|
||||
);
|
||||
};
|
||||
|
||||
export default ImageGallery;
|
@ -1,124 +1,134 @@
|
||||
import {
|
||||
Center,
|
||||
Flex,
|
||||
IconButton,
|
||||
Link,
|
||||
List,
|
||||
ListItem,
|
||||
Text,
|
||||
Center,
|
||||
Flex,
|
||||
IconButton,
|
||||
Link,
|
||||
List,
|
||||
ListItem,
|
||||
Text,
|
||||
} from '@chakra-ui/react';
|
||||
import { memo } from 'react';
|
||||
import { FaPlus } from 'react-icons/fa';
|
||||
import { PARAMETERS } from '../../app/constants';
|
||||
import { useAppDispatch } from '../../app/hooks';
|
||||
import SDButton from '../../components/SDButton';
|
||||
import { useAppDispatch } from '../../app/store';
|
||||
import SDButton from '../../common/components/SDButton';
|
||||
import { setAllParameters, setParameter } from '../sd/sdSlice';
|
||||
import { SDImage, SDMetadata } from './gallerySlice';
|
||||
|
||||
type Props = {
|
||||
image: SDImage;
|
||||
type ImageMetadataViewerProps = {
|
||||
image: SDImage;
|
||||
};
|
||||
|
||||
const ImageMetadataViewer = ({ image }: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
// TODO: I don't know if this is needed.
|
||||
const memoEqualityCheck = (
|
||||
prev: ImageMetadataViewerProps,
|
||||
next: ImageMetadataViewerProps
|
||||
) => prev.image.uuid === next.image.uuid;
|
||||
|
||||
const keys = Object.keys(PARAMETERS);
|
||||
// TODO: Show more interesting information in this component.
|
||||
|
||||
const metadata: Array<{
|
||||
label: string;
|
||||
key: string;
|
||||
value: string | number | boolean;
|
||||
}> = [];
|
||||
/**
|
||||
* Image metadata viewer overlays currently selected image and provides
|
||||
* access to any of its metadata for use in processing.
|
||||
*/
|
||||
const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
keys.forEach((key) => {
|
||||
const value = image.metadata[key as keyof SDMetadata];
|
||||
if (value !== undefined) {
|
||||
metadata.push({ label: PARAMETERS[key], key, value });
|
||||
}
|
||||
});
|
||||
/**
|
||||
* Build an array representing each item of metadata and a human-readable
|
||||
* label for it e.g. "cfgScale" > "CFG Scale".
|
||||
*
|
||||
* This array is then used to render each item with a button to use that
|
||||
* parameter in the processing settings.
|
||||
*
|
||||
* TODO: All this logic feels sloppy.
|
||||
*/
|
||||
const keys = Object.keys(PARAMETERS);
|
||||
|
||||
return (
|
||||
<Flex gap={2} direction={'column'} overflowY={'scroll'} width={'100%'}>
|
||||
<SDButton
|
||||
label='Use all parameters'
|
||||
colorScheme={'gray'}
|
||||
padding={2}
|
||||
isDisabled={metadata.length === 0}
|
||||
onClick={() => dispatch(setAllParameters(image.metadata))}
|
||||
/>
|
||||
<Flex gap={2}>
|
||||
<Text fontWeight={'semibold'}>File:</Text>
|
||||
<Link href={image.url} isExternal>
|
||||
<Text>{image.url}</Text>
|
||||
</Link>
|
||||
</Flex>
|
||||
{metadata.length ? (
|
||||
<>
|
||||
<List>
|
||||
{metadata.map((parameter, i) => {
|
||||
const { label, key, value } = parameter;
|
||||
return (
|
||||
<ListItem key={i} pb={1}>
|
||||
<Flex gap={2}>
|
||||
<IconButton
|
||||
aria-label='Use this parameter'
|
||||
icon={<FaPlus />}
|
||||
size={'xs'}
|
||||
onClick={() =>
|
||||
dispatch(
|
||||
setParameter({
|
||||
key,
|
||||
value,
|
||||
})
|
||||
)
|
||||
}
|
||||
/>
|
||||
<Text fontWeight={'semibold'}>
|
||||
{label}:
|
||||
</Text>
|
||||
const metadata: Array<{
|
||||
label: string;
|
||||
key: string;
|
||||
value: string | number | boolean;
|
||||
}> = [];
|
||||
|
||||
{value === undefined ||
|
||||
value === null ||
|
||||
value === '' ||
|
||||
value === 0 ? (
|
||||
<Text
|
||||
maxHeight={100}
|
||||
fontStyle={'italic'}
|
||||
>
|
||||
None
|
||||
</Text>
|
||||
) : (
|
||||
<Text
|
||||
maxHeight={100}
|
||||
overflowY={'scroll'}
|
||||
>
|
||||
{value.toString()}
|
||||
</Text>
|
||||
)}
|
||||
</Flex>
|
||||
</ListItem>
|
||||
);
|
||||
})}
|
||||
</List>
|
||||
<Flex gap={2}>
|
||||
<Text fontWeight={'semibold'}>Raw:</Text>
|
||||
<Text
|
||||
maxHeight={100}
|
||||
overflowY={'scroll'}
|
||||
wordBreak={'break-all'}
|
||||
>
|
||||
{JSON.stringify(image.metadata)}
|
||||
</Text>
|
||||
</Flex>
|
||||
</>
|
||||
) : (
|
||||
<Center width={'100%'} pt={10}>
|
||||
<Text fontSize={'lg'} fontWeight='semibold'>
|
||||
No metadata available
|
||||
</Text>
|
||||
</Center>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
keys.forEach((key) => {
|
||||
const value = image.metadata[key as keyof SDMetadata];
|
||||
if (value !== undefined) {
|
||||
metadata.push({ label: PARAMETERS[key], key, value });
|
||||
}
|
||||
});
|
||||
|
||||
return (
|
||||
<Flex gap={2} direction={'column'} overflowY={'scroll'} width={'100%'}>
|
||||
<SDButton
|
||||
label="Use all parameters"
|
||||
colorScheme={'gray'}
|
||||
padding={2}
|
||||
isDisabled={metadata.length === 0}
|
||||
onClick={() => dispatch(setAllParameters(image.metadata))}
|
||||
/>
|
||||
<Flex gap={2}>
|
||||
<Text fontWeight={'semibold'}>File:</Text>
|
||||
<Link href={image.url} isExternal>
|
||||
<Text>{image.url}</Text>
|
||||
</Link>
|
||||
</Flex>
|
||||
{metadata.length ? (
|
||||
<>
|
||||
<List>
|
||||
{metadata.map((parameter, i) => {
|
||||
const { label, key, value } = parameter;
|
||||
return (
|
||||
<ListItem key={i} pb={1}>
|
||||
<Flex gap={2}>
|
||||
<IconButton
|
||||
aria-label="Use this parameter"
|
||||
icon={<FaPlus />}
|
||||
size={'xs'}
|
||||
onClick={() =>
|
||||
dispatch(
|
||||
setParameter({
|
||||
key,
|
||||
value,
|
||||
})
|
||||
)
|
||||
}
|
||||
/>
|
||||
<Text fontWeight={'semibold'}>{label}:</Text>
|
||||
|
||||
{value === undefined ||
|
||||
value === null ||
|
||||
value === '' ||
|
||||
value === 0 ? (
|
||||
<Text maxHeight={100} fontStyle={'italic'}>
|
||||
None
|
||||
</Text>
|
||||
) : (
|
||||
<Text maxHeight={100} overflowY={'scroll'}>
|
||||
{value.toString()}
|
||||
</Text>
|
||||
)}
|
||||
</Flex>
|
||||
</ListItem>
|
||||
);
|
||||
})}
|
||||
</List>
|
||||
<Flex gap={2}>
|
||||
<Text fontWeight={'semibold'}>Raw:</Text>
|
||||
<Text maxHeight={100} overflowY={'scroll'} wordBreak={'break-all'}>
|
||||
{JSON.stringify(image.metadata)}
|
||||
</Text>
|
||||
</Flex>
|
||||
</>
|
||||
) : (
|
||||
<Center width={'100%'} pt={10}>
|
||||
<Text fontSize={'lg'} fontWeight="semibold">
|
||||
No metadata available
|
||||
</Text>
|
||||
</Center>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
}, memoEqualityCheck);
|
||||
|
||||
export default ImageMetadataViewer;
|
||||
|
@ -1,150 +0,0 @@
|
||||
import {
|
||||
Box,
|
||||
Flex,
|
||||
Icon,
|
||||
IconButton,
|
||||
Image,
|
||||
useColorModeValue,
|
||||
} from '@chakra-ui/react';
|
||||
import { RootState } from '../../app/store';
|
||||
import { useAppDispatch, useAppSelector } from '../../app/hooks';
|
||||
import { SDImage, setCurrentImage } from './gallerySlice';
|
||||
import { FaCheck, FaCopy, FaSeedling, FaTrash } from 'react-icons/fa';
|
||||
import DeleteImageModalButton from './DeleteImageModalButton';
|
||||
import { memo, SyntheticEvent, useState } from 'react';
|
||||
import { setAllParameters, setSeed } from '../sd/sdSlice';
|
||||
|
||||
interface HoverableImageProps {
|
||||
image: SDImage;
|
||||
isSelected: boolean;
|
||||
}
|
||||
|
||||
const HoverableImage = memo(
|
||||
(props: HoverableImageProps) => {
|
||||
const [isHovered, setIsHovered] = useState<boolean>(false);
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const checkColor = useColorModeValue('green.600', 'green.300');
|
||||
const bgColor = useColorModeValue('gray.200', 'gray.700');
|
||||
const bgGradient = useColorModeValue(
|
||||
'radial-gradient(circle, rgba(255,255,255,0.7) 0%, rgba(255,255,255,0.7) 20%, rgba(0,0,0,0) 100%)',
|
||||
'radial-gradient(circle, rgba(0,0,0,0.7) 0%, rgba(0,0,0,0.7) 20%, rgba(0,0,0,0) 100%)'
|
||||
);
|
||||
|
||||
const { image, isSelected } = props;
|
||||
const { url, uuid, metadata } = image;
|
||||
|
||||
const handleMouseOver = () => setIsHovered(true);
|
||||
const handleMouseOut = () => setIsHovered(false);
|
||||
const handleClickSetAllParameters = (e: SyntheticEvent) => {
|
||||
e.stopPropagation();
|
||||
dispatch(setAllParameters(metadata));
|
||||
};
|
||||
const handleClickSetSeed = (e: SyntheticEvent) => {
|
||||
e.stopPropagation();
|
||||
dispatch(setSeed(image.metadata.seed!)); // component not rendered unless this exists
|
||||
};
|
||||
|
||||
return (
|
||||
<Box position={'relative'} key={uuid}>
|
||||
<Image
|
||||
width={120}
|
||||
height={120}
|
||||
objectFit='cover'
|
||||
rounded={'md'}
|
||||
src={url}
|
||||
loading={'lazy'}
|
||||
backgroundColor={bgColor}
|
||||
/>
|
||||
<Flex
|
||||
cursor={'pointer'}
|
||||
position={'absolute'}
|
||||
top={0}
|
||||
left={0}
|
||||
rounded={'md'}
|
||||
width='100%'
|
||||
height='100%'
|
||||
alignItems={'center'}
|
||||
justifyContent={'center'}
|
||||
background={isSelected ? bgGradient : undefined}
|
||||
onClick={() => dispatch(setCurrentImage(image))}
|
||||
onMouseOver={handleMouseOver}
|
||||
onMouseOut={handleMouseOut}
|
||||
>
|
||||
{isSelected && (
|
||||
<Icon
|
||||
fill={checkColor}
|
||||
width={'50%'}
|
||||
height={'50%'}
|
||||
as={FaCheck}
|
||||
/>
|
||||
)}
|
||||
{isHovered && (
|
||||
<Flex
|
||||
direction={'column'}
|
||||
gap={1}
|
||||
position={'absolute'}
|
||||
top={1}
|
||||
right={1}
|
||||
>
|
||||
<DeleteImageModalButton image={image}>
|
||||
<IconButton
|
||||
colorScheme='red'
|
||||
aria-label='Delete image'
|
||||
icon={<FaTrash />}
|
||||
size='xs'
|
||||
fontSize={15}
|
||||
/>
|
||||
</DeleteImageModalButton>
|
||||
<IconButton
|
||||
aria-label='Use all parameters'
|
||||
colorScheme={'blue'}
|
||||
icon={<FaCopy />}
|
||||
size='xs'
|
||||
fontSize={15}
|
||||
onClickCapture={handleClickSetAllParameters}
|
||||
/>
|
||||
{image.metadata.seed && (
|
||||
<IconButton
|
||||
aria-label='Use seed'
|
||||
colorScheme={'blue'}
|
||||
icon={<FaSeedling />}
|
||||
size='xs'
|
||||
fontSize={16}
|
||||
onClickCapture={handleClickSetSeed}
|
||||
/>
|
||||
)}
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
</Box>
|
||||
);
|
||||
},
|
||||
(prev, next) =>
|
||||
prev.image.uuid === next.image.uuid &&
|
||||
prev.isSelected === next.isSelected
|
||||
);
|
||||
|
||||
const ImageRoll = () => {
|
||||
const { images, currentImageUuid } = useAppSelector(
|
||||
(state: RootState) => state.gallery
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex gap={2} wrap='wrap' pb={2}>
|
||||
{[...images].reverse().map((image) => {
|
||||
const { uuid } = image;
|
||||
const isSelected = currentImageUuid === uuid;
|
||||
return (
|
||||
<HoverableImage
|
||||
key={uuid}
|
||||
image={image}
|
||||
isSelected={isSelected}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default ImageRoll;
|
@ -1,8 +1,7 @@
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { UpscalingLevel } from '../sd/sdSlice';
|
||||
import { backendToFrontendParameters } from '../../app/parameterTranslation';
|
||||
import { clamp } from 'lodash';
|
||||
|
||||
// TODO: Revise pending metadata RFC: https://github.com/lstein/stable-diffusion/issues/266
|
||||
export interface SDMetadata {
|
||||
@ -50,29 +49,48 @@ export const gallerySlice = createSlice({
|
||||
state.currentImage = action.payload;
|
||||
state.currentImageUuid = action.payload.uuid;
|
||||
},
|
||||
removeImage: (state, action: PayloadAction<SDImage>) => {
|
||||
const { uuid } = action.payload;
|
||||
removeImage: (state, action: PayloadAction<string>) => {
|
||||
const uuid = action.payload;
|
||||
|
||||
const newImages = state.images.filter((image) => image.uuid !== uuid);
|
||||
|
||||
const imageToDeleteIndex = state.images.findIndex(
|
||||
(image) => image.uuid === uuid
|
||||
);
|
||||
if (uuid === state.currentImageUuid) {
|
||||
/**
|
||||
* We are deleting the currently selected image.
|
||||
*
|
||||
* We want the new currentl selected image to be under the cursor in the
|
||||
* gallery, so we need to do some fanagling. The currently selected image
|
||||
* is set by its UUID, not its index in the image list.
|
||||
*
|
||||
* Get the currently selected image's index.
|
||||
*/
|
||||
const imageToDeleteIndex = state.images.findIndex(
|
||||
(image) => image.uuid === uuid
|
||||
);
|
||||
|
||||
const newCurrentImageIndex = Math.min(
|
||||
Math.max(imageToDeleteIndex, 0),
|
||||
newImages.length - 1
|
||||
);
|
||||
/**
|
||||
* New current image needs to be in the same spot, but because the gallery
|
||||
* is sorted in reverse order, the new current image's index will actuall be
|
||||
* one less than the deleted image's index.
|
||||
*
|
||||
* Clamp the new index to ensure it is valid..
|
||||
*/
|
||||
const newCurrentImageIndex = clamp(
|
||||
imageToDeleteIndex - 1,
|
||||
0,
|
||||
newImages.length - 1
|
||||
);
|
||||
|
||||
state.currentImage = newImages.length
|
||||
? newImages[newCurrentImageIndex]
|
||||
: undefined;
|
||||
|
||||
state.currentImageUuid = newImages.length
|
||||
? newImages[newCurrentImageIndex].uuid
|
||||
: '';
|
||||
}
|
||||
|
||||
state.images = newImages;
|
||||
|
||||
state.currentImage = newImages.length
|
||||
? newImages[newCurrentImageIndex]
|
||||
: undefined;
|
||||
|
||||
state.currentImageUuid = newImages.length
|
||||
? newImages[newCurrentImageIndex].uuid
|
||||
: '';
|
||||
},
|
||||
addImage: (state, action: PayloadAction<SDImage>) => {
|
||||
state.images.push(action.payload);
|
||||
@ -86,47 +104,13 @@ export const gallerySlice = createSlice({
|
||||
clearIntermediateImage: (state) => {
|
||||
state.intermediateImage = undefined;
|
||||
},
|
||||
setGalleryImages: (
|
||||
state,
|
||||
action: PayloadAction<
|
||||
Array<{
|
||||
path: string;
|
||||
metadata: { [key: string]: string | number | boolean };
|
||||
}>
|
||||
>
|
||||
) => {
|
||||
// TODO: Revise pending metadata RFC: https://github.com/lstein/stable-diffusion/issues/266
|
||||
const images = action.payload;
|
||||
|
||||
if (images.length === 0) {
|
||||
// there are no images on disk, clear the gallery
|
||||
state.images = [];
|
||||
state.currentImageUuid = '';
|
||||
state.currentImage = undefined;
|
||||
} else {
|
||||
// Filter image urls that are already in the rehydrated state
|
||||
const filteredImages = action.payload.filter(
|
||||
(image) => !state.images.find((i) => i.url === image.path)
|
||||
);
|
||||
|
||||
const preparedImages = filteredImages.map((image): SDImage => {
|
||||
return {
|
||||
uuid: uuidv4(),
|
||||
url: image.path,
|
||||
metadata: backendToFrontendParameters(image.metadata),
|
||||
};
|
||||
});
|
||||
|
||||
const newImages = [...state.images].concat(preparedImages);
|
||||
|
||||
// if previous currentimage no longer exists, set a new one
|
||||
if (!newImages.find((image) => image.uuid === state.currentImageUuid)) {
|
||||
const newCurrentImage = newImages[newImages.length - 1];
|
||||
state.currentImage = newCurrentImage;
|
||||
state.currentImageUuid = newCurrentImage.uuid;
|
||||
}
|
||||
|
||||
setGalleryImages: (state, action: PayloadAction<Array<SDImage>>) => {
|
||||
const newImages = action.payload;
|
||||
if (newImages.length) {
|
||||
const newCurrentImage = newImages[newImages.length - 1];
|
||||
state.images = newImages;
|
||||
state.currentImage = newCurrentImage;
|
||||
state.currentImageUuid = newCurrentImage.uuid;
|
||||
}
|
||||
},
|
||||
},
|
||||
|
@ -1,35 +1,38 @@
|
||||
import { Progress } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { isEqual } from 'lodash';
|
||||
import { useAppSelector } from '../../app/hooks';
|
||||
import { useAppSelector } from '../../app/store';
|
||||
import { RootState } from '../../app/store';
|
||||
import { SDState } from '../sd/sdSlice';
|
||||
import { SystemState } from '../system/systemSlice';
|
||||
|
||||
const sdSelector = createSelector(
|
||||
(state: RootState) => state.sd,
|
||||
(sd: SDState) => {
|
||||
return {
|
||||
realSteps: sd.realSteps,
|
||||
};
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
}
|
||||
const systemSelector = createSelector(
|
||||
(state: RootState) => state.system,
|
||||
(system: SystemState) => {
|
||||
return {
|
||||
isProcessing: system.isProcessing,
|
||||
currentStep: system.currentStep,
|
||||
totalSteps: system.totalSteps,
|
||||
currentStatusHasSteps: system.currentStatusHasSteps,
|
||||
};
|
||||
},
|
||||
{
|
||||
memoizeOptions: { resultEqualityCheck: isEqual },
|
||||
}
|
||||
);
|
||||
|
||||
const ProgressBar = () => {
|
||||
const { realSteps } = useAppSelector(sdSelector);
|
||||
const { currentStep } = useAppSelector((state: RootState) => state.system);
|
||||
const progress = Math.round((currentStep * 100) / realSteps);
|
||||
return (
|
||||
<Progress
|
||||
height='10px'
|
||||
value={progress}
|
||||
isIndeterminate={progress < 0 || currentStep === realSteps}
|
||||
/>
|
||||
);
|
||||
const { isProcessing, currentStep, totalSteps, currentStatusHasSteps } =
|
||||
useAppSelector(systemSelector);
|
||||
|
||||
const value = currentStep ? Math.round((currentStep * 100) / totalSteps) : 0;
|
||||
|
||||
return (
|
||||
<Progress
|
||||
height="10px"
|
||||
value={value}
|
||||
isIndeterminate={isProcessing && !currentStatusHasSteps}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default ProgressBar;
|
||||
|
@ -12,39 +12,66 @@ import { isEqual } from 'lodash';
|
||||
|
||||
import { FaSun, FaMoon, FaGithub } from 'react-icons/fa';
|
||||
import { MdHelp, MdSettings } from 'react-icons/md';
|
||||
import { useAppSelector } from '../../app/hooks';
|
||||
import { useAppSelector } from '../../app/store';
|
||||
import { RootState } from '../../app/store';
|
||||
import SettingsModal from '../system/SettingsModal';
|
||||
import { SystemState } from '../system/systemSlice';
|
||||
|
||||
const systemSelector = createSelector(
|
||||
(state: RootState) => state.system,
|
||||
(system: SystemState) => {
|
||||
return { isConnected: system.isConnected };
|
||||
return {
|
||||
isConnected: system.isConnected,
|
||||
isProcessing: system.isProcessing,
|
||||
currentIteration: system.currentIteration,
|
||||
totalIterations: system.totalIterations,
|
||||
currentStatus: system.currentStatus,
|
||||
};
|
||||
},
|
||||
{
|
||||
memoizeOptions: { resultEqualityCheck: isEqual },
|
||||
}
|
||||
);
|
||||
|
||||
/**
|
||||
* Header, includes color mode toggle, settings button, status message.
|
||||
*/
|
||||
const SiteHeader = () => {
|
||||
const { colorMode, toggleColorMode } = useColorMode();
|
||||
const { isConnected } = useAppSelector(systemSelector);
|
||||
const {
|
||||
isConnected,
|
||||
isProcessing,
|
||||
currentIteration,
|
||||
totalIterations,
|
||||
currentStatus,
|
||||
} = useAppSelector(systemSelector);
|
||||
|
||||
const statusMessageTextColor = isConnected ? 'green.500' : 'red.500';
|
||||
|
||||
const colorModeIcon = colorMode == 'light' ? <FaMoon /> : <FaSun />;
|
||||
|
||||
// Make FaMoon and FaSun icon apparent size consistent
|
||||
const colorModeIconFontSize = colorMode == 'light' ? 18 : 20;
|
||||
|
||||
let statusMessage = currentStatus;
|
||||
|
||||
if (isProcessing) {
|
||||
if (totalIterations > 1) {
|
||||
statusMessage += ` [${currentIteration}/${totalIterations}]`;
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex minWidth='max-content' alignItems='center' gap='1' pl={2} pr={1}>
|
||||
<Flex minWidth="max-content" alignItems="center" gap="1" pl={2} pr={1}>
|
||||
<Heading size={'lg'}>Stable Diffusion Dream Server</Heading>
|
||||
|
||||
<Spacer />
|
||||
|
||||
<Text textColor={isConnected ? 'green.500' : 'red.500'}>
|
||||
{isConnected ? `Connected to server` : 'No connection to server'}
|
||||
</Text>
|
||||
<Text textColor={statusMessageTextColor}>{statusMessage}</Text>
|
||||
|
||||
<SettingsModal>
|
||||
<IconButton
|
||||
aria-label='Settings'
|
||||
variant='link'
|
||||
aria-label="Settings"
|
||||
variant="link"
|
||||
fontSize={24}
|
||||
size={'sm'}
|
||||
icon={<MdSettings />}
|
||||
@ -52,14 +79,14 @@ const SiteHeader = () => {
|
||||
</SettingsModal>
|
||||
|
||||
<IconButton
|
||||
aria-label='Link to Github Issues'
|
||||
variant='link'
|
||||
aria-label="Link to Github Issues"
|
||||
variant="link"
|
||||
fontSize={23}
|
||||
size={'sm'}
|
||||
icon={
|
||||
<Link
|
||||
isExternal
|
||||
href='http://github.com/lstein/stable-diffusion/issues'
|
||||
href="http://github.com/lstein/stable-diffusion/issues"
|
||||
>
|
||||
<MdHelp />
|
||||
</Link>
|
||||
@ -67,24 +94,24 @@ const SiteHeader = () => {
|
||||
/>
|
||||
|
||||
<IconButton
|
||||
aria-label='Link to Github Repo'
|
||||
variant='link'
|
||||
aria-label="Link to Github Repo"
|
||||
variant="link"
|
||||
fontSize={20}
|
||||
size={'sm'}
|
||||
icon={
|
||||
<Link isExternal href='http://github.com/lstein/stable-diffusion'>
|
||||
<Link isExternal href="http://github.com/lstein/stable-diffusion">
|
||||
<FaGithub />
|
||||
</Link>
|
||||
}
|
||||
/>
|
||||
|
||||
<IconButton
|
||||
aria-label='Toggle Dark Mode'
|
||||
aria-label="Toggle Dark Mode"
|
||||
onClick={toggleColorMode}
|
||||
variant='link'
|
||||
variant="link"
|
||||
size={'sm'}
|
||||
fontSize={colorMode == 'light' ? 18 : 20}
|
||||
icon={colorMode == 'light' ? <FaMoon /> : <FaSun />}
|
||||
fontSize={colorModeIconFontSize}
|
||||
icon={colorModeIcon}
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
|
@ -1,84 +1,87 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
|
||||
import { RootState } from '../../app/store';
|
||||
import { useAppDispatch, useAppSelector } from '../../app/hooks';
|
||||
import { useAppDispatch, useAppSelector } from '../../app/store';
|
||||
|
||||
import {
|
||||
setUpscalingLevel,
|
||||
setUpscalingStrength,
|
||||
UpscalingLevel,
|
||||
SDState,
|
||||
setUpscalingLevel,
|
||||
setUpscalingStrength,
|
||||
UpscalingLevel,
|
||||
SDState,
|
||||
} from '../sd/sdSlice';
|
||||
|
||||
import SDNumberInput from '../../components/SDNumberInput';
|
||||
import SDSelect from '../../components/SDSelect';
|
||||
|
||||
import { UPSCALING_LEVELS } from '../../app/constants';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { isEqual } from 'lodash';
|
||||
import { SystemState } from '../system/systemSlice';
|
||||
import { ChangeEvent } from 'react';
|
||||
import SDNumberInput from '../../common/components/SDNumberInput';
|
||||
import SDSelect from '../../common/components/SDSelect';
|
||||
|
||||
const sdSelector = createSelector(
|
||||
(state: RootState) => state.sd,
|
||||
(sd: SDState) => {
|
||||
return {
|
||||
upscalingLevel: sd.upscalingLevel,
|
||||
upscalingStrength: sd.upscalingStrength,
|
||||
};
|
||||
(state: RootState) => state.sd,
|
||||
(sd: SDState) => {
|
||||
return {
|
||||
upscalingLevel: sd.upscalingLevel,
|
||||
upscalingStrength: sd.upscalingStrength,
|
||||
};
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
const systemSelector = createSelector(
|
||||
(state: RootState) => state.system,
|
||||
(system: SystemState) => {
|
||||
return {
|
||||
isESRGANAvailable: system.isESRGANAvailable,
|
||||
};
|
||||
(state: RootState) => state.system,
|
||||
(system: SystemState) => {
|
||||
return {
|
||||
isESRGANAvailable: system.isESRGANAvailable,
|
||||
};
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
/**
|
||||
* Displays upscaling/ESRGAN options (level and strength).
|
||||
*/
|
||||
const ESRGANOptions = () => {
|
||||
const { upscalingLevel, upscalingStrength } = useAppSelector(sdSelector);
|
||||
const dispatch = useAppDispatch();
|
||||
const { upscalingLevel, upscalingStrength } = useAppSelector(sdSelector);
|
||||
const { isESRGANAvailable } = useAppSelector(systemSelector);
|
||||
|
||||
const { isESRGANAvailable } = useAppSelector(systemSelector);
|
||||
const handleChangeLevel = (e: ChangeEvent<HTMLSelectElement>) =>
|
||||
dispatch(setUpscalingLevel(Number(e.target.value) as UpscalingLevel));
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const handleChangeStrength = (v: string | number) =>
|
||||
dispatch(setUpscalingStrength(Number(v)));
|
||||
|
||||
return (
|
||||
<Flex direction={'column'} gap={2}>
|
||||
<SDSelect
|
||||
isDisabled={!isESRGANAvailable}
|
||||
label='Scale'
|
||||
value={upscalingLevel}
|
||||
onChange={(e) =>
|
||||
dispatch(
|
||||
setUpscalingLevel(
|
||||
Number(e.target.value) as UpscalingLevel
|
||||
)
|
||||
)
|
||||
}
|
||||
validValues={UPSCALING_LEVELS}
|
||||
/>
|
||||
<SDNumberInput
|
||||
isDisabled={!isESRGANAvailable}
|
||||
label='Strength'
|
||||
step={0.05}
|
||||
min={0}
|
||||
max={1}
|
||||
onChange={(v) => dispatch(setUpscalingStrength(Number(v)))}
|
||||
value={upscalingStrength}
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
return (
|
||||
<Flex direction={'column'} gap={2}>
|
||||
<SDSelect
|
||||
isDisabled={!isESRGANAvailable}
|
||||
label="Scale"
|
||||
value={upscalingLevel}
|
||||
onChange={handleChangeLevel}
|
||||
validValues={UPSCALING_LEVELS}
|
||||
/>
|
||||
<SDNumberInput
|
||||
isDisabled={!isESRGANAvailable}
|
||||
label="Strength"
|
||||
step={0.05}
|
||||
min={0}
|
||||
max={1}
|
||||
onChange={handleChangeStrength}
|
||||
value={upscalingStrength}
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default ESRGANOptions;
|
||||
|
@ -1,63 +1,68 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
|
||||
import { RootState } from '../../app/store';
|
||||
import { useAppDispatch, useAppSelector } from '../../app/hooks';
|
||||
import { useAppDispatch, useAppSelector } from '../../app/store';
|
||||
|
||||
import { SDState, setGfpganStrength } from '../sd/sdSlice';
|
||||
|
||||
import SDNumberInput from '../../components/SDNumberInput';
|
||||
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { isEqual } from 'lodash';
|
||||
import { SystemState } from '../system/systemSlice';
|
||||
import SDNumberInput from '../../common/components/SDNumberInput';
|
||||
|
||||
const sdSelector = createSelector(
|
||||
(state: RootState) => state.sd,
|
||||
(sd: SDState) => {
|
||||
return {
|
||||
gfpganStrength: sd.gfpganStrength,
|
||||
};
|
||||
(state: RootState) => state.sd,
|
||||
(sd: SDState) => {
|
||||
return {
|
||||
gfpganStrength: sd.gfpganStrength,
|
||||
};
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
const systemSelector = createSelector(
|
||||
(state: RootState) => state.system,
|
||||
(system: SystemState) => {
|
||||
return {
|
||||
isGFPGANAvailable: system.isGFPGANAvailable,
|
||||
};
|
||||
(state: RootState) => state.system,
|
||||
(system: SystemState) => {
|
||||
return {
|
||||
isGFPGANAvailable: system.isGFPGANAvailable,
|
||||
};
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
/**
|
||||
* Displays face-fixing/GFPGAN options (strength).
|
||||
*/
|
||||
const GFPGANOptions = () => {
|
||||
const { gfpganStrength } = useAppSelector(sdSelector);
|
||||
const dispatch = useAppDispatch();
|
||||
const { gfpganStrength } = useAppSelector(sdSelector);
|
||||
const { isGFPGANAvailable } = useAppSelector(systemSelector);
|
||||
|
||||
const { isGFPGANAvailable } = useAppSelector(systemSelector);
|
||||
const handleChangeStrength = (v: string | number) =>
|
||||
dispatch(setGfpganStrength(Number(v)));
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
return (
|
||||
<Flex direction={'column'} gap={2}>
|
||||
<SDNumberInput
|
||||
isDisabled={!isGFPGANAvailable}
|
||||
label='Strength'
|
||||
step={0.05}
|
||||
min={0}
|
||||
max={1}
|
||||
onChange={(v) => dispatch(setGfpganStrength(Number(v)))}
|
||||
value={gfpganStrength}
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
return (
|
||||
<Flex direction={'column'} gap={2}>
|
||||
<SDNumberInput
|
||||
isDisabled={!isGFPGANAvailable}
|
||||
label="Strength"
|
||||
step={0.05}
|
||||
min={0}
|
||||
max={1}
|
||||
onChange={handleChangeStrength}
|
||||
value={gfpganStrength}
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default GFPGANOptions;
|
||||
|
@ -1,54 +1,59 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from '../../app/hooks';
|
||||
import { ChangeEvent } from 'react';
|
||||
import { useAppDispatch, useAppSelector } from '../../app/store';
|
||||
import { RootState } from '../../app/store';
|
||||
import SDNumberInput from '../../components/SDNumberInput';
|
||||
import SDSwitch from '../../components/SDSwitch';
|
||||
import InitImage from './InitImage';
|
||||
import SDNumberInput from '../../common/components/SDNumberInput';
|
||||
import SDSwitch from '../../common/components/SDSwitch';
|
||||
import InitAndMaskImage from './InitAndMaskImage';
|
||||
import {
|
||||
SDState,
|
||||
setImg2imgStrength,
|
||||
setShouldFitToWidthHeight,
|
||||
SDState,
|
||||
setImg2imgStrength,
|
||||
setShouldFitToWidthHeight,
|
||||
} from './sdSlice';
|
||||
|
||||
const sdSelector = createSelector(
|
||||
(state: RootState) => state.sd,
|
||||
(sd: SDState) => {
|
||||
return {
|
||||
initialImagePath: sd.initialImagePath,
|
||||
img2imgStrength: sd.img2imgStrength,
|
||||
shouldFitToWidthHeight: sd.shouldFitToWidthHeight,
|
||||
};
|
||||
}
|
||||
(state: RootState) => state.sd,
|
||||
(sd: SDState) => {
|
||||
return {
|
||||
img2imgStrength: sd.img2imgStrength,
|
||||
shouldFitToWidthHeight: sd.shouldFitToWidthHeight,
|
||||
};
|
||||
}
|
||||
);
|
||||
|
||||
/**
|
||||
* Options for img2img generation (strength, fit, init/mask upload).
|
||||
*/
|
||||
const ImageToImageOptions = () => {
|
||||
const { initialImagePath, img2imgStrength, shouldFitToWidthHeight } =
|
||||
useAppSelector(sdSelector);
|
||||
const dispatch = useAppDispatch();
|
||||
const { img2imgStrength, shouldFitToWidthHeight } =
|
||||
useAppSelector(sdSelector);
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
return (
|
||||
<Flex direction={'column'} gap={2}>
|
||||
<SDNumberInput
|
||||
isDisabled={!initialImagePath}
|
||||
label='Strength'
|
||||
step={0.01}
|
||||
min={0}
|
||||
max={1}
|
||||
onChange={(v) => dispatch(setImg2imgStrength(Number(v)))}
|
||||
value={img2imgStrength}
|
||||
/>
|
||||
<SDSwitch
|
||||
isDisabled={!initialImagePath}
|
||||
label='Fit initial image to output size'
|
||||
isChecked={shouldFitToWidthHeight}
|
||||
onChange={(e) =>
|
||||
dispatch(setShouldFitToWidthHeight(e.target.checked))
|
||||
}
|
||||
/>
|
||||
<InitImage />
|
||||
</Flex>
|
||||
);
|
||||
const handleChangeStrength = (v: string | number) =>
|
||||
dispatch(setImg2imgStrength(Number(v)));
|
||||
|
||||
const handleChangeFit = (e: ChangeEvent<HTMLInputElement>) =>
|
||||
dispatch(setShouldFitToWidthHeight(e.target.checked));
|
||||
|
||||
return (
|
||||
<Flex direction={'column'} gap={2}>
|
||||
<SDNumberInput
|
||||
label="Strength"
|
||||
step={0.01}
|
||||
min={0}
|
||||
max={1}
|
||||
onChange={handleChangeStrength}
|
||||
value={img2imgStrength}
|
||||
/>
|
||||
<SDSwitch
|
||||
label="Fit initial image to output size"
|
||||
isChecked={shouldFitToWidthHeight}
|
||||
onChange={handleChangeFit}
|
||||
/>
|
||||
<InitAndMaskImage />
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default ImageToImageOptions;
|
||||
|
63
frontend/src/features/sd/ImageUploader.tsx
Normal file
@ -0,0 +1,63 @@
|
||||
import { cloneElement, ReactElement, SyntheticEvent, useCallback } from 'react';
|
||||
import { FileRejection, useDropzone } from 'react-dropzone';
|
||||
|
||||
type ImageUploaderProps = {
|
||||
/**
|
||||
* Component which, on click, should open the upload interface.
|
||||
*/
|
||||
children: ReactElement;
|
||||
/**
|
||||
* Callback to handle uploading the selected file.
|
||||
*/
|
||||
fileAcceptedCallback: (file: File) => void;
|
||||
/**
|
||||
* Callback to handle a file being rejected.
|
||||
*/
|
||||
fileRejectionCallback: (rejection: FileRejection) => void;
|
||||
};
|
||||
|
||||
/**
|
||||
* File upload using react-dropzone.
|
||||
* Needs a child to be the button to activate the upload interface.
|
||||
*/
|
||||
const ImageUploader = ({
|
||||
children,
|
||||
fileAcceptedCallback,
|
||||
fileRejectionCallback,
|
||||
}: ImageUploaderProps) => {
|
||||
const onDrop = useCallback(
|
||||
(acceptedFiles: Array<File>, fileRejections: Array<FileRejection>) => {
|
||||
fileRejections.forEach((rejection: FileRejection) => {
|
||||
fileRejectionCallback(rejection);
|
||||
});
|
||||
|
||||
acceptedFiles.forEach((file: File) => {
|
||||
fileAcceptedCallback(file);
|
||||
});
|
||||
},
|
||||
[fileAcceptedCallback, fileRejectionCallback]
|
||||
);
|
||||
|
||||
const { getRootProps, getInputProps, open } = useDropzone({
|
||||
onDrop,
|
||||
accept: {
|
||||
'image/jpeg': ['.jpg', '.jpeg', '.png'],
|
||||
},
|
||||
});
|
||||
|
||||
const handleClickUploadIcon = (e: SyntheticEvent) => {
|
||||
e.stopPropagation();
|
||||
open();
|
||||
};
|
||||
|
||||
return (
|
||||
<div {...getRootProps()}>
|
||||
<input {...getInputProps({ multiple: false })} />
|
||||
{cloneElement(children, {
|
||||
onClick: handleClickUploadIcon,
|
||||
})}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default ImageUploader;
|
57
frontend/src/features/sd/InitAndMaskImage.tsx
Normal file
@ -0,0 +1,57 @@
|
||||
import { Flex, Image } from '@chakra-ui/react';
|
||||
import { useState } from 'react';
|
||||
import { useAppSelector } from '../../app/store';
|
||||
import { RootState } from '../../app/store';
|
||||
import { SDState } from '../../features/sd/sdSlice';
|
||||
import './InitAndMaskImage.css';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { isEqual } from 'lodash';
|
||||
import InitAndMaskUploadButtons from './InitAndMaskUploadButtons';
|
||||
|
||||
const sdSelector = createSelector(
|
||||
(state: RootState) => state.sd,
|
||||
(sd: SDState) => {
|
||||
return {
|
||||
initialImagePath: sd.initialImagePath,
|
||||
maskPath: sd.maskPath,
|
||||
};
|
||||
},
|
||||
{ memoizeOptions: { resultEqualityCheck: isEqual } }
|
||||
);
|
||||
|
||||
/**
|
||||
* Displays init and mask images and buttons to upload/delete them.
|
||||
*/
|
||||
const InitAndMaskImage = () => {
|
||||
const { initialImagePath, maskPath } = useAppSelector(sdSelector);
|
||||
const [shouldShowMask, setShouldShowMask] = useState<boolean>(false);
|
||||
|
||||
return (
|
||||
<Flex direction={'column'} alignItems={'center'} gap={2}>
|
||||
<InitAndMaskUploadButtons setShouldShowMask={setShouldShowMask} />
|
||||
{initialImagePath && (
|
||||
<Flex position={'relative'} width={'100%'}>
|
||||
<Image
|
||||
fit={'contain'}
|
||||
src={initialImagePath}
|
||||
rounded={'md'}
|
||||
className={'checkerboard'}
|
||||
/>
|
||||
{shouldShowMask && maskPath && (
|
||||
<Image
|
||||
position={'absolute'}
|
||||
top={0}
|
||||
left={0}
|
||||
fit={'contain'}
|
||||
src={maskPath}
|
||||
rounded={'md'}
|
||||
zIndex={1}
|
||||
/>
|
||||
)}
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default InitAndMaskImage;
|
131
frontend/src/features/sd/InitAndMaskUploadButtons.tsx
Normal file
@ -0,0 +1,131 @@
|
||||
import { Button, Flex, IconButton, useToast } from '@chakra-ui/react';
|
||||
import { SyntheticEvent, useCallback } from 'react';
|
||||
import { FaTrash } from 'react-icons/fa';
|
||||
import { useAppDispatch, useAppSelector } from '../../app/store';
|
||||
import { RootState } from '../../app/store';
|
||||
import {
|
||||
SDState,
|
||||
setInitialImagePath,
|
||||
setMaskPath,
|
||||
} from '../../features/sd/sdSlice';
|
||||
import { uploadInitialImage, uploadMaskImage } from '../../app/socketio/actions';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { isEqual } from 'lodash';
|
||||
import ImageUploader from './ImageUploader';
|
||||
import { FileRejection } from 'react-dropzone';
|
||||
|
||||
const sdSelector = createSelector(
|
||||
(state: RootState) => state.sd,
|
||||
(sd: SDState) => {
|
||||
return {
|
||||
initialImagePath: sd.initialImagePath,
|
||||
maskPath: sd.maskPath,
|
||||
};
|
||||
},
|
||||
{ memoizeOptions: { resultEqualityCheck: isEqual } }
|
||||
);
|
||||
|
||||
type InitAndMaskUploadButtonsProps = {
|
||||
setShouldShowMask: (b: boolean) => void;
|
||||
};
|
||||
|
||||
/**
|
||||
* Init and mask image upload buttons.
|
||||
*/
|
||||
const InitAndMaskUploadButtons = ({
|
||||
setShouldShowMask,
|
||||
}: InitAndMaskUploadButtonsProps) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { initialImagePath } = useAppSelector(sdSelector);
|
||||
|
||||
// Use a toast to alert user when a file upload is rejected
|
||||
const toast = useToast();
|
||||
|
||||
// Clear the init and mask images
|
||||
const handleClickResetInitialImageAndMask = (e: SyntheticEvent) => {
|
||||
e.stopPropagation();
|
||||
dispatch(setInitialImagePath(''));
|
||||
dispatch(setMaskPath(''));
|
||||
};
|
||||
|
||||
// Handle hover to view initial image and mask image
|
||||
const handleMouseOverInitialImageUploadButton = () =>
|
||||
setShouldShowMask(false);
|
||||
const handleMouseOutInitialImageUploadButton = () => setShouldShowMask(true);
|
||||
|
||||
const handleMouseOverMaskUploadButton = () => setShouldShowMask(true);
|
||||
const handleMouseOutMaskUploadButton = () => setShouldShowMask(true);
|
||||
|
||||
// Callbacks to for handling file upload attempts
|
||||
const initImageFileAcceptedCallback = useCallback(
|
||||
(file: File) => dispatch(uploadInitialImage(file)),
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const maskImageFileAcceptedCallback = useCallback(
|
||||
(file: File) => dispatch(uploadMaskImage(file)),
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const fileRejectionCallback = useCallback(
|
||||
(rejection: FileRejection) => {
|
||||
const msg = rejection.errors.reduce(
|
||||
(acc: string, cur: { message: string }) => acc + '\n' + cur.message,
|
||||
''
|
||||
);
|
||||
|
||||
toast({
|
||||
title: 'Upload failed',
|
||||
description: msg,
|
||||
status: 'error',
|
||||
isClosable: true,
|
||||
});
|
||||
},
|
||||
[toast]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex gap={2} justifyContent={'space-between'} width={'100%'}>
|
||||
<ImageUploader
|
||||
fileAcceptedCallback={initImageFileAcceptedCallback}
|
||||
fileRejectionCallback={fileRejectionCallback}
|
||||
>
|
||||
<Button
|
||||
size={'sm'}
|
||||
fontSize={'md'}
|
||||
fontWeight={'normal'}
|
||||
onMouseOver={handleMouseOverInitialImageUploadButton}
|
||||
onMouseOut={handleMouseOutInitialImageUploadButton}
|
||||
>
|
||||
Upload Image
|
||||
</Button>
|
||||
</ImageUploader>
|
||||
|
||||
<ImageUploader
|
||||
fileAcceptedCallback={maskImageFileAcceptedCallback}
|
||||
fileRejectionCallback={fileRejectionCallback}
|
||||
>
|
||||
<Button
|
||||
isDisabled={!initialImagePath}
|
||||
size={'sm'}
|
||||
fontSize={'md'}
|
||||
fontWeight={'normal'}
|
||||
onMouseOver={handleMouseOverMaskUploadButton}
|
||||
onMouseOut={handleMouseOutMaskUploadButton}
|
||||
>
|
||||
Upload Mask
|
||||
</Button>
|
||||
</ImageUploader>
|
||||
|
||||
<IconButton
|
||||
isDisabled={!initialImagePath}
|
||||
size={'sm'}
|
||||
aria-label={'Reset initial image and mask'}
|
||||
onClick={handleClickResetInitialImageAndMask}
|
||||
icon={<FaTrash />}
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default InitAndMaskUploadButtons;
|
@ -1,155 +0,0 @@
|
||||
import {
|
||||
Button,
|
||||
Flex,
|
||||
IconButton,
|
||||
Image,
|
||||
useToast,
|
||||
} from '@chakra-ui/react';
|
||||
import { SyntheticEvent, useCallback, useState } from 'react';
|
||||
import { FileRejection, useDropzone } from 'react-dropzone';
|
||||
import { FaTrash } from 'react-icons/fa';
|
||||
import { useAppDispatch, useAppSelector } from '../../app/hooks';
|
||||
import { RootState } from '../../app/store';
|
||||
import {
|
||||
SDState,
|
||||
setInitialImagePath,
|
||||
setMaskPath,
|
||||
} from '../../features/sd/sdSlice';
|
||||
import MaskUploader from './MaskUploader';
|
||||
import './InitImage.css';
|
||||
import { uploadInitialImage } from '../../app/socketio';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { isEqual } from 'lodash';
|
||||
|
||||
const sdSelector = createSelector(
|
||||
(state: RootState) => state.sd,
|
||||
(sd: SDState) => {
|
||||
return {
|
||||
initialImagePath: sd.initialImagePath,
|
||||
maskPath: sd.maskPath,
|
||||
};
|
||||
},
|
||||
{ memoizeOptions: { resultEqualityCheck: isEqual } }
|
||||
);
|
||||
|
||||
const InitImage = () => {
|
||||
const toast = useToast();
|
||||
const dispatch = useAppDispatch();
|
||||
const { initialImagePath, maskPath } = useAppSelector(sdSelector);
|
||||
|
||||
const onDrop = useCallback(
|
||||
(acceptedFiles: Array<File>, fileRejections: Array<FileRejection>) => {
|
||||
fileRejections.forEach((rejection: FileRejection) => {
|
||||
const msg = rejection.errors.reduce(
|
||||
(acc: string, cur: { message: string }) => acc + '\n' + cur.message,
|
||||
''
|
||||
);
|
||||
|
||||
toast({
|
||||
title: 'Upload failed',
|
||||
description: msg,
|
||||
status: 'error',
|
||||
isClosable: true,
|
||||
});
|
||||
});
|
||||
|
||||
acceptedFiles.forEach((file: File) => {
|
||||
dispatch(uploadInitialImage(file));
|
||||
});
|
||||
},
|
||||
[dispatch, toast]
|
||||
);
|
||||
|
||||
const { getRootProps, getInputProps, open } = useDropzone({
|
||||
onDrop,
|
||||
accept: {
|
||||
'image/jpeg': ['.jpg', '.jpeg', '.png'],
|
||||
},
|
||||
});
|
||||
|
||||
const [shouldShowMask, setShouldShowMask] = useState<boolean>(false);
|
||||
const handleClickUploadIcon = (e: SyntheticEvent) => {
|
||||
e.stopPropagation();
|
||||
open();
|
||||
};
|
||||
const handleClickResetInitialImageAndMask = (e: SyntheticEvent) => {
|
||||
e.stopPropagation();
|
||||
dispatch(setInitialImagePath(''));
|
||||
dispatch(setMaskPath(''));
|
||||
};
|
||||
|
||||
const handleMouseOverInitialImageUploadButton = () =>
|
||||
setShouldShowMask(false);
|
||||
const handleMouseOutInitialImageUploadButton = () => setShouldShowMask(true);
|
||||
|
||||
const handleMouseOverMaskUploadButton = () => setShouldShowMask(true);
|
||||
const handleMouseOutMaskUploadButton = () => setShouldShowMask(true);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
{...getRootProps({
|
||||
onClick: initialImagePath ? (e) => e.stopPropagation() : undefined,
|
||||
})}
|
||||
direction={'column'}
|
||||
alignItems={'center'}
|
||||
gap={2}
|
||||
>
|
||||
<input {...getInputProps({ multiple: false })} />
|
||||
<Flex gap={2} justifyContent={'space-between'} width={'100%'}>
|
||||
<Button
|
||||
size={'sm'}
|
||||
fontSize={'md'}
|
||||
fontWeight={'normal'}
|
||||
onClick={handleClickUploadIcon}
|
||||
onMouseOver={handleMouseOverInitialImageUploadButton}
|
||||
onMouseOut={handleMouseOutInitialImageUploadButton}
|
||||
>
|
||||
Upload Image
|
||||
</Button>
|
||||
|
||||
<MaskUploader>
|
||||
<Button
|
||||
size={'sm'}
|
||||
fontSize={'md'}
|
||||
fontWeight={'normal'}
|
||||
onClick={handleClickUploadIcon}
|
||||
onMouseOver={handleMouseOverMaskUploadButton}
|
||||
onMouseOut={handleMouseOutMaskUploadButton}
|
||||
>
|
||||
Upload Mask
|
||||
</Button>
|
||||
</MaskUploader>
|
||||
<IconButton
|
||||
size={'sm'}
|
||||
aria-label={'Reset initial image and mask'}
|
||||
onClick={handleClickResetInitialImageAndMask}
|
||||
icon={<FaTrash />}
|
||||
/>
|
||||
</Flex>
|
||||
{initialImagePath && (
|
||||
<Flex position={'relative'} width={'100%'}>
|
||||
<Image
|
||||
fit={'contain'}
|
||||
src={initialImagePath}
|
||||
rounded={'md'}
|
||||
className={'checkerboard'}
|
||||
/>
|
||||
{shouldShowMask && maskPath && (
|
||||
<Image
|
||||
position={'absolute'}
|
||||
top={0}
|
||||
left={0}
|
||||
fit={'contain'}
|
||||
src={maskPath}
|
||||
rounded={'md'}
|
||||
zIndex={1}
|
||||
className={'checkerboard'}
|
||||
/>
|
||||
)}
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default InitImage;
|
@ -1,61 +0,0 @@
|
||||
import { useToast } from '@chakra-ui/react';
|
||||
import { cloneElement, ReactElement, SyntheticEvent, useCallback } from 'react';
|
||||
import { FileRejection, useDropzone } from 'react-dropzone';
|
||||
import { useAppDispatch } from '../../app/hooks';
|
||||
import { uploadMaskImage } from '../../app/socketio';
|
||||
|
||||
type Props = {
|
||||
children: ReactElement;
|
||||
};
|
||||
|
||||
const MaskUploader = ({ children }: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const toast = useToast();
|
||||
|
||||
const onDrop = useCallback(
|
||||
(acceptedFiles: Array<File>, fileRejections: Array<FileRejection>) => {
|
||||
fileRejections.forEach((rejection: FileRejection) => {
|
||||
const msg = rejection.errors.reduce(
|
||||
(acc: string, cur: { message: string }) =>
|
||||
acc + '\n' + cur.message,
|
||||
''
|
||||
);
|
||||
|
||||
toast({
|
||||
title: 'Upload failed',
|
||||
description: msg,
|
||||
status: 'error',
|
||||
isClosable: true,
|
||||
});
|
||||
});
|
||||
|
||||
acceptedFiles.forEach((file: File) => {
|
||||
dispatch(uploadMaskImage(file));
|
||||
});
|
||||
},
|
||||
[dispatch, toast]
|
||||
);
|
||||
|
||||
const { getRootProps, getInputProps, open } = useDropzone({
|
||||
onDrop,
|
||||
accept: {
|
||||
'image/jpeg': ['.jpg', '.jpeg', '.png'],
|
||||
},
|
||||
});
|
||||
|
||||
const handleClickUploadIcon = (e: SyntheticEvent) => {
|
||||
e.stopPropagation();
|
||||
open();
|
||||
};
|
||||
|
||||
return (
|
||||
<div {...getRootProps()}>
|
||||
<input {...getInputProps({ multiple: false })} />
|
||||
{cloneElement(children, {
|
||||
onClick: handleClickUploadIcon,
|
||||
})}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default MaskUploader;
|
@ -1,23 +1,24 @@
|
||||
import {
|
||||
Flex,
|
||||
Box,
|
||||
Text,
|
||||
Accordion,
|
||||
AccordionItem,
|
||||
AccordionButton,
|
||||
AccordionIcon,
|
||||
AccordionPanel,
|
||||
Switch,
|
||||
Flex,
|
||||
Box,
|
||||
Text,
|
||||
Accordion,
|
||||
AccordionItem,
|
||||
AccordionButton,
|
||||
AccordionIcon,
|
||||
AccordionPanel,
|
||||
Switch,
|
||||
ExpandedIndex,
|
||||
} from '@chakra-ui/react';
|
||||
|
||||
import { RootState } from '../../app/store';
|
||||
import { useAppDispatch, useAppSelector } from '../../app/hooks';
|
||||
import { useAppDispatch, useAppSelector } from '../../app/store';
|
||||
|
||||
import {
|
||||
setShouldRunGFPGAN,
|
||||
setShouldRunESRGAN,
|
||||
SDState,
|
||||
setShouldUseInitImage,
|
||||
setShouldRunGFPGAN,
|
||||
setShouldRunESRGAN,
|
||||
SDState,
|
||||
setShouldUseInitImage,
|
||||
} from '../sd/sdSlice';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { isEqual } from 'lodash';
|
||||
@ -28,184 +29,189 @@ import ESRGANOptions from './ESRGANOptions';
|
||||
import GFPGANOptions from './GFPGANOptions';
|
||||
import OutputOptions from './OutputOptions';
|
||||
import ImageToImageOptions from './ImageToImageOptions';
|
||||
import { ChangeEvent } from 'react';
|
||||
|
||||
const sdSelector = createSelector(
|
||||
(state: RootState) => state.sd,
|
||||
(sd: SDState) => {
|
||||
return {
|
||||
initialImagePath: sd.initialImagePath,
|
||||
shouldUseInitImage: sd.shouldUseInitImage,
|
||||
shouldRunESRGAN: sd.shouldRunESRGAN,
|
||||
shouldRunGFPGAN: sd.shouldRunGFPGAN,
|
||||
};
|
||||
(state: RootState) => state.sd,
|
||||
(sd: SDState) => {
|
||||
return {
|
||||
initialImagePath: sd.initialImagePath,
|
||||
shouldUseInitImage: sd.shouldUseInitImage,
|
||||
shouldRunESRGAN: sd.shouldRunESRGAN,
|
||||
shouldRunGFPGAN: sd.shouldRunGFPGAN,
|
||||
};
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
const systemSelector = createSelector(
|
||||
(state: RootState) => state.system,
|
||||
(system: SystemState) => {
|
||||
return {
|
||||
isGFPGANAvailable: system.isGFPGANAvailable,
|
||||
isESRGANAvailable: system.isESRGANAvailable,
|
||||
openAccordions: system.openAccordions,
|
||||
};
|
||||
(state: RootState) => state.system,
|
||||
(system: SystemState) => {
|
||||
return {
|
||||
isGFPGANAvailable: system.isGFPGANAvailable,
|
||||
isESRGANAvailable: system.isESRGANAvailable,
|
||||
openAccordions: system.openAccordions,
|
||||
};
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
/**
|
||||
* Main container for generation and processing parameters.
|
||||
*/
|
||||
const OptionsAccordion = () => {
|
||||
const {
|
||||
shouldRunESRGAN,
|
||||
shouldRunGFPGAN,
|
||||
shouldUseInitImage,
|
||||
initialImagePath,
|
||||
} = useAppSelector(sdSelector);
|
||||
const {
|
||||
shouldRunESRGAN,
|
||||
shouldRunGFPGAN,
|
||||
shouldUseInitImage,
|
||||
initialImagePath,
|
||||
} = useAppSelector(sdSelector);
|
||||
|
||||
const { isGFPGANAvailable, isESRGANAvailable, openAccordions } =
|
||||
useAppSelector(systemSelector);
|
||||
const { isGFPGANAvailable, isESRGANAvailable, openAccordions } =
|
||||
useAppSelector(systemSelector);
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
return (
|
||||
<Accordion
|
||||
defaultIndex={openAccordions}
|
||||
allowMultiple
|
||||
reduceMotion
|
||||
onChange={(openAccordions) =>
|
||||
dispatch(setOpenAccordions(openAccordions))
|
||||
}
|
||||
>
|
||||
<AccordionItem>
|
||||
<h2>
|
||||
<AccordionButton>
|
||||
<Box flex='1' textAlign='left'>
|
||||
Seed & Variation
|
||||
</Box>
|
||||
<AccordionIcon />
|
||||
</AccordionButton>
|
||||
</h2>
|
||||
<AccordionPanel>
|
||||
<SeedVariationOptions />
|
||||
</AccordionPanel>
|
||||
</AccordionItem>
|
||||
<AccordionItem>
|
||||
<h2>
|
||||
<AccordionButton>
|
||||
<Box flex='1' textAlign='left'>
|
||||
Sampler
|
||||
</Box>
|
||||
<AccordionIcon />
|
||||
</AccordionButton>
|
||||
</h2>
|
||||
<AccordionPanel>
|
||||
<SamplerOptions />
|
||||
</AccordionPanel>
|
||||
</AccordionItem>
|
||||
<AccordionItem>
|
||||
<h2>
|
||||
<AccordionButton>
|
||||
<Flex
|
||||
justifyContent={'space-between'}
|
||||
alignItems={'center'}
|
||||
width={'100%'}
|
||||
mr={2}
|
||||
>
|
||||
<Text>Upscale (ESRGAN)</Text>
|
||||
<Switch
|
||||
isDisabled={!isESRGANAvailable}
|
||||
isChecked={shouldRunESRGAN}
|
||||
onChange={(e) =>
|
||||
dispatch(
|
||||
setShouldRunESRGAN(e.target.checked)
|
||||
)
|
||||
}
|
||||
/>
|
||||
</Flex>
|
||||
<AccordionIcon />
|
||||
</AccordionButton>
|
||||
</h2>
|
||||
<AccordionPanel>
|
||||
<ESRGANOptions />
|
||||
</AccordionPanel>
|
||||
</AccordionItem>
|
||||
<AccordionItem>
|
||||
<h2>
|
||||
<AccordionButton>
|
||||
<Flex
|
||||
justifyContent={'space-between'}
|
||||
alignItems={'center'}
|
||||
width={'100%'}
|
||||
mr={2}
|
||||
>
|
||||
<Text>Fix Faces (GFPGAN)</Text>
|
||||
<Switch
|
||||
isDisabled={!isGFPGANAvailable}
|
||||
isChecked={shouldRunGFPGAN}
|
||||
onChange={(e) =>
|
||||
dispatch(
|
||||
setShouldRunGFPGAN(e.target.checked)
|
||||
)
|
||||
}
|
||||
/>
|
||||
</Flex>
|
||||
<AccordionIcon />
|
||||
</AccordionButton>
|
||||
</h2>
|
||||
<AccordionPanel>
|
||||
<GFPGANOptions />
|
||||
</AccordionPanel>
|
||||
</AccordionItem>
|
||||
<AccordionItem>
|
||||
<h2>
|
||||
<AccordionButton>
|
||||
<Flex
|
||||
justifyContent={'space-between'}
|
||||
alignItems={'center'}
|
||||
width={'100%'}
|
||||
mr={2}
|
||||
>
|
||||
<Text>Image to Image</Text>
|
||||
<Switch
|
||||
isDisabled={!initialImagePath}
|
||||
isChecked={shouldUseInitImage}
|
||||
onChange={(e) =>
|
||||
dispatch(
|
||||
setShouldUseInitImage(e.target.checked)
|
||||
)
|
||||
}
|
||||
/>
|
||||
</Flex>
|
||||
<AccordionIcon />
|
||||
</AccordionButton>
|
||||
</h2>
|
||||
<AccordionPanel>
|
||||
<ImageToImageOptions />
|
||||
</AccordionPanel>
|
||||
</AccordionItem>
|
||||
<AccordionItem>
|
||||
<h2>
|
||||
<AccordionButton>
|
||||
<Box flex='1' textAlign='left'>
|
||||
Output
|
||||
</Box>
|
||||
<AccordionIcon />
|
||||
</AccordionButton>
|
||||
</h2>
|
||||
<AccordionPanel>
|
||||
<OutputOptions />
|
||||
</AccordionPanel>
|
||||
</AccordionItem>
|
||||
</Accordion>
|
||||
);
|
||||
/**
|
||||
* Stores accordion state in redux so preferred UI setup is retained.
|
||||
*/
|
||||
const handleChangeAccordionState = (openAccordions: ExpandedIndex) =>
|
||||
dispatch(setOpenAccordions(openAccordions));
|
||||
|
||||
const handleChangeShouldRunESRGAN = (e: ChangeEvent<HTMLInputElement>) =>
|
||||
dispatch(setShouldRunESRGAN(e.target.checked));
|
||||
|
||||
const handleChangeShouldRunGFPGAN = (e: ChangeEvent<HTMLInputElement>) =>
|
||||
dispatch(setShouldRunGFPGAN(e.target.checked));
|
||||
|
||||
const handleChangeShouldUseInitImage = (e: ChangeEvent<HTMLInputElement>) =>
|
||||
dispatch(setShouldUseInitImage(e.target.checked));
|
||||
|
||||
return (
|
||||
<Accordion
|
||||
defaultIndex={openAccordions}
|
||||
allowMultiple
|
||||
reduceMotion
|
||||
onChange={handleChangeAccordionState}
|
||||
>
|
||||
<AccordionItem>
|
||||
<h2>
|
||||
<AccordionButton>
|
||||
<Box flex="1" textAlign="left">
|
||||
Seed & Variation
|
||||
</Box>
|
||||
<AccordionIcon />
|
||||
</AccordionButton>
|
||||
</h2>
|
||||
<AccordionPanel>
|
||||
<SeedVariationOptions />
|
||||
</AccordionPanel>
|
||||
</AccordionItem>
|
||||
<AccordionItem>
|
||||
<h2>
|
||||
<AccordionButton>
|
||||
<Box flex="1" textAlign="left">
|
||||
Sampler
|
||||
</Box>
|
||||
<AccordionIcon />
|
||||
</AccordionButton>
|
||||
</h2>
|
||||
<AccordionPanel>
|
||||
<SamplerOptions />
|
||||
</AccordionPanel>
|
||||
</AccordionItem>
|
||||
<AccordionItem>
|
||||
<h2>
|
||||
<AccordionButton>
|
||||
<Flex
|
||||
justifyContent={'space-between'}
|
||||
alignItems={'center'}
|
||||
width={'100%'}
|
||||
mr={2}
|
||||
>
|
||||
<Text>Upscale (ESRGAN)</Text>
|
||||
<Switch
|
||||
isDisabled={!isESRGANAvailable}
|
||||
isChecked={shouldRunESRGAN}
|
||||
onChange={handleChangeShouldRunESRGAN}
|
||||
/>
|
||||
</Flex>
|
||||
<AccordionIcon />
|
||||
</AccordionButton>
|
||||
</h2>
|
||||
<AccordionPanel>
|
||||
<ESRGANOptions />
|
||||
</AccordionPanel>
|
||||
</AccordionItem>
|
||||
<AccordionItem>
|
||||
<h2>
|
||||
<AccordionButton>
|
||||
<Flex
|
||||
justifyContent={'space-between'}
|
||||
alignItems={'center'}
|
||||
width={'100%'}
|
||||
mr={2}
|
||||
>
|
||||
<Text>Fix Faces (GFPGAN)</Text>
|
||||
<Switch
|
||||
isDisabled={!isGFPGANAvailable}
|
||||
isChecked={shouldRunGFPGAN}
|
||||
onChange={handleChangeShouldRunGFPGAN}
|
||||
/>
|
||||
</Flex>
|
||||
<AccordionIcon />
|
||||
</AccordionButton>
|
||||
</h2>
|
||||
<AccordionPanel>
|
||||
<GFPGANOptions />
|
||||
</AccordionPanel>
|
||||
</AccordionItem>
|
||||
<AccordionItem>
|
||||
<h2>
|
||||
<AccordionButton>
|
||||
<Flex
|
||||
justifyContent={'space-between'}
|
||||
alignItems={'center'}
|
||||
width={'100%'}
|
||||
mr={2}
|
||||
>
|
||||
<Text>Image to Image</Text>
|
||||
<Switch
|
||||
isDisabled={!initialImagePath}
|
||||
isChecked={shouldUseInitImage}
|
||||
onChange={handleChangeShouldUseInitImage}
|
||||
/>
|
||||
</Flex>
|
||||
<AccordionIcon />
|
||||
</AccordionButton>
|
||||
</h2>
|
||||
<AccordionPanel>
|
||||
<ImageToImageOptions />
|
||||
</AccordionPanel>
|
||||
</AccordionItem>
|
||||
<AccordionItem>
|
||||
<h2>
|
||||
<AccordionButton>
|
||||
<Box flex="1" textAlign="left">
|
||||
Output
|
||||
</Box>
|
||||
<AccordionIcon />
|
||||
</AccordionButton>
|
||||
</h2>
|
||||
<AccordionPanel>
|
||||
<OutputOptions />
|
||||
</AccordionPanel>
|
||||
</AccordionItem>
|
||||
</Accordion>
|
||||
);
|
||||
};
|
||||
|
||||
export default OptionsAccordion;
|
||||
|
@ -1,66 +1,76 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
|
||||
import { RootState } from '../../app/store';
|
||||
import { useAppDispatch, useAppSelector } from '../../app/hooks';
|
||||
import { useAppDispatch, useAppSelector } from '../../app/store';
|
||||
|
||||
import { setHeight, setWidth, setSeamless, SDState } from '../sd/sdSlice';
|
||||
|
||||
import SDSelect from '../../components/SDSelect';
|
||||
|
||||
import { HEIGHTS, WIDTHS } from '../../app/constants';
|
||||
import SDSwitch from '../../components/SDSwitch';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { isEqual } from 'lodash';
|
||||
import { ChangeEvent } from 'react';
|
||||
import SDSelect from '../../common/components/SDSelect';
|
||||
import SDSwitch from '../../common/components/SDSwitch';
|
||||
|
||||
const sdSelector = createSelector(
|
||||
(state: RootState) => state.sd,
|
||||
(sd: SDState) => {
|
||||
return {
|
||||
height: sd.height,
|
||||
width: sd.width,
|
||||
seamless: sd.seamless,
|
||||
};
|
||||
(state: RootState) => state.sd,
|
||||
(sd: SDState) => {
|
||||
return {
|
||||
height: sd.height,
|
||||
width: sd.width,
|
||||
seamless: sd.seamless,
|
||||
};
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
/**
|
||||
* Image output options. Includes width, height, seamless tiling.
|
||||
*/
|
||||
const OutputOptions = () => {
|
||||
const { height, width, seamless } = useAppSelector(sdSelector);
|
||||
const dispatch = useAppDispatch();
|
||||
const { height, width, seamless } = useAppSelector(sdSelector);
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const handleChangeWidth = (e: ChangeEvent<HTMLSelectElement>) =>
|
||||
dispatch(setWidth(Number(e.target.value)));
|
||||
|
||||
return (
|
||||
<Flex gap={2} direction={'column'}>
|
||||
<Flex gap={2}>
|
||||
<SDSelect
|
||||
label='Width'
|
||||
value={width}
|
||||
flexGrow={1}
|
||||
onChange={(e) => dispatch(setWidth(Number(e.target.value)))}
|
||||
validValues={WIDTHS}
|
||||
/>
|
||||
<SDSelect
|
||||
label='Height'
|
||||
value={height}
|
||||
flexGrow={1}
|
||||
onChange={(e) =>
|
||||
dispatch(setHeight(Number(e.target.value)))
|
||||
}
|
||||
validValues={HEIGHTS}
|
||||
/>
|
||||
</Flex>
|
||||
<SDSwitch
|
||||
label='Seamless tiling'
|
||||
fontSize={'md'}
|
||||
isChecked={seamless}
|
||||
onChange={(e) => dispatch(setSeamless(e.target.checked))}
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
const handleChangeHeight = (e: ChangeEvent<HTMLSelectElement>) =>
|
||||
dispatch(setHeight(Number(e.target.value)));
|
||||
|
||||
const handleChangeSeamless = (e: ChangeEvent<HTMLInputElement>) =>
|
||||
dispatch(setSeamless(e.target.checked));
|
||||
|
||||
return (
|
||||
<Flex gap={2} direction={'column'}>
|
||||
<Flex gap={2}>
|
||||
<SDSelect
|
||||
label="Width"
|
||||
value={width}
|
||||
flexGrow={1}
|
||||
onChange={handleChangeWidth}
|
||||
validValues={WIDTHS}
|
||||
/>
|
||||
<SDSelect
|
||||
label="Height"
|
||||
value={height}
|
||||
flexGrow={1}
|
||||
onChange={handleChangeHeight}
|
||||
validValues={HEIGHTS}
|
||||
/>
|
||||
</Flex>
|
||||
<SDSwitch
|
||||
label="Seamless tiling"
|
||||
fontSize={'md'}
|
||||
isChecked={seamless}
|
||||
onChange={handleChangeSeamless}
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default OutputOptions;
|
||||
|
@ -1,58 +1,68 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { isEqual } from 'lodash';
|
||||
import { useAppDispatch, useAppSelector } from '../../app/hooks';
|
||||
import { cancelProcessing, generateImage } from '../../app/socketio';
|
||||
import { useAppDispatch, useAppSelector } from '../../app/store';
|
||||
import { cancelProcessing, generateImage } from '../../app/socketio/actions';
|
||||
import { RootState } from '../../app/store';
|
||||
import SDButton from '../../components/SDButton';
|
||||
import SDButton from '../../common/components/SDButton';
|
||||
import useCheckParameters from '../../common/hooks/useCheckParameters';
|
||||
import { SystemState } from '../system/systemSlice';
|
||||
import useCheckParameters from '../system/useCheckParameters';
|
||||
|
||||
const systemSelector = createSelector(
|
||||
(state: RootState) => state.system,
|
||||
(system: SystemState) => {
|
||||
return {
|
||||
isProcessing: system.isProcessing,
|
||||
isConnected: system.isConnected,
|
||||
};
|
||||
(state: RootState) => state.system,
|
||||
(system: SystemState) => {
|
||||
return {
|
||||
isProcessing: system.isProcessing,
|
||||
isConnected: system.isConnected,
|
||||
};
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
/**
|
||||
* Buttons to start and cancel image generation.
|
||||
*/
|
||||
const ProcessButtons = () => {
|
||||
const { isProcessing, isConnected } = useAppSelector(systemSelector);
|
||||
const dispatch = useAppDispatch();
|
||||
const { isProcessing, isConnected } = useAppSelector(systemSelector);
|
||||
const isReady = useCheckParameters();
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const handleClickGenerate = () => dispatch(generateImage());
|
||||
|
||||
const isReady = useCheckParameters();
|
||||
const handleClickCancel = () => dispatch(cancelProcessing());
|
||||
|
||||
return (
|
||||
<Flex gap={2} direction={'column'} alignItems={'space-between'} height={'100%'}>
|
||||
<SDButton
|
||||
label='Generate'
|
||||
type='submit'
|
||||
colorScheme='green'
|
||||
flexGrow={1}
|
||||
isDisabled={!isReady}
|
||||
fontSize={'md'}
|
||||
size={'md'}
|
||||
onClick={() => dispatch(generateImage())}
|
||||
/>
|
||||
<SDButton
|
||||
label='Cancel'
|
||||
colorScheme='red'
|
||||
flexGrow={1}
|
||||
fontSize={'md'}
|
||||
size={'md'}
|
||||
isDisabled={!isConnected || !isProcessing}
|
||||
onClick={() => dispatch(cancelProcessing())}
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
return (
|
||||
<Flex
|
||||
gap={2}
|
||||
direction={'column'}
|
||||
alignItems={'space-between'}
|
||||
height={'100%'}
|
||||
>
|
||||
<SDButton
|
||||
label="Generate"
|
||||
type="submit"
|
||||
colorScheme="green"
|
||||
flexGrow={1}
|
||||
isDisabled={!isReady}
|
||||
fontSize={'md'}
|
||||
size={'md'}
|
||||
onClick={handleClickGenerate}
|
||||
/>
|
||||
<SDButton
|
||||
label="Cancel"
|
||||
colorScheme="red"
|
||||
flexGrow={1}
|
||||
fontSize={'md'}
|
||||
size={'md'}
|
||||
isDisabled={!isConnected || !isProcessing}
|
||||
onClick={handleClickCancel}
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default ProcessButtons;
|
||||
|
@ -1,21 +1,40 @@
|
||||
import { Textarea } from '@chakra-ui/react';
|
||||
import { useAppDispatch, useAppSelector } from '../../app/hooks';
|
||||
import {
|
||||
ChangeEvent,
|
||||
KeyboardEvent,
|
||||
} from 'react';
|
||||
import { useAppDispatch, useAppSelector } from '../../app/store';
|
||||
import { generateImage } from '../../app/socketio/actions';
|
||||
import { RootState } from '../../app/store';
|
||||
import { setPrompt } from '../sd/sdSlice';
|
||||
|
||||
/**
|
||||
* Prompt input text area.
|
||||
*/
|
||||
const PromptInput = () => {
|
||||
const { prompt } = useAppSelector((state: RootState) => state.sd);
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const handleChangePrompt = (e: ChangeEvent<HTMLTextAreaElement>) =>
|
||||
dispatch(setPrompt(e.target.value));
|
||||
|
||||
const handleKeyDown = (e: KeyboardEvent<HTMLTextAreaElement>) => {
|
||||
if (e.key === 'Enter' && e.shiftKey === false) {
|
||||
e.preventDefault();
|
||||
dispatch(generateImage())
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<Textarea
|
||||
id='prompt'
|
||||
name='prompt'
|
||||
resize='none'
|
||||
id="prompt"
|
||||
name="prompt"
|
||||
resize="none"
|
||||
size={'lg'}
|
||||
height={'100%'}
|
||||
isInvalid={!prompt.length}
|
||||
onChange={(e) => dispatch(setPrompt(e.target.value))}
|
||||
onChange={handleChangePrompt}
|
||||
onKeyDown={handleKeyDown}
|
||||
value={prompt}
|
||||
placeholder="I'm dreaming of..."
|
||||
/>
|
||||
|
@ -1,51 +0,0 @@
|
||||
import {
|
||||
Slider,
|
||||
SliderTrack,
|
||||
SliderFilledTrack,
|
||||
SliderThumb,
|
||||
FormControl,
|
||||
FormLabel,
|
||||
Text,
|
||||
Flex,
|
||||
SliderProps,
|
||||
} from '@chakra-ui/react';
|
||||
|
||||
interface Props extends SliderProps {
|
||||
label: string;
|
||||
value: number;
|
||||
fontSize?: number | string;
|
||||
}
|
||||
|
||||
const SDSlider = ({
|
||||
label,
|
||||
value,
|
||||
fontSize = 'sm',
|
||||
onChange,
|
||||
...rest
|
||||
}: Props) => {
|
||||
return (
|
||||
<FormControl>
|
||||
<Flex gap={2}>
|
||||
<FormLabel marginInlineEnd={0} marginBottom={1}>
|
||||
<Text fontSize={fontSize} whiteSpace='nowrap'>
|
||||
{label}
|
||||
</Text>
|
||||
</FormLabel>
|
||||
<Slider
|
||||
aria-label={label}
|
||||
focusThumbOnChange={true}
|
||||
value={value}
|
||||
onChange={onChange}
|
||||
{...rest}
|
||||
>
|
||||
<SliderTrack>
|
||||
<SliderFilledTrack />
|
||||
</SliderTrack>
|
||||
<SliderThumb />
|
||||
</Slider>
|
||||
</Flex>
|
||||
</FormControl>
|
||||
);
|
||||
};
|
||||
|
||||
export default SDSlider;
|
@ -1,62 +1,74 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
|
||||
import { RootState } from '../../app/store';
|
||||
import { useAppDispatch, useAppSelector } from '../../app/hooks';
|
||||
import { useAppDispatch, useAppSelector } from '../../app/store';
|
||||
|
||||
import { setCfgScale, setSampler, setSteps, SDState } from '../sd/sdSlice';
|
||||
|
||||
import SDNumberInput from '../../components/SDNumberInput';
|
||||
import SDSelect from '../../components/SDSelect';
|
||||
|
||||
import { SAMPLERS } from '../../app/constants';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { isEqual } from 'lodash';
|
||||
import { ChangeEvent } from 'react';
|
||||
import SDNumberInput from '../../common/components/SDNumberInput';
|
||||
import SDSelect from '../../common/components/SDSelect';
|
||||
|
||||
const sdSelector = createSelector(
|
||||
(state: RootState) => state.sd,
|
||||
(sd: SDState) => {
|
||||
return {
|
||||
steps: sd.steps,
|
||||
cfgScale: sd.cfgScale,
|
||||
sampler: sd.sampler,
|
||||
};
|
||||
(state: RootState) => state.sd,
|
||||
(sd: SDState) => {
|
||||
return {
|
||||
steps: sd.steps,
|
||||
cfgScale: sd.cfgScale,
|
||||
sampler: sd.sampler,
|
||||
};
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
/**
|
||||
* Sampler options. Includes steps, CFG scale, sampler.
|
||||
*/
|
||||
const SamplerOptions = () => {
|
||||
const { steps, cfgScale, sampler } = useAppSelector(sdSelector);
|
||||
const dispatch = useAppDispatch();
|
||||
const { steps, cfgScale, sampler } = useAppSelector(sdSelector);
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const handleChangeSteps = (v: string | number) =>
|
||||
dispatch(setSteps(Number(v)));
|
||||
|
||||
return (
|
||||
<Flex gap={2} direction={'column'}>
|
||||
<SDNumberInput
|
||||
label='Steps'
|
||||
min={1}
|
||||
step={1}
|
||||
precision={0}
|
||||
onChange={(v) => dispatch(setSteps(Number(v)))}
|
||||
value={steps}
|
||||
/>
|
||||
<SDNumberInput
|
||||
label='CFG scale'
|
||||
step={0.5}
|
||||
onChange={(v) => dispatch(setCfgScale(Number(v)))}
|
||||
value={cfgScale}
|
||||
/>
|
||||
<SDSelect
|
||||
label='Sampler'
|
||||
value={sampler}
|
||||
onChange={(e) => dispatch(setSampler(e.target.value))}
|
||||
validValues={SAMPLERS}
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
const handleChangeCfgScale = (v: string | number) =>
|
||||
dispatch(setCfgScale(Number(v)));
|
||||
|
||||
const handleChangeSampler = (e: ChangeEvent<HTMLSelectElement>) =>
|
||||
dispatch(setSampler(e.target.value));
|
||||
|
||||
return (
|
||||
<Flex gap={2} direction={'column'}>
|
||||
<SDNumberInput
|
||||
label="Steps"
|
||||
min={1}
|
||||
step={1}
|
||||
precision={0}
|
||||
onChange={handleChangeSteps}
|
||||
value={steps}
|
||||
/>
|
||||
<SDNumberInput
|
||||
label="CFG scale"
|
||||
step={0.5}
|
||||
onChange={handleChangeCfgScale}
|
||||
value={cfgScale}
|
||||
/>
|
||||
<SDSelect
|
||||
label="Sampler"
|
||||
value={sampler}
|
||||
onChange={handleChangeSampler}
|
||||
validValues={SAMPLERS}
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default SamplerOptions;
|
||||
|
@ -1,144 +1,159 @@
|
||||
import {
|
||||
Flex,
|
||||
Input,
|
||||
HStack,
|
||||
FormControl,
|
||||
FormLabel,
|
||||
Text,
|
||||
Button,
|
||||
Flex,
|
||||
Input,
|
||||
HStack,
|
||||
FormControl,
|
||||
FormLabel,
|
||||
Text,
|
||||
Button,
|
||||
} from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { isEqual } from 'lodash';
|
||||
import { ChangeEvent } from 'react';
|
||||
import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from '../../app/constants';
|
||||
import { useAppDispatch, useAppSelector } from '../../app/hooks';
|
||||
import { useAppDispatch, useAppSelector } from '../../app/store';
|
||||
import { RootState } from '../../app/store';
|
||||
import SDNumberInput from '../../components/SDNumberInput';
|
||||
import SDSwitch from '../../components/SDSwitch';
|
||||
import SDNumberInput from '../../common/components/SDNumberInput';
|
||||
import SDSwitch from '../../common/components/SDSwitch';
|
||||
import randomInt from '../../common/util/randomInt';
|
||||
import { validateSeedWeights } from '../../common/util/seedWeightPairs';
|
||||
import {
|
||||
randomizeSeed,
|
||||
SDState,
|
||||
setIterations,
|
||||
setSeed,
|
||||
setSeedWeights,
|
||||
setShouldGenerateVariations,
|
||||
setShouldRandomizeSeed,
|
||||
setVariantAmount,
|
||||
SDState,
|
||||
setIterations,
|
||||
setSeed,
|
||||
setSeedWeights,
|
||||
setShouldGenerateVariations,
|
||||
setShouldRandomizeSeed,
|
||||
setVariationAmount,
|
||||
} from './sdSlice';
|
||||
import { validateSeedWeights } from './util/seedWeightPairs';
|
||||
|
||||
const sdSelector = createSelector(
|
||||
(state: RootState) => state.sd,
|
||||
(sd: SDState) => {
|
||||
return {
|
||||
variantAmount: sd.variantAmount,
|
||||
seedWeights: sd.seedWeights,
|
||||
shouldGenerateVariations: sd.shouldGenerateVariations,
|
||||
shouldRandomizeSeed: sd.shouldRandomizeSeed,
|
||||
seed: sd.seed,
|
||||
iterations: sd.iterations,
|
||||
};
|
||||
(state: RootState) => state.sd,
|
||||
(sd: SDState) => {
|
||||
return {
|
||||
variationAmount: sd.variationAmount,
|
||||
seedWeights: sd.seedWeights,
|
||||
shouldGenerateVariations: sd.shouldGenerateVariations,
|
||||
shouldRandomizeSeed: sd.shouldRandomizeSeed,
|
||||
seed: sd.seed,
|
||||
iterations: sd.iterations,
|
||||
};
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
/**
|
||||
* Seed & variation options. Includes iteration, seed, seed randomization, variation options.
|
||||
*/
|
||||
const SeedVariationOptions = () => {
|
||||
const {
|
||||
shouldGenerateVariations,
|
||||
variantAmount,
|
||||
seedWeights,
|
||||
shouldRandomizeSeed,
|
||||
seed,
|
||||
iterations,
|
||||
} = useAppSelector(sdSelector);
|
||||
const {
|
||||
shouldGenerateVariations,
|
||||
variationAmount,
|
||||
seedWeights,
|
||||
shouldRandomizeSeed,
|
||||
seed,
|
||||
iterations,
|
||||
} = useAppSelector(sdSelector);
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
return (
|
||||
<Flex gap={2} direction={'column'}>
|
||||
<SDNumberInput
|
||||
label='Images to generate'
|
||||
step={1}
|
||||
min={1}
|
||||
precision={0}
|
||||
onChange={(v) => dispatch(setIterations(Number(v)))}
|
||||
value={iterations}
|
||||
/>
|
||||
<SDSwitch
|
||||
label='Randomize seed on generation'
|
||||
isChecked={shouldRandomizeSeed}
|
||||
onChange={(e) =>
|
||||
dispatch(setShouldRandomizeSeed(e.target.checked))
|
||||
}
|
||||
/>
|
||||
<Flex gap={2}>
|
||||
<SDNumberInput
|
||||
label='Seed'
|
||||
step={1}
|
||||
precision={0}
|
||||
flexGrow={1}
|
||||
min={NUMPY_RAND_MIN}
|
||||
max={NUMPY_RAND_MAX}
|
||||
isDisabled={shouldRandomizeSeed}
|
||||
isInvalid={seed < 0 && shouldGenerateVariations}
|
||||
onChange={(v) => dispatch(setSeed(Number(v)))}
|
||||
value={seed}
|
||||
/>
|
||||
<Button
|
||||
size={'sm'}
|
||||
isDisabled={shouldRandomizeSeed}
|
||||
onClick={() => dispatch(randomizeSeed())}
|
||||
>
|
||||
<Text pl={2} pr={2}>
|
||||
Shuffle
|
||||
</Text>
|
||||
</Button>
|
||||
</Flex>
|
||||
<SDSwitch
|
||||
label='Generate variations'
|
||||
isChecked={shouldGenerateVariations}
|
||||
width={'auto'}
|
||||
onChange={(e) =>
|
||||
dispatch(setShouldGenerateVariations(e.target.checked))
|
||||
}
|
||||
/>
|
||||
<SDNumberInput
|
||||
label='Variation amount'
|
||||
value={variantAmount}
|
||||
step={0.01}
|
||||
min={0}
|
||||
max={1}
|
||||
isDisabled={!shouldGenerateVariations}
|
||||
onChange={(v) => dispatch(setVariantAmount(Number(v)))}
|
||||
/>
|
||||
<FormControl
|
||||
isInvalid={
|
||||
shouldGenerateVariations &&
|
||||
!(validateSeedWeights(seedWeights) || seedWeights === '')
|
||||
}
|
||||
flexGrow={1}
|
||||
isDisabled={!shouldGenerateVariations}
|
||||
>
|
||||
<HStack>
|
||||
<FormLabel marginInlineEnd={0} marginBottom={1}>
|
||||
<Text whiteSpace='nowrap'>
|
||||
Seed Weights
|
||||
</Text>
|
||||
</FormLabel>
|
||||
<Input
|
||||
size={'sm'}
|
||||
value={seedWeights}
|
||||
onChange={(e) =>
|
||||
dispatch(setSeedWeights(e.target.value))
|
||||
}
|
||||
/>
|
||||
</HStack>
|
||||
</FormControl>
|
||||
</Flex>
|
||||
);
|
||||
const handleChangeIterations = (v: string | number) =>
|
||||
dispatch(setIterations(Number(v)));
|
||||
|
||||
const handleChangeShouldRandomizeSeed = (e: ChangeEvent<HTMLInputElement>) =>
|
||||
dispatch(setShouldRandomizeSeed(e.target.checked));
|
||||
|
||||
const handleChangeSeed = (v: string | number) => dispatch(setSeed(Number(v)));
|
||||
|
||||
const handleClickRandomizeSeed = () =>
|
||||
dispatch(setSeed(randomInt(NUMPY_RAND_MIN, NUMPY_RAND_MAX)));
|
||||
|
||||
const handleChangeShouldGenerateVariations = (
|
||||
e: ChangeEvent<HTMLInputElement>
|
||||
) => dispatch(setShouldGenerateVariations(e.target.checked));
|
||||
|
||||
const handleChangevariationAmount = (v: string | number) =>
|
||||
dispatch(setVariationAmount(Number(v)));
|
||||
|
||||
const handleChangeSeedWeights = (e: ChangeEvent<HTMLInputElement>) =>
|
||||
dispatch(setSeedWeights(e.target.value));
|
||||
|
||||
return (
|
||||
<Flex gap={2} direction={'column'}>
|
||||
<SDNumberInput
|
||||
label="Images to generate"
|
||||
step={1}
|
||||
min={1}
|
||||
precision={0}
|
||||
onChange={handleChangeIterations}
|
||||
value={iterations}
|
||||
/>
|
||||
<SDSwitch
|
||||
label="Randomize seed on generation"
|
||||
isChecked={shouldRandomizeSeed}
|
||||
onChange={handleChangeShouldRandomizeSeed}
|
||||
/>
|
||||
<Flex gap={2}>
|
||||
<SDNumberInput
|
||||
label="Seed"
|
||||
step={1}
|
||||
precision={0}
|
||||
flexGrow={1}
|
||||
min={NUMPY_RAND_MIN}
|
||||
max={NUMPY_RAND_MAX}
|
||||
isDisabled={shouldRandomizeSeed}
|
||||
isInvalid={seed < 0 && shouldGenerateVariations}
|
||||
onChange={handleChangeSeed}
|
||||
value={seed}
|
||||
/>
|
||||
<Button
|
||||
size={'sm'}
|
||||
isDisabled={shouldRandomizeSeed}
|
||||
onClick={handleClickRandomizeSeed}
|
||||
>
|
||||
<Text pl={2} pr={2}>
|
||||
Shuffle
|
||||
</Text>
|
||||
</Button>
|
||||
</Flex>
|
||||
<SDSwitch
|
||||
label="Generate variations"
|
||||
isChecked={shouldGenerateVariations}
|
||||
width={'auto'}
|
||||
onChange={handleChangeShouldGenerateVariations}
|
||||
/>
|
||||
<SDNumberInput
|
||||
label="Variation amount"
|
||||
value={variationAmount}
|
||||
step={0.01}
|
||||
min={0}
|
||||
max={1}
|
||||
onChange={handleChangevariationAmount}
|
||||
/>
|
||||
<FormControl
|
||||
isInvalid={
|
||||
shouldGenerateVariations &&
|
||||
!(validateSeedWeights(seedWeights) || seedWeights === '')
|
||||
}
|
||||
flexGrow={1}
|
||||
>
|
||||
<HStack>
|
||||
<FormLabel marginInlineEnd={0} marginBottom={1}>
|
||||
<Text whiteSpace="nowrap">Seed Weights</Text>
|
||||
</FormLabel>
|
||||
<Input
|
||||
size={'sm'}
|
||||
value={seedWeights}
|
||||
onChange={handleChangeSeedWeights}
|
||||
/>
|
||||
</HStack>
|
||||
</FormControl>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default SeedVariationOptions;
|
||||
|
@ -1,92 +0,0 @@
|
||||
import {
|
||||
Flex,
|
||||
FormControl,
|
||||
FormLabel,
|
||||
HStack,
|
||||
Input,
|
||||
Text,
|
||||
} from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { isEqual } from 'lodash';
|
||||
import { useAppDispatch, useAppSelector } from '../../app/hooks';
|
||||
import { RootState } from '../../app/store';
|
||||
import SDNumberInput from '../../components/SDNumberInput';
|
||||
import SDSwitch from '../../components/SDSwitch';
|
||||
import {
|
||||
SDState,
|
||||
setSeedWeights,
|
||||
setShouldGenerateVariations,
|
||||
setVariantAmount,
|
||||
} from './sdSlice';
|
||||
import { validateSeedWeights } from './util/seedWeightPairs';
|
||||
|
||||
const sdSelector = createSelector(
|
||||
(state: RootState) => state.sd,
|
||||
(sd: SDState) => {
|
||||
return {
|
||||
variantAmount: sd.variantAmount,
|
||||
seedWeights: sd.seedWeights,
|
||||
shouldGenerateVariations: sd.shouldGenerateVariations,
|
||||
};
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
const Variant = () => {
|
||||
const { shouldGenerateVariations, variantAmount, seedWeights } =
|
||||
useAppSelector(sdSelector);
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
return (
|
||||
<Flex gap={2} alignItems={'center'} pl={1}>
|
||||
<SDSwitch
|
||||
label='Generate variations'
|
||||
isChecked={shouldGenerateVariations}
|
||||
width={'auto'}
|
||||
onChange={(e) =>
|
||||
dispatch(setShouldGenerateVariations(e.target.checked))
|
||||
}
|
||||
/>
|
||||
<SDNumberInput
|
||||
label='Amount'
|
||||
value={variantAmount}
|
||||
step={0.01}
|
||||
min={0}
|
||||
max={1}
|
||||
width={240}
|
||||
isDisabled={!shouldGenerateVariations}
|
||||
onChange={(v) => dispatch(setVariantAmount(Number(v)))}
|
||||
/>
|
||||
<FormControl
|
||||
isInvalid={
|
||||
shouldGenerateVariations &&
|
||||
!(validateSeedWeights(seedWeights) || seedWeights === '')
|
||||
}
|
||||
flexGrow={1}
|
||||
isDisabled={!shouldGenerateVariations}
|
||||
>
|
||||
<HStack>
|
||||
<FormLabel marginInlineEnd={0} marginBottom={1}>
|
||||
<Text fontSize={'sm'} whiteSpace='nowrap'>
|
||||
Seed Weights
|
||||
</Text>
|
||||
</FormLabel>
|
||||
<Input
|
||||
size={'sm'}
|
||||
value={seedWeights}
|
||||
onChange={(e) =>
|
||||
dispatch(setSeedWeights(e.target.value))
|
||||
}
|
||||
/>
|
||||
</HStack>
|
||||
</FormControl>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default Variant;
|
@ -1,24 +1,13 @@
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { SDMetadata } from '../gallery/gallerySlice';
|
||||
import randomInt from './util/randomInt';
|
||||
import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from '../../app/constants';
|
||||
|
||||
const calculateRealSteps = (
|
||||
steps: number,
|
||||
strength: number,
|
||||
hasInitImage: boolean
|
||||
): number => {
|
||||
return hasInitImage ? Math.floor(strength * steps) : steps;
|
||||
};
|
||||
|
||||
export type UpscalingLevel = 0 | 2 | 3 | 4;
|
||||
export type UpscalingLevel = 2 | 4;
|
||||
|
||||
export interface SDState {
|
||||
prompt: string;
|
||||
iterations: number;
|
||||
steps: number;
|
||||
realSteps: number;
|
||||
cfgScale: number;
|
||||
height: number;
|
||||
width: number;
|
||||
@ -34,7 +23,7 @@ export interface SDState {
|
||||
seamless: boolean;
|
||||
shouldFitToWidthHeight: boolean;
|
||||
shouldGenerateVariations: boolean;
|
||||
variantAmount: number;
|
||||
variationAmount: number;
|
||||
seedWeights: string;
|
||||
shouldRunESRGAN: boolean;
|
||||
shouldRunGFPGAN: boolean;
|
||||
@ -45,7 +34,6 @@ const initialSDState: SDState = {
|
||||
prompt: '',
|
||||
iterations: 1,
|
||||
steps: 50,
|
||||
realSteps: 50,
|
||||
cfgScale: 7.5,
|
||||
height: 512,
|
||||
width: 512,
|
||||
@ -58,7 +46,7 @@ const initialSDState: SDState = {
|
||||
maskPath: '',
|
||||
shouldFitToWidthHeight: true,
|
||||
shouldGenerateVariations: false,
|
||||
variantAmount: 0.1,
|
||||
variationAmount: 0.1,
|
||||
seedWeights: '',
|
||||
shouldRunESRGAN: false,
|
||||
upscalingLevel: 4,
|
||||
@ -81,14 +69,7 @@ export const sdSlice = createSlice({
|
||||
state.iterations = action.payload;
|
||||
},
|
||||
setSteps: (state, action: PayloadAction<number>) => {
|
||||
const { img2imgStrength, initialImagePath } = state;
|
||||
const steps = action.payload;
|
||||
state.steps = steps;
|
||||
state.realSteps = calculateRealSteps(
|
||||
steps,
|
||||
img2imgStrength,
|
||||
Boolean(initialImagePath)
|
||||
);
|
||||
state.steps = action.payload;
|
||||
},
|
||||
setCfgScale: (state, action: PayloadAction<number>) => {
|
||||
state.cfgScale = action.payload;
|
||||
@ -107,14 +88,7 @@ export const sdSlice = createSlice({
|
||||
state.shouldRandomizeSeed = false;
|
||||
},
|
||||
setImg2imgStrength: (state, action: PayloadAction<number>) => {
|
||||
const img2imgStrength = action.payload;
|
||||
const { steps, initialImagePath } = state;
|
||||
state.img2imgStrength = img2imgStrength;
|
||||
state.realSteps = calculateRealSteps(
|
||||
steps,
|
||||
img2imgStrength,
|
||||
Boolean(initialImagePath)
|
||||
);
|
||||
state.img2imgStrength = action.payload;
|
||||
},
|
||||
setGfpganStrength: (state, action: PayloadAction<number>) => {
|
||||
state.gfpganStrength = action.payload;
|
||||
@ -129,15 +103,9 @@ export const sdSlice = createSlice({
|
||||
state.shouldUseInitImage = action.payload;
|
||||
},
|
||||
setInitialImagePath: (state, action: PayloadAction<string>) => {
|
||||
const initialImagePath = action.payload;
|
||||
const { steps, img2imgStrength } = state;
|
||||
state.shouldUseInitImage = initialImagePath ? true : false;
|
||||
state.initialImagePath = initialImagePath;
|
||||
state.realSteps = calculateRealSteps(
|
||||
steps,
|
||||
img2imgStrength,
|
||||
Boolean(initialImagePath)
|
||||
);
|
||||
const newInitialImagePath = action.payload;
|
||||
state.shouldUseInitImage = newInitialImagePath ? true : false;
|
||||
state.initialImagePath = newInitialImagePath;
|
||||
},
|
||||
setMaskPath: (state, action: PayloadAction<string>) => {
|
||||
state.maskPath = action.payload;
|
||||
@ -151,13 +119,11 @@ export const sdSlice = createSlice({
|
||||
resetSeed: (state) => {
|
||||
state.seed = -1;
|
||||
},
|
||||
randomizeSeed: (state) => {
|
||||
state.seed = randomInt(NUMPY_RAND_MIN, NUMPY_RAND_MAX);
|
||||
},
|
||||
setParameter: (
|
||||
state,
|
||||
action: PayloadAction<{ key: string; value: string | number | boolean }>
|
||||
) => {
|
||||
// TODO: This probably needs to be refactored.
|
||||
const { key, value } = action.payload;
|
||||
const temp = { ...state, [key]: value };
|
||||
if (key === 'seed') {
|
||||
@ -171,13 +137,14 @@ export const sdSlice = createSlice({
|
||||
setShouldGenerateVariations: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldGenerateVariations = action.payload;
|
||||
},
|
||||
setVariantAmount: (state, action: PayloadAction<number>) => {
|
||||
state.variantAmount = action.payload;
|
||||
setVariationAmount: (state, action: PayloadAction<number>) => {
|
||||
state.variationAmount = action.payload;
|
||||
},
|
||||
setSeedWeights: (state, action: PayloadAction<string>) => {
|
||||
state.seedWeights = action.payload;
|
||||
},
|
||||
setAllParameters: (state, action: PayloadAction<SDMetadata>) => {
|
||||
// TODO: This probably needs to be refactored.
|
||||
const {
|
||||
prompt,
|
||||
steps,
|
||||
@ -267,13 +234,12 @@ export const {
|
||||
setInitialImagePath,
|
||||
setMaskPath,
|
||||
resetSeed,
|
||||
randomizeSeed,
|
||||
resetSDState,
|
||||
setShouldFitToWidthHeight,
|
||||
setParameter,
|
||||
setShouldGenerateVariations,
|
||||
setSeedWeights,
|
||||
setVariantAmount,
|
||||
setVariationAmount,
|
||||
setAllParameters,
|
||||
setShouldRunGFPGAN,
|
||||
setShouldRunESRGAN,
|
||||
|
@ -1,11 +1,11 @@
|
||||
import {
|
||||
IconButton,
|
||||
useColorModeValue,
|
||||
Flex,
|
||||
Text,
|
||||
Tooltip,
|
||||
IconButton,
|
||||
useColorModeValue,
|
||||
Flex,
|
||||
Text,
|
||||
Tooltip,
|
||||
} from '@chakra-ui/react';
|
||||
import { useAppDispatch, useAppSelector } from '../../app/hooks';
|
||||
import { useAppDispatch, useAppSelector } from '../../app/store';
|
||||
import { RootState } from '../../app/store';
|
||||
import { setShouldShowLogViewer, SystemState } from './systemSlice';
|
||||
import { useLayoutEffect, useRef, useState } from 'react';
|
||||
@ -14,112 +14,138 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { isEqual } from 'lodash';
|
||||
|
||||
const logSelector = createSelector(
|
||||
(state: RootState) => state.system,
|
||||
(system: SystemState) => system.log,
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: (a, b) => a.length === b.length,
|
||||
},
|
||||
}
|
||||
(state: RootState) => state.system,
|
||||
(system: SystemState) => system.log,
|
||||
{
|
||||
memoizeOptions: {
|
||||
// We don't need a deep equality check for this selector.
|
||||
resultEqualityCheck: (a, b) => a.length === b.length,
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
const systemSelector = createSelector(
|
||||
(state: RootState) => state.system,
|
||||
(system: SystemState) => {
|
||||
return { shouldShowLogViewer: system.shouldShowLogViewer };
|
||||
(state: RootState) => state.system,
|
||||
(system: SystemState) => {
|
||||
return { shouldShowLogViewer: system.shouldShowLogViewer };
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
/**
|
||||
* Basic log viewer, floats on bottom of page.
|
||||
*/
|
||||
const LogViewer = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const bg = useColorModeValue('gray.50', 'gray.900');
|
||||
const borderColor = useColorModeValue('gray.500', 'gray.500');
|
||||
const [shouldAutoscroll, setShouldAutoscroll] = useState<boolean>(true);
|
||||
const dispatch = useAppDispatch();
|
||||
const log = useAppSelector(logSelector);
|
||||
const { shouldShowLogViewer } = useAppSelector(systemSelector);
|
||||
|
||||
const log = useAppSelector(logSelector);
|
||||
const { shouldShowLogViewer } = useAppSelector(systemSelector);
|
||||
// Set colors based on dark/light mode
|
||||
const bg = useColorModeValue('gray.50', 'gray.900');
|
||||
const borderColor = useColorModeValue('gray.500', 'gray.500');
|
||||
const logTextColors = useColorModeValue(
|
||||
{
|
||||
info: undefined,
|
||||
warning: 'yellow.500',
|
||||
error: 'red.500',
|
||||
},
|
||||
{
|
||||
info: undefined,
|
||||
warning: 'yellow.300',
|
||||
error: 'red.300',
|
||||
}
|
||||
);
|
||||
|
||||
const viewerRef = useRef<HTMLDivElement>(null);
|
||||
// Rudimentary autoscroll
|
||||
const [shouldAutoscroll, setShouldAutoscroll] = useState<boolean>(true);
|
||||
const viewerRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
useLayoutEffect(() => {
|
||||
if (viewerRef.current !== null && shouldAutoscroll) {
|
||||
viewerRef.current.scrollTop = viewerRef.current.scrollHeight;
|
||||
}
|
||||
});
|
||||
/**
|
||||
* If autoscroll is on, scroll to the bottom when:
|
||||
* - log updates
|
||||
* - viewer is toggled
|
||||
*
|
||||
* Also scroll to the bottom whenever autoscroll is turned on.
|
||||
*/
|
||||
useLayoutEffect(() => {
|
||||
if (viewerRef.current !== null && shouldAutoscroll) {
|
||||
viewerRef.current.scrollTop = viewerRef.current.scrollHeight;
|
||||
}
|
||||
}, [shouldAutoscroll, log, shouldShowLogViewer]);
|
||||
|
||||
return (
|
||||
<>
|
||||
{shouldShowLogViewer && (
|
||||
<Flex
|
||||
position={'fixed'}
|
||||
left={0}
|
||||
bottom={0}
|
||||
height='200px'
|
||||
width='100vw'
|
||||
overflow='auto'
|
||||
direction='column'
|
||||
fontFamily='monospace'
|
||||
fontSize='sm'
|
||||
pl={12}
|
||||
pr={2}
|
||||
pb={2}
|
||||
borderTopWidth='4px'
|
||||
borderColor={borderColor}
|
||||
background={bg}
|
||||
ref={viewerRef}
|
||||
>
|
||||
{log.map((entry, i) => (
|
||||
<Flex gap={2} key={i}>
|
||||
<Text fontSize='sm' fontWeight={'semibold'}>
|
||||
{entry.timestamp}:
|
||||
</Text>
|
||||
<Text fontSize='sm' wordBreak={'break-all'}>
|
||||
{entry.message}
|
||||
</Text>
|
||||
</Flex>
|
||||
))}
|
||||
</Flex>
|
||||
)}
|
||||
{shouldShowLogViewer && (
|
||||
<Tooltip
|
||||
label={
|
||||
shouldAutoscroll ? 'Autoscroll on' : 'Autoscroll off'
|
||||
}
|
||||
>
|
||||
<IconButton
|
||||
size='sm'
|
||||
position={'fixed'}
|
||||
left={2}
|
||||
bottom={12}
|
||||
aria-label='Toggle autoscroll'
|
||||
variant={'solid'}
|
||||
colorScheme={shouldAutoscroll ? 'blue' : 'gray'}
|
||||
icon={<FaAngleDoubleDown />}
|
||||
onClick={() => setShouldAutoscroll(!shouldAutoscroll)}
|
||||
/>
|
||||
</Tooltip>
|
||||
)}
|
||||
<Tooltip label={shouldShowLogViewer ? 'Hide logs' : 'Show logs'}>
|
||||
<IconButton
|
||||
size='sm'
|
||||
position={'fixed'}
|
||||
left={2}
|
||||
bottom={2}
|
||||
variant={'solid'}
|
||||
aria-label='Toggle Log Viewer'
|
||||
icon={shouldShowLogViewer ? <FaMinus /> : <FaCode />}
|
||||
onClick={() =>
|
||||
dispatch(setShouldShowLogViewer(!shouldShowLogViewer))
|
||||
}
|
||||
/>
|
||||
</Tooltip>
|
||||
</>
|
||||
);
|
||||
const handleClickLogViewerToggle = () => {
|
||||
dispatch(setShouldShowLogViewer(!shouldShowLogViewer));
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
{shouldShowLogViewer && (
|
||||
<Flex
|
||||
position={'fixed'}
|
||||
left={0}
|
||||
bottom={0}
|
||||
height="200px" // TODO: Make the log viewer resizeable.
|
||||
width="100vw"
|
||||
overflow="auto"
|
||||
direction="column"
|
||||
fontFamily="monospace"
|
||||
fontSize="sm"
|
||||
pl={12}
|
||||
pr={2}
|
||||
pb={2}
|
||||
borderTopWidth="4px"
|
||||
borderColor={borderColor}
|
||||
background={bg}
|
||||
ref={viewerRef}
|
||||
>
|
||||
{log.map((entry, i) => {
|
||||
const { timestamp, message, level } = entry;
|
||||
return (
|
||||
<Flex gap={2} key={i} textColor={logTextColors[level]}>
|
||||
<Text fontSize="sm" fontWeight={'semibold'}>
|
||||
{timestamp}:
|
||||
</Text>
|
||||
<Text fontSize="sm" wordBreak={'break-all'}>
|
||||
{message}
|
||||
</Text>
|
||||
</Flex>
|
||||
);
|
||||
})}
|
||||
</Flex>
|
||||
)}
|
||||
{shouldShowLogViewer && (
|
||||
<Tooltip label={shouldAutoscroll ? 'Autoscroll on' : 'Autoscroll off'}>
|
||||
<IconButton
|
||||
size="sm"
|
||||
position={'fixed'}
|
||||
left={2}
|
||||
bottom={12}
|
||||
aria-label="Toggle autoscroll"
|
||||
variant={'solid'}
|
||||
colorScheme={shouldAutoscroll ? 'blue' : 'gray'}
|
||||
icon={<FaAngleDoubleDown />}
|
||||
onClick={() => setShouldAutoscroll(!shouldAutoscroll)}
|
||||
/>
|
||||
</Tooltip>
|
||||
)}
|
||||
<Tooltip label={shouldShowLogViewer ? 'Hide logs' : 'Show logs'}>
|
||||
<IconButton
|
||||
size="sm"
|
||||
position={'fixed'}
|
||||
left={2}
|
||||
bottom={2}
|
||||
variant={'solid'}
|
||||
aria-label="Toggle Log Viewer"
|
||||
icon={shouldShowLogViewer ? <FaMinus /> : <FaCode />}
|
||||
onClick={handleClickLogViewerToggle}
|
||||
/>
|
||||
</Tooltip>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default LogViewer;
|
||||
|
@ -1,170 +1,164 @@
|
||||
import {
|
||||
Flex,
|
||||
FormControl,
|
||||
FormLabel,
|
||||
Heading,
|
||||
HStack,
|
||||
Modal,
|
||||
ModalBody,
|
||||
ModalCloseButton,
|
||||
ModalContent,
|
||||
ModalFooter,
|
||||
ModalHeader,
|
||||
ModalOverlay,
|
||||
Switch,
|
||||
Text,
|
||||
useDisclosure,
|
||||
Button,
|
||||
Flex,
|
||||
FormControl,
|
||||
FormLabel,
|
||||
Heading,
|
||||
HStack,
|
||||
Modal,
|
||||
ModalBody,
|
||||
ModalCloseButton,
|
||||
ModalContent,
|
||||
ModalFooter,
|
||||
ModalHeader,
|
||||
ModalOverlay,
|
||||
Switch,
|
||||
Text,
|
||||
useDisclosure,
|
||||
} from '@chakra-ui/react';
|
||||
import { useAppDispatch, useAppSelector } from '../../app/hooks';
|
||||
import { useAppDispatch, useAppSelector } from '../../app/store';
|
||||
import {
|
||||
setShouldConfirmOnDelete,
|
||||
setShouldDisplayInProgress,
|
||||
SystemState,
|
||||
setShouldConfirmOnDelete,
|
||||
setShouldDisplayInProgress,
|
||||
SystemState,
|
||||
} from './systemSlice';
|
||||
import { RootState } from '../../app/store';
|
||||
import SDButton from '../../components/SDButton';
|
||||
import { persistor } from '../../main';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { isEqual } from 'lodash';
|
||||
import { cloneElement, ReactElement } from 'react';
|
||||
|
||||
const systemSelector = createSelector(
|
||||
(state: RootState) => state.system,
|
||||
(system: SystemState) => {
|
||||
const { shouldDisplayInProgress, shouldConfirmOnDelete } = system;
|
||||
return { shouldDisplayInProgress, shouldConfirmOnDelete };
|
||||
},
|
||||
{
|
||||
memoizeOptions: { resultEqualityCheck: isEqual },
|
||||
}
|
||||
(state: RootState) => state.system,
|
||||
(system: SystemState) => {
|
||||
const { shouldDisplayInProgress, shouldConfirmOnDelete } = system;
|
||||
return { shouldDisplayInProgress, shouldConfirmOnDelete };
|
||||
},
|
||||
{
|
||||
memoizeOptions: { resultEqualityCheck: isEqual },
|
||||
}
|
||||
);
|
||||
|
||||
type Props = {
|
||||
children: ReactElement;
|
||||
type SettingsModalProps = {
|
||||
/* The button to open the Settings Modal */
|
||||
children: ReactElement;
|
||||
};
|
||||
|
||||
const SettingsModal = ({ children }: Props) => {
|
||||
const {
|
||||
isOpen: isSettingsModalOpen,
|
||||
onOpen: onSettingsModalOpen,
|
||||
onClose: onSettingsModalClose,
|
||||
} = useDisclosure();
|
||||
/**
|
||||
* Modal for app settings. Also provides Reset functionality in which the
|
||||
* app's localstorage is wiped via redux-persist.
|
||||
*
|
||||
* Secondary post-reset modal is included here.
|
||||
*/
|
||||
const SettingsModal = ({ children }: SettingsModalProps) => {
|
||||
const {
|
||||
isOpen: isSettingsModalOpen,
|
||||
onOpen: onSettingsModalOpen,
|
||||
onClose: onSettingsModalClose,
|
||||
} = useDisclosure();
|
||||
|
||||
const {
|
||||
isOpen: isRefreshModalOpen,
|
||||
onOpen: onRefreshModalOpen,
|
||||
onClose: onRefreshModalClose,
|
||||
} = useDisclosure();
|
||||
const {
|
||||
isOpen: isRefreshModalOpen,
|
||||
onOpen: onRefreshModalOpen,
|
||||
onClose: onRefreshModalClose,
|
||||
} = useDisclosure();
|
||||
|
||||
const { shouldDisplayInProgress, shouldConfirmOnDelete } =
|
||||
useAppSelector(systemSelector);
|
||||
const { shouldDisplayInProgress, shouldConfirmOnDelete } =
|
||||
useAppSelector(systemSelector);
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const handleClickResetWebUI = () => {
|
||||
persistor.purge().then(() => {
|
||||
onSettingsModalClose();
|
||||
onRefreshModalOpen();
|
||||
});
|
||||
};
|
||||
/**
|
||||
* Resets localstorage, then opens a secondary modal informing user to
|
||||
* refresh their browser.
|
||||
* */
|
||||
const handleClickResetWebUI = () => {
|
||||
persistor.purge().then(() => {
|
||||
onSettingsModalClose();
|
||||
onRefreshModalOpen();
|
||||
});
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
{cloneElement(children, {
|
||||
onClick: onSettingsModalOpen,
|
||||
})}
|
||||
return (
|
||||
<>
|
||||
{cloneElement(children, {
|
||||
onClick: onSettingsModalOpen,
|
||||
})}
|
||||
|
||||
<Modal isOpen={isSettingsModalOpen} onClose={onSettingsModalClose}>
|
||||
<ModalOverlay />
|
||||
<ModalContent>
|
||||
<ModalHeader>Settings</ModalHeader>
|
||||
<ModalCloseButton />
|
||||
<ModalBody>
|
||||
<Flex gap={5} direction='column'>
|
||||
<FormControl>
|
||||
<HStack>
|
||||
<FormLabel marginBottom={1}>
|
||||
Display in-progress images (slower)
|
||||
</FormLabel>
|
||||
<Switch
|
||||
isChecked={shouldDisplayInProgress}
|
||||
onChange={(e) =>
|
||||
dispatch(
|
||||
setShouldDisplayInProgress(
|
||||
e.target.checked
|
||||
)
|
||||
)
|
||||
}
|
||||
/>
|
||||
</HStack>
|
||||
</FormControl>
|
||||
<FormControl>
|
||||
<HStack>
|
||||
<FormLabel marginBottom={1}>
|
||||
Confirm on delete
|
||||
</FormLabel>
|
||||
<Switch
|
||||
isChecked={shouldConfirmOnDelete}
|
||||
onChange={(e) =>
|
||||
dispatch(
|
||||
setShouldConfirmOnDelete(
|
||||
e.target.checked
|
||||
)
|
||||
)
|
||||
}
|
||||
/>
|
||||
</HStack>
|
||||
</FormControl>
|
||||
<Modal isOpen={isSettingsModalOpen} onClose={onSettingsModalClose}>
|
||||
<ModalOverlay />
|
||||
<ModalContent>
|
||||
<ModalHeader>Settings</ModalHeader>
|
||||
<ModalCloseButton />
|
||||
<ModalBody>
|
||||
<Flex gap={5} direction="column">
|
||||
<FormControl>
|
||||
<HStack>
|
||||
<FormLabel marginBottom={1}>
|
||||
Display in-progress images (slower)
|
||||
</FormLabel>
|
||||
<Switch
|
||||
isChecked={shouldDisplayInProgress}
|
||||
onChange={(e) =>
|
||||
dispatch(setShouldDisplayInProgress(e.target.checked))
|
||||
}
|
||||
/>
|
||||
</HStack>
|
||||
</FormControl>
|
||||
<FormControl>
|
||||
<HStack>
|
||||
<FormLabel marginBottom={1}>Confirm on delete</FormLabel>
|
||||
<Switch
|
||||
isChecked={shouldConfirmOnDelete}
|
||||
onChange={(e) =>
|
||||
dispatch(setShouldConfirmOnDelete(e.target.checked))
|
||||
}
|
||||
/>
|
||||
</HStack>
|
||||
</FormControl>
|
||||
|
||||
<Heading size={'md'}>Reset Web UI</Heading>
|
||||
<Text>
|
||||
Resetting the web UI only resets the browser's
|
||||
local cache of your images and remembered
|
||||
settings. It does not delete any images from
|
||||
disk.
|
||||
</Text>
|
||||
<Text>
|
||||
If images aren't showing up in the gallery or
|
||||
something else isn't working, please try
|
||||
resetting before submitting an issue on GitHub.
|
||||
</Text>
|
||||
<SDButton
|
||||
label='Reset Web UI'
|
||||
colorScheme='red'
|
||||
onClick={handleClickResetWebUI}
|
||||
/>
|
||||
</Flex>
|
||||
</ModalBody>
|
||||
<Heading size={'md'}>Reset Web UI</Heading>
|
||||
<Text>
|
||||
Resetting the web UI only resets the browser's local cache of
|
||||
your images and remembered settings. It does not delete any
|
||||
images from disk.
|
||||
</Text>
|
||||
<Text>
|
||||
If images aren't showing up in the gallery or something else
|
||||
isn't working, please try resetting before submitting an issue
|
||||
on GitHub.
|
||||
</Text>
|
||||
<Button colorScheme="red" onClick={handleClickResetWebUI}>
|
||||
Reset Web UI
|
||||
</Button>
|
||||
</Flex>
|
||||
</ModalBody>
|
||||
|
||||
<ModalFooter>
|
||||
<SDButton
|
||||
label='Close'
|
||||
onClick={onSettingsModalClose}
|
||||
/>
|
||||
</ModalFooter>
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
<ModalFooter>
|
||||
<Button onClick={onSettingsModalClose}>Close</Button>
|
||||
</ModalFooter>
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
|
||||
<Modal
|
||||
closeOnOverlayClick={false}
|
||||
isOpen={isRefreshModalOpen}
|
||||
onClose={onRefreshModalClose}
|
||||
isCentered
|
||||
>
|
||||
<ModalOverlay bg='blackAlpha.300' backdropFilter='blur(40px)' />
|
||||
<ModalContent>
|
||||
<ModalBody pb={6} pt={6}>
|
||||
<Flex justifyContent={'center'}>
|
||||
<Text fontSize={'lg'}>
|
||||
Web UI has been reset. Refresh the page to
|
||||
reload.
|
||||
</Text>
|
||||
</Flex>
|
||||
</ModalBody>
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
</>
|
||||
);
|
||||
<Modal
|
||||
closeOnOverlayClick={false}
|
||||
isOpen={isRefreshModalOpen}
|
||||
onClose={onRefreshModalClose}
|
||||
isCentered
|
||||
>
|
||||
<ModalOverlay bg="blackAlpha.300" backdropFilter="blur(40px)" />
|
||||
<ModalContent>
|
||||
<ModalBody pb={6} pt={6}>
|
||||
<Flex justifyContent={'center'}>
|
||||
<Text fontSize={'lg'}>
|
||||
Web UI has been reset. Refresh the page to reload.
|
||||
</Text>
|
||||
</Flex>
|
||||
</ModalBody>
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default SettingsModal;
|
||||
|
@ -1,10 +1,12 @@
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import dateFormat from 'dateformat';
|
||||
import { ExpandedIndex } from '@chakra-ui/react';
|
||||
|
||||
export type LogLevel = 'info' | 'warning' | 'error';
|
||||
|
||||
export interface LogEntry {
|
||||
timestamp: string;
|
||||
level: LogLevel;
|
||||
message: string;
|
||||
}
|
||||
|
||||
@ -12,10 +14,18 @@ export interface Log {
|
||||
[index: number]: LogEntry;
|
||||
}
|
||||
|
||||
export interface SystemState {
|
||||
shouldDisplayInProgress: boolean;
|
||||
export interface SystemStatus {
|
||||
isProcessing: boolean;
|
||||
currentStep: number;
|
||||
totalSteps: number;
|
||||
currentIteration: number;
|
||||
totalIterations: number;
|
||||
currentStatus: string;
|
||||
currentStatusHasSteps: boolean;
|
||||
}
|
||||
|
||||
export interface SystemState extends SystemStatus {
|
||||
shouldDisplayInProgress: boolean;
|
||||
log: Array<LogEntry>;
|
||||
shouldShowLogViewer: boolean;
|
||||
isGFPGANAvailable: boolean;
|
||||
@ -24,12 +34,17 @@ export interface SystemState {
|
||||
socketId: string;
|
||||
shouldConfirmOnDelete: boolean;
|
||||
openAccordions: ExpandedIndex;
|
||||
currentStep: number;
|
||||
totalSteps: number;
|
||||
currentIteration: number;
|
||||
totalIterations: number;
|
||||
currentStatus: string;
|
||||
currentStatusHasSteps: boolean;
|
||||
}
|
||||
|
||||
const initialSystemState = {
|
||||
isConnected: false,
|
||||
isProcessing: false,
|
||||
currentStep: 0,
|
||||
log: [],
|
||||
shouldShowLogViewer: false,
|
||||
shouldDisplayInProgress: false,
|
||||
@ -38,6 +53,12 @@ const initialSystemState = {
|
||||
socketId: '',
|
||||
shouldConfirmOnDelete: true,
|
||||
openAccordions: [0],
|
||||
currentStep: 0,
|
||||
totalSteps: 0,
|
||||
currentIteration: 0,
|
||||
totalIterations: 0,
|
||||
currentStatus: '',
|
||||
currentStatusHasSteps: false,
|
||||
};
|
||||
|
||||
const initialState: SystemState = initialSystemState;
|
||||
@ -51,18 +72,35 @@ export const systemSlice = createSlice({
|
||||
},
|
||||
setIsProcessing: (state, action: PayloadAction<boolean>) => {
|
||||
state.isProcessing = action.payload;
|
||||
if (action.payload === false) {
|
||||
state.currentStep = 0;
|
||||
}
|
||||
},
|
||||
setCurrentStep: (state, action: PayloadAction<number>) => {
|
||||
state.currentStep = action.payload;
|
||||
setCurrentStatus: (state, action: PayloadAction<string>) => {
|
||||
state.currentStatus = action.payload;
|
||||
},
|
||||
addLogEntry: (state, action: PayloadAction<string>) => {
|
||||
setSystemStatus: (state, action: PayloadAction<SystemStatus>) => {
|
||||
const currentStatus =
|
||||
!action.payload.isProcessing && state.isConnected
|
||||
? 'Connected'
|
||||
: action.payload.currentStatus;
|
||||
|
||||
return { ...state, ...action.payload, currentStatus };
|
||||
},
|
||||
addLogEntry: (
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
timestamp: string;
|
||||
message: string;
|
||||
level?: LogLevel;
|
||||
}>
|
||||
) => {
|
||||
const { timestamp, message, level } = action.payload;
|
||||
const logLevel = level || 'info';
|
||||
|
||||
const entry: LogEntry = {
|
||||
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||
message: action.payload,
|
||||
timestamp,
|
||||
message,
|
||||
level: logLevel,
|
||||
};
|
||||
|
||||
state.log.push(entry);
|
||||
},
|
||||
setShouldShowLogViewer: (state, action: PayloadAction<boolean>) => {
|
||||
@ -86,13 +124,14 @@ export const systemSlice = createSlice({
|
||||
export const {
|
||||
setShouldDisplayInProgress,
|
||||
setIsProcessing,
|
||||
setCurrentStep,
|
||||
addLogEntry,
|
||||
setShouldShowLogViewer,
|
||||
setIsConnected,
|
||||
setSocketId,
|
||||
setShouldConfirmOnDelete,
|
||||
setOpenAccordions,
|
||||
setSystemStatus,
|
||||
setCurrentStatus,
|
||||
} = systemSlice.actions;
|
||||
|
||||
export default systemSlice.reducer;
|
||||
|
@ -1,108 +0,0 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { isEqual } from 'lodash';
|
||||
import { useMemo } from 'react';
|
||||
import { useAppSelector } from '../../app/hooks';
|
||||
import { RootState } from '../../app/store';
|
||||
import { SDState } from '../sd/sdSlice';
|
||||
import { validateSeedWeights } from '../sd/util/seedWeightPairs';
|
||||
import { SystemState } from './systemSlice';
|
||||
|
||||
const sdSelector = createSelector(
|
||||
(state: RootState) => state.sd,
|
||||
(sd: SDState) => {
|
||||
return {
|
||||
prompt: sd.prompt,
|
||||
shouldGenerateVariations: sd.shouldGenerateVariations,
|
||||
seedWeights: sd.seedWeights,
|
||||
maskPath: sd.maskPath,
|
||||
initialImagePath: sd.initialImagePath,
|
||||
seed: sd.seed,
|
||||
};
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
const systemSelector = createSelector(
|
||||
(state: RootState) => state.system,
|
||||
(system: SystemState) => {
|
||||
return {
|
||||
isProcessing: system.isProcessing,
|
||||
isConnected: system.isConnected,
|
||||
};
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
/*
|
||||
Checks relevant pieces of state to confirm generation will not deterministically fail.
|
||||
|
||||
This is used to prevent the 'Generate' button from being clicked.
|
||||
|
||||
Other parameter values may cause failure but we rely on input validation for those.
|
||||
*/
|
||||
const useCheckParameters = () => {
|
||||
const {
|
||||
prompt,
|
||||
shouldGenerateVariations,
|
||||
seedWeights,
|
||||
maskPath,
|
||||
initialImagePath,
|
||||
seed,
|
||||
} = useAppSelector(sdSelector);
|
||||
|
||||
const { isProcessing, isConnected } = useAppSelector(systemSelector);
|
||||
|
||||
return useMemo(() => {
|
||||
// Cannot generate without a prompt
|
||||
if (!prompt) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Cannot generate with a mask without img2img
|
||||
if (maskPath && !initialImagePath) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: job queue
|
||||
// Cannot generate if already processing an image
|
||||
if (isProcessing) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Cannot generate if not connected
|
||||
if (!isConnected) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Cannot generate variations without valid seed weights
|
||||
if (
|
||||
shouldGenerateVariations &&
|
||||
(!(validateSeedWeights(seedWeights) || seedWeights === '') ||
|
||||
seed === -1)
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// All good
|
||||
return true;
|
||||
}, [
|
||||
prompt,
|
||||
maskPath,
|
||||
initialImagePath,
|
||||
isProcessing,
|
||||
isConnected,
|
||||
shouldGenerateVariations,
|
||||
seedWeights,
|
||||
seed,
|
||||
]);
|
||||
};
|
||||
|
||||
export default useCheckParameters;
|
@ -8,9 +8,9 @@ import { persistStore } from 'redux-persist';
|
||||
|
||||
export const persistor = persistStore(store);
|
||||
|
||||
import App from './App';
|
||||
import { theme } from './app/theme';
|
||||
import Loading from './Loading';
|
||||
import App from './app/App';
|
||||
|
||||
ReactDOM.createRoot(document.getElementById('root') as HTMLElement).render(
|
||||
<React.StrictMode>
|
||||
|
@ -2,7 +2,10 @@
|
||||
|
||||
The Args class parses both the command line (shell) arguments, as well as the
|
||||
command string passed at the dream> prompt. It serves as the definitive repository
|
||||
of all the arguments used by Generate and their default values.
|
||||
of all the arguments used by Generate and their default values, and implements the
|
||||
preliminary metadata standards discussed here:
|
||||
|
||||
https://github.com/lstein/stable-diffusion/issues/266
|
||||
|
||||
To use:
|
||||
opt = Args()
|
||||
@ -52,15 +55,38 @@ you wish to apply logic as to which one to use. For example:
|
||||
To add new attributes, edit the _create_arg_parser() and
|
||||
_create_dream_cmd_parser() methods.
|
||||
|
||||
We also export the function build_metadata
|
||||
**Generating and retrieving sd-metadata**
|
||||
|
||||
To generate a dict representing RFC266 metadata:
|
||||
|
||||
metadata = metadata_dumps(opt,<seeds,model_hash,postprocesser>)
|
||||
|
||||
This will generate an RFC266 dictionary that can then be turned into a JSON
|
||||
and written to the PNG file. The optional seeds, weights, model_hash and
|
||||
postprocesser arguments are not available to the opt object and so must be
|
||||
provided externally. See how dream.py does it.
|
||||
|
||||
Note that this function was originally called format_metadata() and a wrapper
|
||||
is provided that issues a deprecation notice.
|
||||
|
||||
To retrieve a (series of) opt objects corresponding to the metadata, do this:
|
||||
|
||||
opt_list = metadata_loads(metadata)
|
||||
|
||||
The metadata should be pulled out of the PNG image. pngwriter has a method
|
||||
retrieve_metadata that will do this.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from argparse import Namespace
|
||||
import shlex
|
||||
import json
|
||||
import hashlib
|
||||
import os
|
||||
import copy
|
||||
import base64
|
||||
from ldm.dream.conditioning import split_weighted_subprompts
|
||||
|
||||
SAMPLER_CHOICES = [
|
||||
@ -105,6 +131,7 @@ class Args(object):
|
||||
try:
|
||||
elements = shlex.split(command)
|
||||
except ValueError:
|
||||
import sys, traceback
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
return
|
||||
switches = ['']
|
||||
@ -141,24 +168,26 @@ class Args(object):
|
||||
a = vars(self)
|
||||
a.update(kwargs)
|
||||
switches = list()
|
||||
switches.append(f'"{a["prompt"]}')
|
||||
switches.append(f'"{a["prompt"]}"')
|
||||
switches.append(f'-s {a["steps"]}')
|
||||
switches.append(f'-S {a["seed"]}')
|
||||
switches.append(f'-W {a["width"]}')
|
||||
switches.append(f'-H {a["height"]}')
|
||||
switches.append(f'-C {a["cfg_scale"]}')
|
||||
switches.append(f'-A {a["sampler_name"]}')
|
||||
switches.append(f'-S {a["seed"]}')
|
||||
if a['grid']:
|
||||
switches.append('--grid')
|
||||
if a['iterations'] and a['iterations']>0:
|
||||
switches.append(f'-n {a["iterations"]}')
|
||||
if a['seamless']:
|
||||
switches.append('--seamless')
|
||||
if a['init_img'] and len(a['init_img'])>0:
|
||||
switches.append(f'-I {a["init_img"]}')
|
||||
if a['init_mask'] and len(a['init_mask'])>0:
|
||||
switches.append(f'-M {a["init_mask"]}')
|
||||
if a['init_color'] and len(a['init_color'])>0:
|
||||
switches.append(f'--init_color {a["init_color"]}')
|
||||
if a['fit']:
|
||||
switches.append(f'--fit')
|
||||
if a['strength'] and a['strength']>0:
|
||||
if a['init_img'] and a['strength'] and a['strength']>0:
|
||||
switches.append(f'-f {a["strength"]}')
|
||||
if a['gfpgan_strength']:
|
||||
switches.append(f'-G {a["gfpgan_strength"]}')
|
||||
@ -189,10 +218,10 @@ class Args(object):
|
||||
pass
|
||||
|
||||
if cmd_switches and arg_switches and name=='__dict__':
|
||||
a = arg_switches.__dict__
|
||||
a.update(cmd_switches.__dict__)
|
||||
return a
|
||||
|
||||
return self._merge_dict(
|
||||
arg_switches.__dict__,
|
||||
cmd_switches.__dict__,
|
||||
)
|
||||
try:
|
||||
return object.__getattribute__(self,name)
|
||||
except AttributeError:
|
||||
@ -216,13 +245,8 @@ class Args(object):
|
||||
# the arg value. For example, the --grid and --individual options are a little
|
||||
# funny because of their push/pull relationship. This is how to handle it.
|
||||
if name=='grid':
|
||||
return value_arg or value_cmd # arg supersedes cmd
|
||||
if name=='individual':
|
||||
return value_cmd or value_arg # cmd supersedes arg
|
||||
if value_cmd is not None:
|
||||
return value_cmd
|
||||
else:
|
||||
return value_arg
|
||||
return not cmd_switches.individual and value_arg # arg supersedes cmd
|
||||
return value_cmd if value_cmd is not None else value_arg
|
||||
|
||||
def __setattr__(self,name,value):
|
||||
if name.startswith('_'):
|
||||
@ -230,6 +254,14 @@ class Args(object):
|
||||
else:
|
||||
self._cmd_switches.__dict__[name] = value
|
||||
|
||||
def _merge_dict(self,dict1,dict2):
|
||||
new_dict = {}
|
||||
for k in set(list(dict1.keys())+list(dict2.keys())):
|
||||
value1 = dict1.get(k,None)
|
||||
value2 = dict2.get(k,None)
|
||||
new_dict[k] = value2 if value2 is not None else value1
|
||||
return new_dict
|
||||
|
||||
def _create_arg_parser(self):
|
||||
'''
|
||||
This defines all the arguments used on the command line when you launch
|
||||
@ -268,6 +300,17 @@ class Args(object):
|
||||
default='stable-diffusion-1.4',
|
||||
help='Indicates which diffusion model to load. (currently "stable-diffusion-1.4" (default) or "laion400m")',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--sampler',
|
||||
'-A',
|
||||
'-m',
|
||||
dest='sampler_name',
|
||||
type=str,
|
||||
choices=SAMPLER_CHOICES,
|
||||
metavar='SAMPLER_NAME',
|
||||
help=f'Switch to a different sampler. Supported samplers: {", ".join(SAMPLER_CHOICES)}',
|
||||
default='k_lms',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'-F',
|
||||
'--full_precision',
|
||||
@ -294,11 +337,6 @@ class Args(object):
|
||||
action='store_true',
|
||||
help='Place images in subdirectories named after the prompt.',
|
||||
)
|
||||
render_group.add_argument(
|
||||
'--seamless',
|
||||
action='store_true',
|
||||
help='Change the model to seamless tiling (circular) mode',
|
||||
)
|
||||
render_group.add_argument(
|
||||
'--grid',
|
||||
'-g',
|
||||
@ -393,14 +431,12 @@ class Args(object):
|
||||
'--width',
|
||||
type=int,
|
||||
help='Image width, multiple of 64',
|
||||
default=512
|
||||
)
|
||||
render_group.add_argument(
|
||||
'-H',
|
||||
'--height',
|
||||
type=int,
|
||||
help='Image height, multiple of 64',
|
||||
default=512,
|
||||
)
|
||||
render_group.add_argument(
|
||||
'-C',
|
||||
@ -416,8 +452,8 @@ class Args(object):
|
||||
help='generate a grid'
|
||||
)
|
||||
render_group.add_argument(
|
||||
'--individual',
|
||||
'-i',
|
||||
'--individual',
|
||||
action='store_true',
|
||||
help='override command-line --grid setting and generate individual images'
|
||||
)
|
||||
@ -436,7 +472,6 @@ class Args(object):
|
||||
choices=SAMPLER_CHOICES,
|
||||
metavar='SAMPLER_NAME',
|
||||
help=f'Switch to a different sampler. Supported samplers: {", ".join(SAMPLER_CHOICES)}',
|
||||
default='k_lms',
|
||||
)
|
||||
render_group.add_argument(
|
||||
'-t',
|
||||
@ -448,7 +483,6 @@ class Args(object):
|
||||
'--outdir',
|
||||
'-o',
|
||||
type=str,
|
||||
default='outputs/img-samples',
|
||||
help='Directory to save generated images and a log of prompts and seeds',
|
||||
)
|
||||
img2img_group.add_argument(
|
||||
@ -463,6 +497,11 @@ class Args(object):
|
||||
type=str,
|
||||
help='Path to input mask for inpainting mode (supersedes width and height)',
|
||||
)
|
||||
img2img_group.add_argument(
|
||||
'--init_color',
|
||||
type=str,
|
||||
help='Path to reference image for color correction (used for repeated img2img and inpainting)'
|
||||
)
|
||||
img2img_group.add_argument(
|
||||
'-T',
|
||||
'-fit',
|
||||
@ -477,6 +516,12 @@ class Args(object):
|
||||
help='Strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely',
|
||||
default=0.75,
|
||||
)
|
||||
postprocessing_group.add_argument(
|
||||
'-ft',
|
||||
'--facetool',
|
||||
type=str,
|
||||
help='Select the face restoration AI to use: gfpgan, codeformer',
|
||||
)
|
||||
postprocessing_group.add_argument(
|
||||
'-G',
|
||||
'--gfpgan_strength',
|
||||
@ -484,6 +529,13 @@ class Args(object):
|
||||
help='The strength at which to apply the GFPGAN model to the result, in order to improve faces.',
|
||||
default=0,
|
||||
)
|
||||
postprocessing_group.add_argument(
|
||||
'-cf',
|
||||
'--codeformer_fidelity',
|
||||
type=float,
|
||||
help='Takes values between 0 and 1. 0 produces high quality but low accuracy. 1 produces high accuracy but low quality.',
|
||||
default=0.75
|
||||
)
|
||||
postprocessing_group.add_argument(
|
||||
'-U',
|
||||
'--upscale',
|
||||
@ -535,18 +587,31 @@ class Args(object):
|
||||
)
|
||||
return parser
|
||||
|
||||
# very partial implementation of https://github.com/lstein/stable-diffusion/issues/266
|
||||
# it does not write all the required top-level metadata, writes too much image
|
||||
# data, and doesn't support grids yet. But you gotta start somewhere, no?
|
||||
def format_metadata(opt,
|
||||
seeds=[],
|
||||
weights=None,
|
||||
model_hash=None,
|
||||
postprocessing=None):
|
||||
def format_metadata(**kwargs):
|
||||
print(f'format_metadata() is deprecated. Please use metadata_dumps()')
|
||||
return metadata_dumps(kwargs)
|
||||
|
||||
def metadata_dumps(opt,
|
||||
seeds=[],
|
||||
model_hash=None,
|
||||
postprocessing=None):
|
||||
'''
|
||||
Given an Args object, returns a partial implementation of
|
||||
the stable diffusion metadata standard
|
||||
Given an Args object, returns a dict containing the keys and
|
||||
structure of the proposed stable diffusion metadata standard
|
||||
https://github.com/lstein/stable-diffusion/discussions/392
|
||||
This is intended to be turned into JSON and stored in the
|
||||
"sd
|
||||
'''
|
||||
|
||||
# top-level metadata minus `image` or `images`
|
||||
metadata = {
|
||||
'model' : 'stable diffusion',
|
||||
'model_id' : opt.model,
|
||||
'model_hash' : model_hash,
|
||||
'app_id' : APP_ID,
|
||||
'app_version' : APP_VERSION,
|
||||
}
|
||||
|
||||
# add some RFC266 fields that are generated internally, and not as
|
||||
# user args
|
||||
image_dict = opt.to_dict(
|
||||
@ -587,24 +652,67 @@ def format_metadata(opt,
|
||||
if opt.init_img:
|
||||
rfc_dict['type'] = 'img2img'
|
||||
rfc_dict['strength_steps'] = rfc_dict.pop('strength')
|
||||
rfc_dict['orig_hash'] = sha256(image_dict['init_img'])
|
||||
rfc_dict['orig_hash'] = calculate_init_img_hash(opt.init_img)
|
||||
rfc_dict['sampler'] = 'ddim' # FIX ME WHEN IMG2IMG SUPPORTS ALL SAMPLERS
|
||||
else:
|
||||
rfc_dict['type'] = 'txt2img'
|
||||
|
||||
images = []
|
||||
for seed in seeds:
|
||||
rfc_dict['seed'] = seed
|
||||
images.append(copy.copy(rfc_dict))
|
||||
if len(seeds)==0 and opt.seed:
|
||||
seeds=[seed]
|
||||
|
||||
return {
|
||||
'model' : 'stable diffusion',
|
||||
'model_id' : opt.model,
|
||||
'model_hash' : model_hash,
|
||||
'app_id' : APP_ID,
|
||||
'app_version' : APP_VERSION,
|
||||
'images' : images,
|
||||
}
|
||||
if opt.grid:
|
||||
images = []
|
||||
for seed in seeds:
|
||||
rfc_dict['seed'] = seed
|
||||
images.append(copy.copy(rfc_dict))
|
||||
metadata['images'] = images
|
||||
else:
|
||||
# there should only ever be a single seed if we did not generate a grid
|
||||
assert len(seeds) == 1, 'Expected a single seed'
|
||||
rfc_dict['seed'] = seeds[0]
|
||||
metadata['image'] = rfc_dict
|
||||
|
||||
return metadata
|
||||
|
||||
def metadata_loads(metadata):
|
||||
'''
|
||||
Takes the dictionary corresponding to RFC266 (https://github.com/lstein/stable-diffusion/issues/266)
|
||||
and returns a series of opt objects for each of the images described in the dictionary.
|
||||
'''
|
||||
results = []
|
||||
try:
|
||||
images = metadata['sd-metadata']['images']
|
||||
for image in images:
|
||||
# repack the prompt and variations
|
||||
image['prompt'] = ','.join([':'.join([x['prompt'], str(x['weight'])]) for x in image['prompt']])
|
||||
image['variations'] = ','.join([':'.join([str(x['seed']),str(x['weight'])]) for x in image['variations']])
|
||||
# fix a bit of semantic drift here
|
||||
image['sampler_name']=image.pop('sampler')
|
||||
opt = Args()
|
||||
opt._cmd_switches = Namespace(**image)
|
||||
results.append(opt)
|
||||
except KeyError as e:
|
||||
import sys, traceback
|
||||
print('>> badly-formatted metadata',file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
return results
|
||||
|
||||
# image can either be a file path on disk or a base64-encoded
|
||||
# representation of the file's contents
|
||||
def calculate_init_img_hash(image_string):
|
||||
prefix = 'data:image/png;base64,'
|
||||
hash = None
|
||||
if image_string.startswith(prefix):
|
||||
imagebase64 = image_string[len(prefix):]
|
||||
imagedata = base64.b64decode(imagebase64)
|
||||
with open('outputs/test.png','wb') as file:
|
||||
file.write(imagedata)
|
||||
sha = hashlib.sha256()
|
||||
sha.update(imagedata)
|
||||
hash = sha.hexdigest()
|
||||
else:
|
||||
hash = sha256(image_string)
|
||||
return hash
|
||||
|
||||
# Bah. This should be moved somewhere else...
|
||||
def sha256(path):
|
||||
|
@ -13,7 +13,20 @@ import re
|
||||
import torch
|
||||
|
||||
def get_uc_and_c(prompt, model, log_tokens=False, skip_normalize=False):
|
||||
uc = model.get_learned_conditioning([''])
|
||||
# Extract Unconditioned Words From Prompt
|
||||
unconditioned_words = ''
|
||||
unconditional_regex = r'\[(.*?)\]'
|
||||
unconditionals = re.findall(unconditional_regex, prompt)
|
||||
|
||||
if len(unconditionals) > 0:
|
||||
unconditioned_words = ' '.join(unconditionals)
|
||||
|
||||
# Remove Unconditioned Words From Prompt
|
||||
unconditional_regex_compile = re.compile(unconditional_regex)
|
||||
clean_prompt = unconditional_regex_compile.sub(' ', prompt)
|
||||
prompt = re.sub(' +', ' ', clean_prompt)
|
||||
|
||||
uc = model.get_learned_conditioning([unconditioned_words])
|
||||
|
||||
# get weighted sub-prompts
|
||||
weighted_subprompts = split_weighted_subprompts(
|
||||
@ -25,15 +38,16 @@ def get_uc_and_c(prompt, model, log_tokens=False, skip_normalize=False):
|
||||
c = torch.zeros_like(uc)
|
||||
# normalize each "sub prompt" and add it
|
||||
for subprompt, weight in weighted_subprompts:
|
||||
log_tokenization(subprompt, model, log_tokens)
|
||||
log_tokenization(subprompt, model, log_tokens, weight)
|
||||
c = torch.add(
|
||||
c,
|
||||
model.get_learned_conditioning([subprompt]),
|
||||
alpha=weight,
|
||||
)
|
||||
else: # just standard 1 prompt
|
||||
log_tokenization(prompt, model, log_tokens)
|
||||
log_tokenization(prompt, model, log_tokens, 1)
|
||||
c = model.get_learned_conditioning([prompt])
|
||||
uc = model.get_learned_conditioning([unconditioned_words])
|
||||
return (uc, c)
|
||||
|
||||
def split_weighted_subprompts(text, skip_normalize=False)->list:
|
||||
@ -72,7 +86,7 @@ def split_weighted_subprompts(text, skip_normalize=False)->list:
|
||||
# shows how the prompt is tokenized
|
||||
# usually tokens have '</w>' to indicate end-of-word,
|
||||
# but for readability it has been replaced with ' '
|
||||
def log_tokenization(text, model, log=False):
|
||||
def log_tokenization(text, model, log=False, weight=1):
|
||||
if not log:
|
||||
return
|
||||
tokens = model.cond_stage_model.tokenizer._tokenize(text)
|
||||
@ -89,8 +103,8 @@ def log_tokenization(text, model, log=False):
|
||||
usedTokens += 1
|
||||
else: # over max token length
|
||||
discarded = discarded + f"\x1b[0;3{s};40m{token}"
|
||||
print(f"\n>> Tokens ({usedTokens}):\n{tokenized}\x1b[0m")
|
||||
if discarded != "":
|
||||
print(
|
||||
f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m"
|
||||
)
|
||||
print(f"\n>> Tokens ({usedTokens}), Weight ({weight:.2f}):\n{tokenized}\x1b[0m")
|
||||
if discarded != "":
|
||||
print(
|
||||
f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m"
|
||||
)
|
||||
|
@ -22,7 +22,8 @@ class Completer:
|
||||
def complete(self, text, state):
|
||||
buffer = readline.get_line_buffer()
|
||||
|
||||
if text.startswith(('-I', '--init_img','-M','--init_mask')):
|
||||
if text.startswith(('-I', '--init_img','-M','--init_mask',
|
||||
'--init_color')):
|
||||
return self._path_completions(text, state, ('.png','.jpg','.jpeg'))
|
||||
|
||||
if buffer.strip().endswith('cd') or text.startswith(('.', '/')):
|
||||
@ -57,6 +58,8 @@ class Completer:
|
||||
path = text.replace('--init_mask=', '', 1).lstrip()
|
||||
elif text.startswith('-M'):
|
||||
path = text.replace('-M', '', 1).lstrip()
|
||||
elif text.startswith('--init_color='):
|
||||
path = text.replace('--init_color=', '', 1).lstrip()
|
||||
else:
|
||||
path = text
|
||||
|
||||
@ -100,6 +103,7 @@ if readline_available:
|
||||
'--individual','-i',
|
||||
'--init_img','-I',
|
||||
'--init_mask','-M',
|
||||
'--init_color',
|
||||
'--strength','-f',
|
||||
'--variants','-v',
|
||||
'--outdir','-o',
|
||||
|
@ -4,7 +4,7 @@ import copy
|
||||
import base64
|
||||
import mimetypes
|
||||
import os
|
||||
from ldm.dream.args import Args, format_metadata
|
||||
from ldm.dream.args import Args, metadata_dumps
|
||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||
from ldm.dream.pngwriter import PngWriter
|
||||
from threading import Event
|
||||
@ -76,7 +76,7 @@ class DreamServer(BaseHTTPRequestHandler):
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "text/html")
|
||||
self.end_headers()
|
||||
with open("./static/dream_web/index.html", "rb") as content:
|
||||
with open("./static/legacy_web/index.html", "rb") as content:
|
||||
self.wfile.write(content.read())
|
||||
elif self.path == "/config.js":
|
||||
# unfortunately this import can't be at the top level, since that would cause a circular import
|
||||
@ -94,7 +94,7 @@ class DreamServer(BaseHTTPRequestHandler):
|
||||
self.end_headers()
|
||||
output = []
|
||||
|
||||
log_file = os.path.join(self.outdir, "dream_web_log.txt")
|
||||
log_file = os.path.join(self.outdir, "legacy_web_log.txt")
|
||||
if os.path.exists(log_file):
|
||||
with open(log_file, "r") as log:
|
||||
for line in log:
|
||||
@ -114,7 +114,7 @@ class DreamServer(BaseHTTPRequestHandler):
|
||||
else:
|
||||
path_dir = os.path.dirname(self.path)
|
||||
out_dir = os.path.realpath(self.outdir.rstrip('/'))
|
||||
if self.path.startswith('/static/dream_web/'):
|
||||
if self.path.startswith('/static/legacy_web/'):
|
||||
path = '.' + self.path
|
||||
elif out_dir.replace('\\', '/').endswith(path_dir):
|
||||
file = os.path.basename(self.path)
|
||||
@ -145,7 +145,6 @@ class DreamServer(BaseHTTPRequestHandler):
|
||||
opt = build_opt(post_data, self.model.seed, gfpgan_model_exists)
|
||||
|
||||
self.canceled.clear()
|
||||
print(f">> Request to generate with prompt: {opt.prompt}")
|
||||
# In order to handle upscaled images, the PngWriter needs to maintain state
|
||||
# across images generated by each call to prompt2img(), so we define it in
|
||||
# the outer scope of image_done()
|
||||
@ -176,10 +175,9 @@ class DreamServer(BaseHTTPRequestHandler):
|
||||
path = pngwriter.save_image_and_prompt_to_png(
|
||||
image,
|
||||
dream_prompt = formatted_prompt,
|
||||
metadata = format_metadata(iter_opt,
|
||||
seeds = [seed],
|
||||
weights = self.model.weights,
|
||||
model_hash = self.model.model_hash
|
||||
metadata = metadata_dumps(iter_opt,
|
||||
seeds = [seed],
|
||||
model_hash = self.model.model_hash
|
||||
),
|
||||
name = name,
|
||||
)
|
||||
@ -188,7 +186,7 @@ class DreamServer(BaseHTTPRequestHandler):
|
||||
config['seed'] = seed
|
||||
# Append post_data to log, but only once!
|
||||
if not upscaled:
|
||||
with open(os.path.join(self.outdir, "dream_web_log.txt"), "a") as log:
|
||||
with open(os.path.join(self.outdir, "legacy_web_log.txt"), "a") as log:
|
||||
log.write(f"{path}: {json.dumps(config)}\n")
|
||||
|
||||
self.wfile.write(bytes(json.dumps(
|
||||
@ -228,7 +226,8 @@ class DreamServer(BaseHTTPRequestHandler):
|
||||
nonlocal step_index
|
||||
if opt.progress_images and step % 5 == 0 and step < opt.steps - 1:
|
||||
image = self.model.sample_to_image(sample)
|
||||
name = f'{prefix}.{opt.seed}.{step_index}.png'
|
||||
step_index_padded = str(step_index).rjust(len(str(opt.steps)), '0')
|
||||
name = f'{prefix}.{opt.seed}.{step_index_padded}.png'
|
||||
metadata = f'{opt.prompt} -S{opt.seed} [intermediate]'
|
||||
path = step_writer.save_image_and_prompt_to_png(image, dream_prompt=metadata, name=name)
|
||||
step_index += 1
|
||||
|
@ -15,6 +15,8 @@ import traceback
|
||||
import transformers
|
||||
import io
|
||||
import hashlib
|
||||
import cv2
|
||||
import skimage
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
from PIL import Image, ImageOps
|
||||
@ -220,11 +222,14 @@ class Generate:
|
||||
init_mask = None,
|
||||
fit = False,
|
||||
strength = None,
|
||||
init_color = None,
|
||||
# these are specific to embiggen (which also relies on img2img args)
|
||||
embiggen = None,
|
||||
embiggen_tiles = None,
|
||||
# these are specific to GFPGAN/ESRGAN
|
||||
facetool = None,
|
||||
gfpgan_strength = 0,
|
||||
codeformer_fidelity = None,
|
||||
save_original = False,
|
||||
upscale = None,
|
||||
# Set this True to handle KeyboardInterrupt internally
|
||||
@ -362,10 +367,17 @@ class Generate:
|
||||
embiggen_tiles = embiggen_tiles,
|
||||
)
|
||||
|
||||
if init_color:
|
||||
self.correct_colors(image_list = results,
|
||||
reference_image_path = init_color,
|
||||
image_callback = image_callback)
|
||||
|
||||
if upscale is not None or gfpgan_strength > 0:
|
||||
self.upscale_and_reconstruct(results,
|
||||
upscale = upscale,
|
||||
facetool = facetool,
|
||||
strength = gfpgan_strength,
|
||||
codeformer_fidelity = codeformer_fidelity,
|
||||
save_original = save_original,
|
||||
image_callback = image_callback)
|
||||
|
||||
@ -475,17 +487,44 @@ class Generate:
|
||||
|
||||
return self.model
|
||||
|
||||
def correct_colors(self,
|
||||
image_list,
|
||||
reference_image_path,
|
||||
image_callback = None):
|
||||
reference_image = Image.open(reference_image_path)
|
||||
correction_target = cv2.cvtColor(np.asarray(reference_image),
|
||||
cv2.COLOR_RGB2LAB)
|
||||
for r in image_list:
|
||||
image, seed = r
|
||||
image = cv2.cvtColor(np.asarray(image),
|
||||
cv2.COLOR_RGB2LAB)
|
||||
image = skimage.exposure.match_histograms(image,
|
||||
correction_target,
|
||||
channel_axis=2)
|
||||
image = Image.fromarray(
|
||||
cv2.cvtColor(image, cv2.COLOR_LAB2RGB).astype("uint8")
|
||||
)
|
||||
if image_callback is not None:
|
||||
image_callback(image, seed)
|
||||
else:
|
||||
r[0] = image
|
||||
|
||||
def upscale_and_reconstruct(self,
|
||||
image_list,
|
||||
facetool = 'gfpgan',
|
||||
upscale = None,
|
||||
strength = 0.0,
|
||||
codeformer_fidelity = 0.75,
|
||||
save_original = False,
|
||||
image_callback = None):
|
||||
try:
|
||||
if upscale is not None:
|
||||
from ldm.gfpgan.gfpgan_tools import real_esrgan_upscale
|
||||
if strength > 0:
|
||||
from ldm.gfpgan.gfpgan_tools import run_gfpgan
|
||||
if facetool == 'codeformer':
|
||||
from ldm.restoration.codeformer.codeformer import CodeFormerRestoration
|
||||
else:
|
||||
from ldm.gfpgan.gfpgan_tools import run_gfpgan
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print('>> You may need to install the ESRGAN and/or GFPGAN modules')
|
||||
@ -504,9 +543,12 @@ class Generate:
|
||||
seed,
|
||||
)
|
||||
if strength > 0:
|
||||
image = run_gfpgan(
|
||||
image, strength, seed, 1
|
||||
)
|
||||
if facetool == 'codeformer':
|
||||
image = CodeFormerRestoration().process(image=image, strength=strength, device=self.device, seed=seed, fidelity=codeformer_fidelity)
|
||||
else:
|
||||
image = run_gfpgan(
|
||||
image, strength, seed, 1
|
||||
)
|
||||
except Exception as e:
|
||||
print(
|
||||
f'>> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}'
|
||||
|
@ -90,7 +90,7 @@ class LinearAttention(nn.Module):
|
||||
b, c, h, w = x.shape
|
||||
qkv = self.to_qkv(x)
|
||||
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
|
||||
k = k.softmax(dim=-1)
|
||||
k = k.softmax(dim=-1)
|
||||
context = torch.einsum('bhdn,bhen->bhde', k, v)
|
||||
out = torch.einsum('bhde,bhdn->bhen', context, q)
|
||||
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
|
||||
@ -167,101 +167,85 @@ class CrossAttention(nn.Module):
|
||||
nn.Linear(inner_dim, query_dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
self.einsum_op = self.einsum_op_cuda
|
||||
else:
|
||||
self.mem_total = psutil.virtual_memory().total / (1024**3)
|
||||
self.einsum_op = self.einsum_op_mps_v1 if self.mem_total >= 32 else self.einsum_op_mps_v2
|
||||
|
||||
def einsum_op_compvis(self, q, k, v, r1):
|
||||
s1 = einsum('b i d, b j d -> b i j', q, k) * self.scale # faster
|
||||
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
||||
del s1
|
||||
r1 = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
return r1
|
||||
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
||||
|
||||
def einsum_op_mps_v1(self, q, k, v, r1):
|
||||
def einsum_op_compvis(self, q, k, v):
|
||||
s = einsum('b i d, b j d -> b i j', q, k)
|
||||
s = s.softmax(dim=-1, dtype=s.dtype)
|
||||
return einsum('b i j, b j d -> b i d', s, v)
|
||||
|
||||
def einsum_op_slice_0(self, q, k, v, slice_size):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[0], slice_size):
|
||||
end = i + slice_size
|
||||
r[i:end] = self.einsum_op_compvis(q[i:end], k[i:end], v[i:end])
|
||||
return r
|
||||
|
||||
def einsum_op_slice_1(self, q, k, v, slice_size):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
r[:, i:end] = self.einsum_op_compvis(q[:, i:end], k, v)
|
||||
return r
|
||||
|
||||
def einsum_op_mps_v1(self, q, k, v):
|
||||
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
|
||||
r1 = self.einsum_op_compvis(q, k, v, r1)
|
||||
return self.einsum_op_compvis(q, k, v)
|
||||
else:
|
||||
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
|
||||
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
|
||||
del s1
|
||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
return r1
|
||||
return self.einsum_op_slice_1(q, k, v, slice_size)
|
||||
|
||||
def einsum_op_mps_v2(self, q, k, v, r1):
|
||||
if self.mem_total >= 8 and q.shape[1] <= 4096:
|
||||
r1 = self.einsum_op_compvis(q, k, v, r1)
|
||||
def einsum_op_mps_v2(self, q, k, v):
|
||||
if self.mem_total_gb > 8 and q.shape[1] <= 4096:
|
||||
return self.einsum_op_compvis(q, k, v)
|
||||
else:
|
||||
slice_size = 1
|
||||
for i in range(0, q.shape[0], slice_size):
|
||||
end = min(q.shape[0], i + slice_size)
|
||||
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
|
||||
s1 *= self.scale
|
||||
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
|
||||
del s1
|
||||
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
|
||||
del s2
|
||||
return r1
|
||||
|
||||
def einsum_op_cuda(self, q, k, v, r1):
|
||||
return self.einsum_op_slice_0(q, k, v, 1)
|
||||
|
||||
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb):
|
||||
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
|
||||
if size_mb <= max_tensor_mb:
|
||||
return self.einsum_op_compvis(q, k, v)
|
||||
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
|
||||
if div <= q.shape[0]:
|
||||
return self.einsum_op_slice_0(q, k, v, q.shape[0] // div)
|
||||
return self.einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
|
||||
|
||||
def einsum_op_cuda(self, q, k, v):
|
||||
stats = torch.cuda.memory_stats(q.device)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
||||
mem_free_cuda, _ = torch.cuda.mem_get_info(q.device)
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_cuda + mem_free_torch
|
||||
# Divide factor of safety as there's copying and fragmentation
|
||||
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
||||
|
||||
gb = 1024 ** 3
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * 4
|
||||
mem_required = tensor_size * 2.5
|
||||
steps = 1
|
||||
def einsum_op(self, q, k, v):
|
||||
if q.device.type == 'cuda':
|
||||
return self.einsum_op_cuda(q, k, v)
|
||||
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
if q.device.type == 'mps':
|
||||
if self.mem_total_gb >= 32:
|
||||
return self.einsum_op_mps_v1(q, k, v)
|
||||
return self.einsum_op_mps_v2(q, k, v)
|
||||
|
||||
if steps > 64:
|
||||
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
||||
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
||||
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
|
||||
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = min(q.shape[1], i + slice_size)
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
|
||||
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
|
||||
del s1
|
||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
return r1
|
||||
# Smaller slices are faster due to L2/L3/SLC caches.
|
||||
# Tested on i7 with 8MB L3 cache.
|
||||
return self.einsum_op_tensor_mem(q, k, v, 32)
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
q_in = self.to_q(x)
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k_in = self.to_k(context)
|
||||
v_in = self.to_v(context)
|
||||
device_type = 'mps' if x.device.type == 'mps' else 'cuda'
|
||||
k = self.to_k(context) * self.scale
|
||||
v = self.to_v(context)
|
||||
del context, x
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
||||
del q_in, k_in, v_in
|
||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
r1 = self.einsum_op(q, k, v, r1)
|
||||
del q, k, v
|
||||
|
||||
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
||||
del r1
|
||||
|
||||
return self.to_out(r2)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
r = self.einsum_op(q, k, v)
|
||||
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
|
@ -3,6 +3,7 @@ import gc
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.functional import silu
|
||||
import numpy as np
|
||||
from einops import rearrange
|
||||
|
||||
@ -32,11 +33,6 @@ def get_timestep_embedding(timesteps, embedding_dim):
|
||||
return emb
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return x*torch.sigmoid(x)
|
||||
|
||||
|
||||
def Normalize(in_channels, num_groups=32):
|
||||
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
@ -122,14 +118,14 @@ class ResnetBlock(nn.Module):
|
||||
|
||||
def forward(self, x, temb):
|
||||
h = self.norm1(x)
|
||||
h = nonlinearity(h)
|
||||
h = silu(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
|
||||
h = h + self.temb_proj(silu(temb))[:,:,None,None]
|
||||
|
||||
h = self.norm2(h)
|
||||
h = nonlinearity(h)
|
||||
h = silu(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
@ -368,7 +364,7 @@ class Model(nn.Module):
|
||||
assert t is not None
|
||||
temb = get_timestep_embedding(t, self.ch)
|
||||
temb = self.temb.dense[0](temb)
|
||||
temb = nonlinearity(temb)
|
||||
temb = silu(temb)
|
||||
temb = self.temb.dense[1](temb)
|
||||
else:
|
||||
temb = None
|
||||
@ -402,7 +398,7 @@ class Model(nn.Module):
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = silu(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
@ -499,7 +495,7 @@ class Encoder(nn.Module):
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = silu(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
@ -611,7 +607,7 @@ class Decoder(nn.Module):
|
||||
return h
|
||||
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = silu(h)
|
||||
h = self.conv_out(h)
|
||||
if self.tanh_out:
|
||||
h = torch.tanh(h)
|
||||
@ -649,7 +645,7 @@ class SimpleDecoder(nn.Module):
|
||||
x = layer(x)
|
||||
|
||||
h = self.norm_out(x)
|
||||
h = nonlinearity(h)
|
||||
h = silu(h)
|
||||
x = self.conv_out(h)
|
||||
return x
|
||||
|
||||
@ -697,7 +693,7 @@ class UpsampleDecoder(nn.Module):
|
||||
if i_level != self.num_resolutions - 1:
|
||||
h = self.upsample_blocks[k](h)
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = silu(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
@ -873,7 +869,7 @@ class FirstStagePostProcessor(nn.Module):
|
||||
z_fs = self.encode_with_pretrained(x)
|
||||
z = self.proj_norm(z_fs)
|
||||
z = self.proj(z)
|
||||
z = nonlinearity(z)
|
||||
z = silu(z)
|
||||
|
||||
for submodel, downmodel in zip(self.model,self.downsampler):
|
||||
z = submodel(z,temb=None)
|
||||
|
@ -252,12 +252,6 @@ def normalization(channels):
|
||||
return GroupNorm32(32, channels)
|
||||
|
||||
|
||||
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
||||
class SiLU(nn.Module):
|
||||
def forward(self, x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class GroupNorm32(nn.GroupNorm):
|
||||
def forward(self, x):
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
|
@ -82,7 +82,9 @@ class EmbeddingManager(nn.Module):
|
||||
get_embedding_for_clip_token,
|
||||
embedder.transformer.text_model.embeddings,
|
||||
)
|
||||
token_dim = 1280
|
||||
# per bug report #572
|
||||
#token_dim = 1280
|
||||
token_dim = 768
|
||||
else: # using LDM's BERT encoder
|
||||
self.is_clip = False
|
||||
get_token_for_string = partial(
|
||||
|