Merge branch 'development' of github.com:psychedelicious/stable-diffusion into psychedelicious-development

This commit is contained in:
Lincoln Stein 2022-09-19 12:27:42 -04:00
commit 58c63fe339
114 changed files with 6192 additions and 3831 deletions

View File

@ -1,64 +0,0 @@
name: Cache Model
on:
workflow_dispatch
jobs:
build:
strategy:
matrix:
os: [ macos-12 ]
name: Create Caches using ${{ matrix.os }}
runs-on: ${{ matrix.os }}
steps:
- name: Checkout sources
uses: actions/checkout@v3
- name: Cache model
id: cache-sd-v1-4
uses: actions/cache@v3
env:
cache-name: cache-sd-v1-4
with:
path: models/ldm/stable-diffusion-v1/model.ckpt
key: ${{ env.cache-name }}
restore-keys: |
${{ env.cache-name }}
- name: Download Stable Diffusion v1.4 model
if: ${{ steps.cache-sd-v1-4.outputs.cache-hit != 'true' }}
continue-on-error: true
run: |
if [ ! -e models/ldm/stable-diffusion-v1 ]; then
mkdir -p models/ldm/stable-diffusion-v1
fi
if [ ! -e models/ldm/stable-diffusion-v1/model.ckpt ]; then
curl -o models/ldm/stable-diffusion-v1/model.ckpt ${{ secrets.SD_V1_4_URL }}
fi
# Uncomment this when we no longer make changes to environment-mac.yaml
# - name: Cache environment
# id: cache-conda-env-ldm
# uses: actions/cache@v3
# env:
# cache-name: cache-conda-env-ldm
# with:
# path: ~/.conda/envs/ldm
# key: ${{ env.cache-name }}
# restore-keys: |
# ${{ env.cache-name }}
- name: Install dependencies
# if: ${{ steps.cache-conda-env-ldm.outputs.cache-hit != 'true' }}
run: |
conda env create -f environment-mac.yaml
- name: Cache hugginface and torch models
id: cache-hugginface-torch
uses: actions/cache@v3
env:
cache-name: cache-hugginface-torch
with:
path: ~/.cache
key: ${{ env.cache-name }}
restore-keys: |
${{ env.cache-name }}
- name: Download Huggingface and Torch models
if: ${{ steps.cache-hugginface-torch.outputs.cache-hit != 'true' }}
continue-on-error: true
run: |
export PYTHON_BIN=/usr/local/miniconda/envs/ldm/bin/python
$PYTHON_BIN scripts/preload_models.py

70
.github/workflows/create-caches.yml vendored Normal file
View File

@ -0,0 +1,70 @@
name: Create Caches
on:
workflow_dispatch
jobs:
build:
strategy:
matrix:
os: [ ubuntu-latest, macos-12 ]
name: Create Caches on ${{ matrix.os }} conda
runs-on: ${{ matrix.os }}
steps:
- name: Set platform variables
id: vars
run: |
if [ "$RUNNER_OS" = "macOS" ]; then
echo "::set-output name=ENV_FILE::environment-mac.yaml"
echo "::set-output name=PYTHON_BIN::/usr/local/miniconda/envs/ldm/bin/python"
elif [ "$RUNNER_OS" = "Linux" ]; then
echo "::set-output name=ENV_FILE::environment.yaml"
echo "::set-output name=PYTHON_BIN::/usr/share/miniconda/envs/ldm/bin/python"
fi
- name: Checkout sources
uses: actions/checkout@v3
- name: Use Cached Stable Diffusion v1.4 Model
id: cache-sd-v1-4
uses: actions/cache@v3
env:
cache-name: cache-sd-v1-4
with:
path: models/ldm/stable-diffusion-v1/model.ckpt
key: ${{ env.cache-name }}
restore-keys: |
${{ env.cache-name }}
- name: Download Stable Diffusion v1.4 Model
if: ${{ steps.cache-sd-v1-4.outputs.cache-hit != 'true' }}
run: |
if [ ! -e models/ldm/stable-diffusion-v1 ]; then
mkdir -p models/ldm/stable-diffusion-v1
fi
if [ ! -e models/ldm/stable-diffusion-v1/model.ckpt ]; then
curl -o models/ldm/stable-diffusion-v1/model.ckpt ${{ secrets.SD_V1_4_URL }}
fi
- name: Use Cached Dependencies
id: cache-conda-env-ldm
uses: actions/cache@v3
env:
cache-name: cache-conda-env-ldm
with:
path: ~/.conda/envs/ldm
key: ${{ env.cache-name }}
restore-keys: |
${{ env.cache-name }}-${{ runner.os }}-${{ hashFiles(steps.vars.outputs.ENV_FILE) }}
- name: Install Dependencies
if: ${{ steps.cache-conda-env-ldm.outputs.cache-hit != 'true' }}
run: |
conda env create -f ${{ steps.vars.outputs.ENV_FILE }}
- name: Use Cached Huggingface and Torch models
id: cache-huggingface-torch
uses: actions/cache@v3
env:
cache-name: cache-huggingface-torch
with:
path: ~/.cache
key: ${{ env.cache-name }}
restore-keys: |
${{ env.cache-name }}-${{ hashFiles('scripts/preload_models.py') }}
- name: Download Huggingface and Torch models
if: ${{ steps.cache-huggingface-torch.outputs.cache-hit != 'true' }}
run: |
${{ steps.vars.outputs.PYTHON_BIN }} scripts/preload_models.py

View File

@ -1,80 +0,0 @@
name: Build
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
jobs:
build:
strategy:
matrix:
os: [ macos-12 ]
name: Build on ${{ matrix.os }} miniconda
runs-on: ${{ matrix.os }}
steps:
- name: Checkout sources
uses: actions/checkout@v3
- name: Cache model
id: cache-sd-v1-4
uses: actions/cache@v3
env:
cache-name: cache-sd-v1-4
with:
path: models/ldm/stable-diffusion-v1/model.ckpt
key: ${{ env.cache-name }}
restore-keys: |
${{ env.cache-name }}
- name: Download Stable Diffusion v1.4 model
if: ${{ steps.cache-sd-v1-4.outputs.cache-hit != 'true' }}
continue-on-error: true
run: |
if [ ! -e models/ldm/stable-diffusion-v1 ]; then
mkdir -p models/ldm/stable-diffusion-v1
fi
if [ ! -e models/ldm/stable-diffusion-v1/model.ckpt ]; then
curl -o models/ldm/stable-diffusion-v1/model.ckpt ${{ secrets.SD_V1_4_URL }}
fi
# Uncomment this when we no longer make changes to environment-mac.yaml
# - name: Cache environment
# id: cache-conda-env-ldm
# uses: actions/cache@v3
# env:
# cache-name: cache-conda-env-ldm
# with:
# path: ~/.conda/envs/ldm
# key: ${{ env.cache-name }}
# restore-keys: |
# ${{ env.cache-name }}
- name: Install dependencies
# if: ${{ steps.cache-conda-env-ldm.outputs.cache-hit != 'true' }}
run: |
conda env create -f environment-mac.yaml
- name: Cache hugginface and torch models
id: cache-hugginface-torch
uses: actions/cache@v3
env:
cache-name: cache-hugginface-torch
with:
path: ~/.cache
key: ${{ env.cache-name }}
restore-keys: |
${{ env.cache-name }}
- name: Download Huggingface and Torch models
if: ${{ steps.cache-hugginface-torch.outputs.cache-hit != 'true' }}
continue-on-error: true
run: |
export PYTHON_BIN=/usr/local/miniconda/envs/ldm/bin/python
$PYTHON_BIN scripts/preload_models.py
- name: Run the tests
run: |
# Note, can't "activate" via automation, and activation is just env vars and path
export PYTHON_BIN=/usr/local/miniconda/envs/ldm/bin/python
export PYTORCH_ENABLE_MPS_FALLBACK=1
$PYTHON_BIN scripts/preload_models.py
mkdir -p outputs/img-samples
time $PYTHON_BIN scripts/dream.py --from_file tests/prompts.txt </dev/null 2> outputs/img-samples/err.log > outputs/img-samples/out.log
- name: Archive results
uses: actions/upload-artifact@v3
with:
name: results
path: outputs/img-samples

97
.github/workflows/test-dream-conda.yml vendored Normal file
View File

@ -0,0 +1,97 @@
name: Test Dream with Conda
on:
push:
branches:
- 'main'
- 'development'
jobs:
os_matrix:
strategy:
matrix:
os: [ ubuntu-latest, macos-12 ]
name: Test dream.py on ${{ matrix.os }} with conda
runs-on: ${{ matrix.os }}
steps:
- run: |
echo The PR was merged
- name: Set platform variables
id: vars
run: |
# Note, can't "activate" via github action; specifying the env's python has the same effect
if [ "$RUNNER_OS" = "macOS" ]; then
echo "::set-output name=ENV_FILE::environment-mac.yaml"
echo "::set-output name=PYTHON_BIN::/usr/local/miniconda/envs/ldm/bin/python"
elif [ "$RUNNER_OS" = "Linux" ]; then
echo "::set-output name=ENV_FILE::environment.yaml"
echo "::set-output name=PYTHON_BIN::/usr/share/miniconda/envs/ldm/bin/python"
fi
- name: Checkout sources
uses: actions/checkout@v3
- name: Use Cached Stable Diffusion v1.4 Model
id: cache-sd-v1-4
uses: actions/cache@v3
env:
cache-name: cache-sd-v1-4
with:
path: models/ldm/stable-diffusion-v1/model.ckpt
key: ${{ env.cache-name }}
restore-keys: |
${{ env.cache-name }}
- name: Download Stable Diffusion v1.4 Model
if: ${{ steps.cache-sd-v1-4.outputs.cache-hit != 'true' }}
run: |
if [ ! -e models/ldm/stable-diffusion-v1 ]; then
mkdir -p models/ldm/stable-diffusion-v1
fi
if [ ! -e models/ldm/stable-diffusion-v1/model.ckpt ]; then
curl -o models/ldm/stable-diffusion-v1/model.ckpt ${{ secrets.SD_V1_4_URL }}
fi
- name: Use Cached Dependencies
id: cache-conda-env-ldm
uses: actions/cache@v3
env:
cache-name: cache-conda-env-ldm
with:
path: ~/.conda/envs/ldm
key: ${{ env.cache-name }}
restore-keys: |
${{ env.cache-name }}-${{ runner.os }}-${{ hashFiles(steps.vars.outputs.ENV_FILE) }}
- name: Install Dependencies
if: ${{ steps.cache-conda-env-ldm.outputs.cache-hit != 'true' }}
run: |
conda env create -f ${{ steps.vars.outputs.ENV_FILE }}
- name: Use Cached Huggingface and Torch models
id: cache-hugginface-torch
uses: actions/cache@v3
env:
cache-name: cache-hugginface-torch
with:
path: ~/.cache
key: ${{ env.cache-name }}
restore-keys: |
${{ env.cache-name }}-${{ hashFiles('scripts/preload_models.py') }}
- name: Download Huggingface and Torch models
if: ${{ steps.cache-hugginface-torch.outputs.cache-hit != 'true' }}
run: |
${{ steps.vars.outputs.PYTHON_BIN }} scripts/preload_models.py
# - name: Run tmate
# uses: mxschmitt/action-tmate@v3
# timeout-minutes: 30
- name: Run the tests
run: |
# Note, can't "activate" via github action; specifying the env's python has the same effect
if [ $(uname) = "Darwin" ]; then
export PYTORCH_ENABLE_MPS_FALLBACK=1
fi
# Utterly hacky, but I don't know how else to do this
if [[ ${{ github.ref }} == 'refs/heads/master' ]]; then
time ${{ steps.vars.outputs.PYTHON_BIN }} scripts/dream.py --from_file tests/preflight_prompts.txt --full_precision
elif [[ ${{ github.ref }} == 'refs/heads/development' ]]; then
time ${{ steps.vars.outputs.PYTHON_BIN }} scripts/dream.py --from_file tests/dev_prompts.txt --full_precision
fi
mkdir -p outputs/img-samples
- name: Archive results
uses: actions/upload-artifact@v3
with:
name: results
path: outputs/img-samples

1
.gitignore vendored
View File

@ -1,6 +1,7 @@
# ignore default image save location and model symbolic link
outputs/
models/ldm/stable-diffusion-v1/model.ckpt
ldm/restoration/codeformer/weights
# ignore a directory which serves as a place for initial images
inputs/

131
README.md
View File

@ -1,16 +1,41 @@
<h1 align='center'><b>Stable Diffusion Dream Script</b></h1>
<div align="center">
# Stable Diffusion Dream Script
![project logo](docs/assets/logo.png)
<p align='center'>
<img src="docs/assets/logo.png"/>
<a href="https://discord.gg/ZmtBAhwWhy"><img src="docs/assets/join-us-on-discord-image.png"/></a>
</p>
<p align="center">
<img src="https://img.shields.io/github/last-commit/lstein/stable-diffusion?logo=Python&logoColor=green&style=for-the-badge" alt="last-commit"/>
<img src="https://img.shields.io/github/stars/lstein/stable-diffusion?logo=GitHub&style=for-the-badge" alt="stars"/>
<br>
<img src="https://img.shields.io/github/issues/lstein/stable-diffusion?logo=GitHub&style=for-the-badge" alt="issues"/>
<img src="https://img.shields.io/github/issues-pr/lstein/stable-diffusion?logo=GitHub&style=for-the-badge" alt="pull-requests"/>
</p>
# **Stable Diffusion Dream Script**
[![discord badge]][discord link]
[![latest release badge]][latest release link] [![github stars badge]][github stars link] [![github forks badge]][github forks link]
[![CI checks on main badge]][CI checks on main link] [![CI checks on dev badge]][CI checks on dev link] [![latest commit to dev badge]][latest commit to dev link]
[![github open issues badge]][github open issues link] [![github open prs badge]][github open prs link]
[CI checks on dev badge]: https://flat.badgen.net/github/checks/lstein/stable-diffusion/development?label=CI%20status%20on%20dev&cache=900&icon=github
[CI checks on dev link]: https://github.com/lstein/stable-diffusion/actions?query=branch%3Adevelopment
[CI checks on main badge]: https://flat.badgen.net/github/checks/lstein/stable-diffusion/main?label=CI%20status%20on%20main&cache=900&icon=github
[CI checks on main link]: https://github.com/lstein/stable-diffusion/actions/workflows/test-dream-conda.yml
[discord badge]: https://flat.badgen.net/discord/members/htRgbc7e?icon=discord
[discord link]: https://discord.com/invite/htRgbc7e
[github forks badge]: https://flat.badgen.net/github/forks/lstein/stable-diffusion?icon=github
[github forks link]: https://useful-forks.github.io/?repo=lstein%2Fstable-diffusion
[github open issues badge]: https://flat.badgen.net/github/open-issues/lstein/stable-diffusion?icon=github
[github open issues link]: https://github.com/lstein/stable-diffusion/issues?q=is%3Aissue+is%3Aopen
[github open prs badge]: https://flat.badgen.net/github/open-prs/lstein/stable-diffusion?icon=github
[github open prs link]: https://github.com/lstein/stable-diffusion/pulls?q=is%3Apr+is%3Aopen
[github stars badge]: https://flat.badgen.net/github/stars/lstein/stable-diffusion?icon=github
[github stars link]: https://github.com/lstein/stable-diffusion/stargazers
[latest commit to dev badge]: https://flat.badgen.net/github/last-commit/lstein/stable-diffusion/development?icon=github&color=yellow&label=last%20dev%20commit&cache=900
[latest commit to dev link]: https://github.com/lstein/stable-diffusion/commits/development
[latest release badge]: https://flat.badgen.net/github/release/lstein/stable-diffusion/development?icon=github
[latest release link]: https://github.com/lstein/stable-diffusion/releases
</div>
This is a fork of [CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion), the open
source text-to-image generator. It provides a streamlined process with various new features and
@ -21,7 +46,7 @@ _Note: This fork is rapidly evolving. Please use the
[Issues](https://github.com/lstein/stable-diffusion/issues) tab to report bugs and make feature
requests. Be sure to use the provided templates. They will help aid diagnose issues faster._
**Table of Contents**
## Table of Contents
1. [Installation](#installation)
2. [Hardware Requirements](#hardware-requirements)
@ -33,38 +58,38 @@ requests. Be sure to use the provided templates. They will help aid diagnose iss
8. [Support](#support)
9. [Further Reading](#further-reading)
## Installation
### Installation
This fork is supported across multiple platforms. You can find individual installation instructions
below.
- ### [Linux](docs/installation/INSTALL_LINUX.md)
- #### [Linux](docs/installation/INSTALL_LINUX.md)
- ### [Windows](docs/installation/INSTALL_WINDOWS.md)
- #### [Windows](docs/installation/INSTALL_WINDOWS.md)
- ### [Macintosh](docs/installation/INSTALL_MAC.md)
- #### [Macintosh](docs/installation/INSTALL_MAC.md)
## Hardware Requirements
### Hardware Requirements
**System**
#### System
You wil need one of the following:
- An NVIDIA-based graphics card with 4 GB or more VRAM memory.
- An Apple computer with an M1 chip.
**Memory**
#### Memory
- At least 12 GB Main Memory RAM.
**Disk**
#### Disk
- At least 6 GB of free disk space for the machine learning model, Python, and all its dependencies.
**Note**
If you are have a Nvidia 10xx series card (e.g. the 1080ti), please run the dream script in
full-precision mode as shown below.
> Note
>
> If you have an Nvidia 10xx series card (e.g. the 1080ti), please run the dream script in
> full-precision mode as shown below.
Similarly, specify full-precision mode on Apple M1 hardware.
@ -74,43 +99,31 @@ To run in full-precision mode, start `dream.py` with the `--full_precision` flag
(ldm) ~/stable-diffusion$ python scripts/dream.py --full_precision
```
## Features
### Features
### Major Features
#### Major Features
- #### [Interactive Command Line Interface](docs/features/CLI.md)
- [Interactive Command Line Interface](docs/features/CLI.md)
- [Image To Image](docs/features/IMG2IMG.md)
- [Inpainting Support](docs/features/INPAINTING.md)
- [GFPGAN and Real-ESRGAN Support](docs/features/UPSCALE.md)
- [Seamless Tiling](docs/features/OTHER.md#seamless-tiling)
- [Google Colab](docs/features/OTHER.md#google-colab)
- [Web Server](docs/features/WEB.md)
- [Reading Prompts From File](docs/features/PROMPTS.md#reading-prompts-from-a-file)
- [Shortcut: Reusing Seeds](docs/features/OTHER.md#shortcuts-reusing-seeds)
- [Weighted Prompts](docs/features/PROMPTS.md#weighted-prompts)
- [Negative/Unconditioned Prompts](docs/features/PROMPTS.md#negative-and-unconditioned-prompts)
- [Variations](docs/features/VARIATIONS.md)
- [Personalizing Text-to-Image Generation](docs/features/TEXTUAL_INVERSION.md)
- [Simplified API for text to image generation](docs/features/OTHER.md#simplified-api)
- #### [Image To Image](docs/features/IMG2IMG.md)
#### Other Features
- #### [Inpainting Support](docs/features/INPAINTING.md)
- [Creating Transparent Regions for Inpainting](docs/features/INPAINTING.md#creating-transparent-regions-for-inpainting)
- [Preload Models](docs/features/OTHER.md#preload-models)
- #### [GFPGAN and Real-ESRGAN Support](docs/features/UPSCALE.md)
- #### [Seamless Tiling](docs/features/OTHER.md#seamless-tiling)
- #### [Google Colab](docs/features/OTHER.md#google-colab)
- #### [Web Server](docs/features/WEB.md)
- #### [Reading Prompts From File](docs/features/OTHER.md#reading-prompts-from-a-file)
- #### [Shortcut: Reusing Seeds](docs/features/OTHER.md#shortcuts-reusing-seeds)
- #### [Weighted Prompts](docs/features/OTHER.md#weighted-prompts)
- #### [Variations](docs/features/VARIATIONS.md)
- #### [Personalizing Text-to-Image Generation](docs/features/TEXTUAL_INVERSION.md)
- #### [Simplified API for text to image generation](docs/features/OTHER.md#simplified-api)
### Other Features
- #### [Creating Transparent Regions for Inpainting](docs/features/INPAINTING.md#creating-transparent-regions-for-inpainting)
- #### [Preload Models](docs/features/OTHER.md#preload-models)
## Latest Changes
### Latest Changes
- v1.14 (11 September 2022)
@ -142,12 +155,12 @@ To run in full-precision mode, start `dream.py` with the `--full_precision` flag
For older changelogs, please visit the **[CHANGELOG](docs/features/CHANGELOG.md)**.
## Troubleshooting
### Troubleshooting
Please check out our **[Q&A](docs/help/TROUBLESHOOT.md)** to get solutions for common installation
problems and other issues.
## Contributing
### Contributing
Anyone who wishes to contribute to this project, whether documentation, features, bug fixes, code
cleanup, testing, or code reviews, is very much encouraged to do so. If you are unfamiliar with how
@ -159,13 +172,13 @@ important thing is to **make your pull request against the "development" branch*
"main". This will help keep public breakage to a minimum and will allow you to propose more radical
changes.
## Contributors
### Contributors
This fork is a combined effort of various people from across the world.
[Check out the list of all these amazing people](docs/other/CONTRIBUTORS.md). We thank them for
their time, hard work and effort.
## Support
### Support
For support, please use this repository's GitHub Issues tracking service. Feel free to send me an
email if you use and like the script.
@ -173,7 +186,7 @@ email if you use and like the script.
Original portions of the software are Copyright (c) 2020
[Lincoln D. Stein](https://github.com/lstein)
## Further Reading
### Further Reading
Please see the original README for more information on this software and underlying algorithm,
located in the file [README-CompViz.md](docs/other/README-CompViz.md).

View File

@ -40,6 +40,8 @@ def parameters_to_command(params):
switches.append(f'-I {params["init_img"]}')
if 'init_mask' in params and len(params['init_mask']) > 0:
switches.append(f'-M {params["init_mask"]}')
if 'init_color' in params and len(params['init_color']) > 0:
switches.append(f'--init_color {params["init_color"]}')
if 'strength' in params and 'init_img' in params:
switches.append(f'-f {params["strength"]}')
if 'fit' in params and params["fit"] == True:
@ -129,6 +131,11 @@ def create_cmd_parser():
type=str,
help='Path to input mask for inpainting mode (supersedes width and height)',
)
parser.add_argument(
'--init_color',
type=str,
help='Path to reference image for color correction (used for repeated img2img and inpainting)'
)
parser.add_argument(
'-T',
'-fit',

View File

@ -7,6 +7,8 @@ import eventlet
import glob
import shlex
import argparse
import math
import shutil
from flask_socketio import SocketIO
from flask import Flask, send_from_directory, url_for, jsonify
@ -15,6 +17,7 @@ from PIL import Image
from pytorch_lightning import logging
from threading import Event
from uuid import uuid4
from send2trash import send2trash
from ldm.gfpgan.gfpgan_tools import real_esrgan_upscale
from ldm.gfpgan.gfpgan_tools import run_gfpgan
@ -118,15 +121,15 @@ result_path = os.path.join(output_dir, 'img-samples/')
intermediate_path = os.path.join(result_path, 'intermediates/')
# path for user-uploaded init images and masks
init_path = os.path.join(result_path, 'init-images/')
mask_path = os.path.join(result_path, 'mask-images/')
init_image_path = os.path.join(result_path, 'init-images/')
mask_image_path = os.path.join(result_path, 'mask-images/')
# txt log
log_path = os.path.join(result_path, 'dream_log.txt')
# make all output paths
[os.makedirs(path, exist_ok=True)
for path in [result_path, intermediate_path, init_path, mask_path]]
for path in [result_path, intermediate_path, init_image_path, mask_image_path]]
"""
@ -154,7 +157,8 @@ def handle_request_all_images():
else:
metadata = all_metadata['sd-metadata']
image_array.append({'path': path, 'metadata': metadata})
return make_response("OK", data=image_array)
socketio.emit('galleryImages', {'images': image_array})
eventlet.sleep(0)
@socketio.on('generateImage')
@ -165,16 +169,32 @@ def handle_generate_image_event(generation_parameters, esrgan_parameters, gfpgan
esrgan_parameters,
gfpgan_parameters
)
return make_response("OK")
@socketio.on('runESRGAN')
def handle_run_esrgan_event(original_image, esrgan_parameters):
print(f'>> ESRGAN upscale requested for "{original_image["url"]}": {esrgan_parameters}')
progress = {
'currentStep': 1,
'totalSteps': 1,
'currentIteration': 1,
'totalIterations': 1,
'currentStatus': 'Preparing',
'isProcessing': True,
'currentStatusHasSteps': False
}
socketio.emit('progressUpdate', progress)
eventlet.sleep(0)
image = Image.open(original_image["url"])
seed = original_image['metadata']['seed'] if 'seed' in original_image['metadata'] else 'unknown_seed'
progress['currentStatus'] = 'Upscaling'
socketio.emit('progressUpdate', progress)
eventlet.sleep(0)
image = real_esrgan_upscale(
image=image,
upsampler_scale=esrgan_parameters['upscale'][0],
@ -182,24 +202,54 @@ def handle_run_esrgan_event(original_image, esrgan_parameters):
seed=seed
)
progress['currentStatus'] = 'Saving image'
socketio.emit('progressUpdate', progress)
eventlet.sleep(0)
esrgan_parameters['seed'] = seed
path = save_image(image, esrgan_parameters, result_path, postprocessing='esrgan')
command = parameters_to_command(esrgan_parameters)
write_log_message(f'[Upscaled] "{original_image["url"]}" > "{path}": {command}')
progress['currentStatus'] = 'Finished'
progress['currentStep'] = 0
progress['totalSteps'] = 0
progress['currentIteration'] = 0
progress['totalIterations'] = 0
progress['isProcessing'] = False
socketio.emit('progressUpdate', progress)
eventlet.sleep(0)
socketio.emit(
'result', {'url': os.path.relpath(path), 'type': 'esrgan', 'uuid': original_image['uuid'],'metadata': esrgan_parameters})
'esrganResult', {'url': os.path.relpath(path), 'uuid': original_image['uuid'], 'metadata': esrgan_parameters})
@socketio.on('runGFPGAN')
def handle_run_gfpgan_event(original_image, gfpgan_parameters):
print(f'>> GFPGAN face fix requested for "{original_image["url"]}": {gfpgan_parameters}')
progress = {
'currentStep': 1,
'totalSteps': 1,
'currentIteration': 1,
'totalIterations': 1,
'currentStatus': 'Preparing',
'isProcessing': True,
'currentStatusHasSteps': False
}
socketio.emit('progressUpdate', progress)
eventlet.sleep(0)
image = Image.open(original_image["url"])
seed = original_image['metadata']['seed'] if 'seed' in original_image['metadata'] else 'unknown_seed'
progress['currentStatus'] = 'Fixing faces'
socketio.emit('progressUpdate', progress)
eventlet.sleep(0)
image = run_gfpgan(
image=image,
strength=gfpgan_parameters['gfpgan_strength'],
@ -207,29 +257,42 @@ def handle_run_gfpgan_event(original_image, gfpgan_parameters):
upsampler_scale=1
)
progress['currentStatus'] = 'Saving image'
socketio.emit('progressUpdate', progress)
eventlet.sleep(0)
gfpgan_parameters['seed'] = seed
path = save_image(image, gfpgan_parameters, result_path, postprocessing='gfpgan')
command = parameters_to_command(gfpgan_parameters)
write_log_message(f'[Fixed faces] "{original_image["url"]}" > "{path}": {command}')
progress['currentStatus'] = 'Finished'
progress['currentStep'] = 0
progress['totalSteps'] = 0
progress['currentIteration'] = 0
progress['totalIterations'] = 0
progress['isProcessing'] = False
socketio.emit('progressUpdate', progress)
eventlet.sleep(0)
socketio.emit(
'result', {'url': os.path.relpath(path), 'type': 'gfpgan', 'uuid': original_image['uuid'],'metadata': gfpgan_parameters})
'gfpganResult', {'url': os.path.relpath(path), 'uuid': original_image['uuid'], 'metadata': gfpgan_parameters})
@socketio.on('cancel')
def handle_cancel():
print(f'>> Cancel processing requested')
canceled.set()
return make_response("OK")
socketio.emit('processingCanceled')
# TODO: I think this needs a safety mechanism.
@socketio.on('deleteImage')
def handle_delete_image(path):
def handle_delete_image(path, uuid):
print(f'>> Delete requested "{path}"')
Path(path).unlink()
return make_response("OK")
send2trash(path)
socketio.emit('imageDeleted', {'url': path, 'uuid': uuid})
# TODO: I think this needs a safety mechanism.
@ -239,11 +302,11 @@ def handle_upload_initial_image(bytes, name):
uuid = uuid4().hex
split = os.path.splitext(name)
name = f'{split[0]}.{uuid}{split[1]}'
file_path = os.path.join(init_path, name)
file_path = os.path.join(init_image_path, name)
os.makedirs(os.path.dirname(file_path), exist_ok=True)
newFile = open(file_path, "wb")
newFile.write(bytes)
return make_response("OK", data=file_path)
socketio.emit('initialImageUploaded', {'url': file_path, 'uuid': ''})
# TODO: I think this needs a safety mechanism.
@ -253,11 +316,11 @@ def handle_upload_mask_image(bytes, name):
uuid = uuid4().hex
split = os.path.splitext(name)
name = f'{split[0]}.{uuid}{split[1]}'
file_path = os.path.join(mask_path, name)
file_path = os.path.join(mask_image_path, name)
os.makedirs(os.path.dirname(file_path), exist_ok=True)
newFile = open(file_path, "wb")
newFile.write(bytes)
return make_response("OK", data=file_path)
socketio.emit('maskImageUploaded', {'url': file_path, 'uuid': ''})
@ -272,6 +335,13 @@ ADDITIONAL FUNCTIONS
"""
def make_unique_init_image_filename(name):
uuid = uuid4().hex
split = os.path.splitext(name)
name = f'{split[0]}.{uuid}{split[1]}'
return name
def write_log_message(message, log_path=log_path):
"""Logs the filename and parameters used to generate or process that image to log file"""
message = f'{message}\n'
@ -279,15 +349,6 @@ def write_log_message(message, log_path=log_path):
file.writelines(message)
def make_response(status, message=None, data=None):
response = {'status': status}
if message is not None:
response['message'] = message
if data is not None:
response['data'] = data
return response
def save_image(image, parameters, output_dir, step_index=None, postprocessing=False):
seed = parameters['seed'] if 'seed' in parameters else 'unknown_seed'
@ -309,16 +370,69 @@ def save_image(image, parameters, output_dir, step_index=None, postprocessing=Fa
return path
def calculate_real_steps(steps, strength, has_init_image):
return math.floor(strength * steps) if has_init_image else steps
def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters):
canceled.clear()
step_index = 1
"""
If a result image is used as an init image, and then deleted, we will want to be
able to use it as an init image in the future. Need to copy it.
If the init/mask image doesn't exist in the init_image_path/mask_image_path,
make a unique filename for it and copy it there.
"""
if ('init_img' in generation_parameters):
filename = os.path.basename(generation_parameters['init_img'])
if not os.path.exists(os.path.join(init_image_path, filename)):
unique_filename = make_unique_init_image_filename(filename)
new_path = os.path.join(init_image_path, unique_filename)
shutil.copy(generation_parameters['init_img'], new_path)
generation_parameters['init_img'] = new_path
if ('init_mask' in generation_parameters):
filename = os.path.basename(generation_parameters['init_mask'])
if not os.path.exists(os.path.join(mask_image_path, filename)):
unique_filename = make_unique_init_image_filename(filename)
new_path = os.path.join(init_image_path, unique_filename)
shutil.copy(generation_parameters['init_img'], new_path)
generation_parameters['init_mask'] = new_path
totalSteps = calculate_real_steps(
steps=generation_parameters['steps'],
strength=generation_parameters['strength'] if 'strength' in generation_parameters else None,
has_init_image='init_img' in generation_parameters
)
progress = {
'currentStep': 1,
'totalSteps': totalSteps,
'currentIteration': 1,
'totalIterations': generation_parameters['iterations'],
'currentStatus': 'Preparing',
'isProcessing': True,
'currentStatusHasSteps': False
}
socketio.emit('progressUpdate', progress)
eventlet.sleep(0)
def image_progress(sample, step):
if canceled.is_set():
raise CanceledException
nonlocal step_index
nonlocal generation_parameters
nonlocal progress
progress['currentStep'] = step + 1
progress['currentStatus'] = 'Generating'
progress['currentStatusHasSteps'] = True
if generation_parameters["progress_images"] and step % 5 == 0 and step < generation_parameters['steps'] - 1:
image = model.sample_to_image(sample)
path = save_image(image, generation_parameters, intermediate_path, step_index)
@ -326,18 +440,30 @@ def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters)
step_index += 1
socketio.emit('intermediateResult', {
'url': os.path.relpath(path), 'metadata': generation_parameters})
socketio.emit('progress', {'step': step + 1})
socketio.emit('progressUpdate', progress)
eventlet.sleep(0)
def image_done(image, seed):
nonlocal generation_parameters
nonlocal esrgan_parameters
nonlocal gfpgan_parameters
nonlocal progress
step_index = 1
progress['currentStatus'] = 'Generation complete'
socketio.emit('progressUpdate', progress)
eventlet.sleep(0)
all_parameters = generation_parameters
postprocessing = False
if esrgan_parameters:
progress['currentStatus'] = 'Upscaling'
progress['currentStatusHasSteps'] = False
socketio.emit('progressUpdate', progress)
eventlet.sleep(0)
image = real_esrgan_upscale(
image=image,
strength=esrgan_parameters['strength'],
@ -348,6 +474,11 @@ def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters)
all_parameters["upscale"] = [esrgan_parameters['level'], esrgan_parameters['strength']]
if gfpgan_parameters:
progress['currentStatus'] = 'Fixing faces'
progress['currentStatusHasSteps'] = False
socketio.emit('progressUpdate', progress)
eventlet.sleep(0)
image = run_gfpgan(
image=image,
strength=gfpgan_parameters['strength'],
@ -358,6 +489,9 @@ def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters)
all_parameters["gfpgan_strength"] = gfpgan_parameters['strength']
all_parameters['seed'] = seed
progress['currentStatus'] = 'Saving image'
socketio.emit('progressUpdate', progress)
eventlet.sleep(0)
path = save_image(image, all_parameters, result_path, postprocessing=postprocessing)
command = parameters_to_command(all_parameters)
@ -365,8 +499,24 @@ def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters)
print(f'Image generated: "{path}"')
write_log_message(f'[Generated] "{path}": {command}')
if (progress['totalIterations'] > progress['currentIteration']):
progress['currentStep'] = 1
progress['currentIteration'] +=1
progress['currentStatus'] = 'Iteration finished'
progress['currentStatusHasSteps'] = False
else:
progress['currentStep'] = 0
progress['totalSteps'] = 0
progress['currentIteration'] = 0
progress['totalIterations'] = 0
progress['currentStatus'] = 'Finished'
progress['isProcessing'] = False
socketio.emit('progressUpdate', progress)
eventlet.sleep(0)
socketio.emit(
'result', {'url': os.path.relpath(path), 'type': 'generation', 'metadata': all_parameters})
'generationResult', {'url': os.path.relpath(path), 'metadata': all_parameters})
eventlet.sleep(0)
try:
@ -381,7 +531,7 @@ def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters)
except CanceledException:
pass
except Exception as e:
socketio.emit('error', (str(e)))
socketio.emit('error', {'message': (str(e))})
print("\n")
traceback.print_exc()
print("\n")

Binary file not shown.

After

Width:  |  Height:  |  Size: 25 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 451 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 453 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 463 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 435 KiB

BIN
docs/assets/step1.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 503 KiB

BIN
docs/assets/step2.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 KiB

BIN
docs/assets/step4.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.3 KiB

BIN
docs/assets/step5.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.6 KiB

BIN
docs/assets/step6.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 395 KiB

BIN
docs/assets/step7.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1014 KiB

View File

@ -154,13 +154,19 @@ vary greatly depending on what is in the image. We also ask to --fit the image i
than 640x480. Otherwise the image size will be identical to the provided photo and you may run out
of memory if it is large.
Repeated chaining of img2img on an image can result in significant color shifts
in the output, especially if run with lower strength. Color correction can be
run against a reference image to fix this issue. Use the original input image to the
chain as the the reference image for each step in the chain.
In addition to the command-line options recognized by txt2img, img2img accepts additional options:
| Argument | Shortcut | Default | Description |
| ------------------ | --------- | ------- | ------------------------------------------------------------------------------------------------------------------------------------------ |
| --init_img <path> | -I<path> | None | Path to the initialization image |
| --fit | -F | False | Scale the image to fit into the specified -H and -W dimensions |
| --strength <float> | -s<float> | 0.75 | How hard to try to match the prompt to the initial image. Ranges from 0.0-0.99, with higher values replacing the initial image completely. |
| --init_img <path> | -I<path> | None | Path to the initialization image |
| --init_color <path> | | None | Path to reference image for color correction |
| --fit | -F | False | Scale the image to fit into the specified -H and -W dimensions |
| --strength <float> | -s<float> | 0.75 | How hard to try to match the prompt to the initial image. Ranges from 0.0-0.99, with higher values replacing the initial image completely. |
### This is an example of inpainting

View File

@ -37,5 +37,44 @@ We are hoping to get rid of the need for this workaround in an upcoming release.
5. Open the Layers toolbar (^L) and select "Floating Selection"
6. Set opacity to 0%
7. Export as PNG
8. In the export dialogue, Make sure the "Save colour values from
transparent pixels" checkbox is selected.
## Recipe for Adobe Photoshop
1. Open image in Photoshop
<p align='left'>
<img src="../assets/step1.png"/>
</p>
2. Use any of the selection tools (Marquee, Lasso, or Wand) to select the area you desire to inpaint.
<p align='left'>
<img src="../assets/step2.png"/>
</p>
3. Because we'll be applying a mask over the area we want to preserve, you should now select the inverse by using the Shift + Ctrl + I shortcut, or right clicking and using the "Select Inverse" option.
4. You'll now create a mask by selecting the image layer, and Masking the selection. Make sure that you don't delete any of the underlying image, or your inpainting results will be dramatically impacted.
<p align='left'>
<img src="../assets/step4.png"/>
</p>
5. Make sure to hide any background layers that are present. You should see the mask applied to your image layer, and the image on your canvas should display the checkered background.
<p align='left'>
<img src="../assets/step5.png"/>
</p>
<p align='left'>
<img src="../assets/step6.png"/>
</p>
6. Save the image as a transparent PNG by using the "Save a Copy" option in the File menu, or using the Alt + Ctrl + S keyboard shortcut.
7. After following the inpainting instructions above (either through the CLI or the Web UI), marvel at your newfound ability to selectively dream. Lookin' good!
<p align='left'>
<img src="../assets/step7.png"/>
</p>
8. In the export dialogue, Make sure the "Save colour values from transparent pixels" checkbox is
selected.

View File

@ -28,32 +28,6 @@ dream> "pond garden with lotus by claude monet" --seamless -s100 -n4
---
## **Reading Prompts from a File**
You can automate `dream.py` by providing a text file with the prompts you want to run, one line per
prompt. The text file must be composed with a text editor (e.g. Notepad) and not a word processor.
Each line should look like what you would type at the dream> prompt:
```bash
a beautiful sunny day in the park, children playing -n4 -C10
stormy weather on a mountain top, goats grazing -s100
innovative packaging for a squid's dinner -S137038382
```
Then pass this file's name to `dream.py` when you invoke it:
```bash
(ldm) ~/stable-diffusion$ python3 scripts/dream.py --from_file "path/to/prompts.txt"
```
You may read a series of prompts from standard input by providing a filename of `-`:
```bash
(ldm) ~/stable-diffusion$ echo "a beautiful day" | python3 scripts/dream.py --from_file -
```
---
## **Shortcuts: Reusing Seeds**
Since it is so common to reuse seeds while refining a prompt, there is now a shortcut as of version
@ -79,22 +53,6 @@ outputs/img-samples/000040.3498014304.png: "a cute child playing hopscotch" -G1.
---
## **Weighted Prompts**
You may weight different sections of the prompt to tell the sampler to attach different levels of
priority to them, by adding `:(number)` to the end of the section you wish to up- or downweight. For
example consider this prompt:
```bash
tabby cat:0.25 white duck:0.75 hybrid
```
This will tell the sampler to invest 25% of its effort on the tabby cat aspect of the image and 75%
on the white duck aspect (surprisingly, this example actually works). The prompt weights can use any
combination of integers and floating point numbers, and they do not need to add up to 1.
---
## **Simplified API**
For programmers who wish to incorporate stable-diffusion into other products, this repository

96
docs/features/PROMPTS.md Normal file
View File

@ -0,0 +1,96 @@
# Prompting Features
## **Reading Prompts from a File**
You can automate `dream.py` by providing a text file with the prompts you want to run, one line per
prompt. The text file must be composed with a text editor (e.g. Notepad) and not a word processor.
Each line should look like what you would type at the dream> prompt:
```bash
a beautiful sunny day in the park, children playing -n4 -C10
stormy weather on a mountain top, goats grazing -s100
innovative packaging for a squid's dinner -S137038382
```
Then pass this file's name to `dream.py` when you invoke it:
```bash
(ldm) ~/stable-diffusion$ python3 scripts/dream.py --from_file "path/to/prompts.txt"
```
You may read a series of prompts from standard input by providing a filename of `-`:
```bash
(ldm) ~/stable-diffusion$ echo "a beautiful day" | python3 scripts/dream.py --from_file -
```
---
## **Weighted Prompts**
You may weight different sections of the prompt to tell the sampler to attach different levels of
priority to them, by adding `:(number)` to the end of the section you wish to up- or downweight. For
example consider this prompt:
```bash
tabby cat:0.25 white duck:0.75 hybrid
```
This will tell the sampler to invest 25% of its effort on the tabby cat aspect of the image and 75%
on the white duck aspect (surprisingly, this example actually works). The prompt weights can use any
combination of integers and floating point numbers, and they do not need to add up to 1.
---
## **Negative and Unconditioned Prompts**
Any words between a pair of square brackets will try and be ignored by Stable Diffusion's model during generation of images.
```bash
this is a test prompt [not really] to make you understand [cool] how this works.
```
In the above statement, the words 'not really cool` will be ignored by Stable Diffusion.
Here's a prompt that depicts what it does.
original prompt:
```bash
"A fantastical translucent poney made of water and foam, ethereal, radiant, hyperalism, scottish folklore, digital painting, artstation, concept art, smooth, 8 k frostbite 3 engine, ultra detailed, art by artgerm and greg rutkowski and magali villeneuve" -s 20 -W 512 -H 768 -C 7.5 -A k_euler_a -S 1654590180
```
![step1](../assets/variation_walkthru/step1.png)
That image has a woman, so if we want the horse without a rider, we can influence the image not to have a woman by putting [woman] in the prompt, like this:
```bash
"A fantastical translucent poney made of water and foam, ethereal, radiant, hyperalism, scottish folklore, digital painting, artstation, concept art, smooth, 8 k frostbite 3 engine, ultra detailed, art by artgerm and greg rutkowski and magali villeneuve [woman]" -s 20 -W 512 -H 768 -C 7.5 -A k_euler_a -S 1654590180
```
![step2](../assets/variation_walkthru/step2.png)
That's nice - but say we also don't want the image to be quite so blue. We can add "blue" to the list of negative prompts, so it's now [woman blue]:
```bash
"A fantastical translucent poney made of water and foam, ethereal, radiant, hyperalism, scottish folklore, digital painting, artstation, concept art, smooth, 8 k frostbite 3 engine, ultra detailed, art by artgerm and greg rutkowski and magali villeneuve [woman blue]" -s 20 -W 512 -H 768 -C 7.5 -A k_euler_a -S 1654590180
```
![step3](../assets/variation_walkthru/step3.png)
Getting close - but there's no sense in having a saddle when our horse doesn't have a rider, so we'll add one more negative prompt: [woman blue saddle].
```bash
"A fantastical translucent poney made of water and foam, ethereal, radiant, hyperalism, scottish folklore, digital painting, artstation, concept art, smooth, 8 k frostbite 3 engine, ultra detailed, art by artgerm and greg rutkowski and magali villeneuve [woman blue saddle]" -s 20 -W 512 -H 768 -C 7.5 -A k_euler_a -S 1654590180
```
![step4](../assets/variation_walkthru/step4.png)
Notes about this feature:
* The only requirement for words to be ignored is that they are in between a pair of square brackets.
* You can provide multiple words within the same bracket.
* You can provide multiple brackets with multiple words in different places of your prompt. That works just fine.
* To improve typical anatomy problems, you can add negative prompts like [bad anatomy, extra legs, extra arms, extra fingers, poorly drawn hands, poorly drawn feet, disfigured, out of frame, tiling, bad art, deformed, mutated].

View File

@ -97,3 +97,39 @@ the base images.
If you wish to stop during the image generation but want to upscale or face restore a particular
generated image, pass it again with the same prompt and generated seed along with the `-U` and `-G`
prompt arguments to perform those actions.
## CodeFormer Support
This repo also allows you to perform face restoration using
[CodeFormer](https://github.com/sczhou/CodeFormer).
In order to setup CodeFormer to work, you need to download the models like with GFPGAN. You can do
this either by running `preload_models.py` or by manually downloading the
[model file](https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth) and
saving it to `ldm/restoration/codeformer/weights` folder.
You can use `-ft` prompt argument to swap between CodeFormer and the default GFPGAN. The above
mentioned `-G` prompt argument will allow you to control the strength of the restoration effect.
### **Usage:**
The following command will perform face restoration with CodeFormer instead of the default gfpgan.
`<prompt> -G 0.8 -ft codeformer`
**Other Options:**
- `-cf` - cf or CodeFormer Fidelity takes values between `0` and `1`. 0 produces high quality
results but low accuracy and 1 produces lower quality results but higher accuacy to your original
face.
The following command will perform face restoration with CodeFormer. CodeFormer will output a result
that is closely matching to the input face.
`<prompt> -G 1.0 -ft codeformer -cf 0.9`
The following command will perform face restoration with CodeFormer. CodeFormer will output a result
that is the best restoration possible. This may deviate slightly from the original face. This is an
excellent option to use in situations when there is very little facial data to work with.
`<prompt> -G 1.0 -ft codeformer -cf 0.1`

View File

@ -102,6 +102,7 @@ generate more variations around the almost-but-not-quite image. We do the
latter, using both the `-V` (combining) and `-v` (variation strength) options.
Note that we use `-n6` to generate 6 variations:
```bash
dream> "prompt" -S3357757885 -V3647897225,0.1,1614299449,0.1 -v0.05 -n6
Outputs:
./outputs/Xena/000004.3279757577.png: "prompt" -s50 -W512 -H512 -C7.5 -Ak_lms -V 3647897225:0.1,1614299449:0.1,3279757577:0.05 -S3357757885

View File

@ -7,10 +7,7 @@ title: macOS
- macOS 12.3 Monterey or later
- Python
- Patience
- Apple Silicon\*
\*I haven't tested any of this on Intel Macs but I have read that one person got
it to work, so Apple Silicon might not be requried.
- Apple Silicon or Intel Mac
Things have moved really fast and so these instructions change often and are
often out-of-date. One of the problems is that there are so many different ways
@ -59,9 +56,13 @@ First get the weights checkpoint download started - it's big:
# install python 3, git, cmake, protobuf:
brew install cmake protobuf rust
# install miniconda (M1 arm64 version):
curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh -o Miniconda3-latest-MacOSX-arm64.sh
/bin/bash Miniconda3-latest-MacOSX-arm64.sh
# install miniconda for M1 arm64:
curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh -o Miniconda3-latest-MacOSX-arm64.sh
/bin/bash Miniconda3-latest-MacOSX-arm64.sh
# OR install miniconda for Intel:
curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh -o Miniconda3-latest-MacOSX-x86_64.sh
/bin/bash Miniconda3-latest-MacOSX-x86_64.sh
# EITHER WAY,
@ -82,15 +83,22 @@ brew install cmake protobuf rust
ln -s "$PATH_TO_CKPT/sd-v1-4.ckpt" models/ldm/stable-diffusion-v1/model.ckpt
# install packages
PIP_EXISTS_ACTION=w CONDA_SUBDIR=osx-arm64 conda env create -f environment-mac.yaml
conda activate ldm
# install packages for arm64
PIP_EXISTS_ACTION=w CONDA_SUBDIR=osx-arm64 conda env create -f environment-mac.yaml
conda activate ldm
# OR install packages for x86_64
PIP_EXISTS_ACTION=w CONDA_SUBDIR=osx-x86_64 conda env create -f environment-mac.yaml
conda activate ldm
# only need to do this once
python scripts/preload_models.py
# run SD!
python scripts/dream.py --full_precision # half-precision requires autocast and won't work
# or run the web interface!
python scripts/dream.py --web
```
The original scripts should work as well.
@ -181,7 +189,12 @@ There are several causes of these errors.
- Third, if it says you're missing taming you need to rebuild your virtual
environment.
`conda env remove -n ldm conda env create -f environment-mac.yaml`
````bash
conda deactivate
conda env remove -n ldm
PIP_EXISTS_ACTION=w CONDA_SUBDIR=osx-arm64 conda env create -f environment-mac.yaml
```
Fourth, If you have activated the ldm virtual environment and tried rebuilding
it, maybe the problem could be that I have something installed that you don't

View File

@ -2,15 +2,16 @@
title: Contributors
---
The list of all the amazing people who have contributed to the various features that you get to experience in this fork.
The list of all the amazing people who have contributed to the various features that you get to
experience in this fork.
We thank them for all of their time and hard work.
## __Original Author:__
## **Original Author:**
- [Lincoln D. Stein](mailto:lincoln.stein@gmail.com)
## __Contributions by:__
## **Contributions by:**
- [Sean McLellan](https://github.com/Oceanswave)
- [Kevin Gibbons](https://github.com/bakkot)
@ -52,8 +53,9 @@ We thank them for all of their time and hard work.
- [Doggettx](https://github.com/doggettx)
- [Matthias Wild](https://github.com/mauwii)
- [Kyle Schouviller](https://github.com/kyle0654)
- [rabidcopy](https://github.com/rabidcopy)
## __Original CompVis Authors:__
## **Original CompVis Authors:**
- [Robin Rombach](https://github.com/rromb)
- [Patrick von Platen](https://github.com/patrickvonplaten)
@ -65,4 +67,5 @@ We thank them for all of their time and hard work.
---
_If you have contributed and don't see your name on the list of contributors, please let one of the collaborators know about the omission, or feel free to make a pull request._
_If you have contributed and don't see your name on the list of contributors, please let one of the
collaborators know about the omission, or feel free to make a pull request._

View File

@ -48,6 +48,7 @@ dependencies:
- opencv-python==4.6.0
- protobuf==3.20.1
- realesrgan==0.2.5.0
- send2trash==1.8.0
- test-tube==0.7.5
- transformers==4.21.2
- torch-fidelity==0.3.0

View File

@ -20,7 +20,8 @@ dependencies:
- realesrgan==0.2.5.0
- test-tube>=0.7.5
- streamlit==1.12.0
- pillow==6.2.0
- send2trash==1.8.0
- pillow==9.2.0
- einops==0.3.0
- torch-fidelity==0.3.0
- transformers==4.19.2

View File

@ -1,85 +1,37 @@
# Stable Diffusion Web UI
Demo at https://peaceful-otter-7a427f.netlify.app/ (not connected to back end)
## Run
much of this readme is just notes for myself during dev work
- `python backend/server.py` serves both frontend and backend at http://localhost:9090
numpy rand: 0 to 4294967295
## Evironment
## Test and Build
Install [node](https://nodejs.org/en/download/) (includes npm) and optionally
[yarn](https://yarnpkg.com/getting-started/install).
from `frontend/`:
From `frontend/` run `npm install` / `yarn install` to install the frontend packages.
- `yarn dev` runs `tsc-watch`, which runs `vite build` on successful `tsc` transpilation
## Dev
from `.`:
1. From `frontend/`, run `npm dev` / `yarn dev` to start the dev server.
2. Note the address it starts up on (probably `http://localhost:5173/`).
3. Edit `backend/server.py`'s `additional_allowed_origins` to include this address, e.g.
`additional_allowed_origins = ['http://localhost:5173']`.
4. Leaving the dev server running, open a new terminal and go to the project root.
5. Run `python backend/server.py`.
6. Navigate to the dev server address e.g. `http://localhost:5173/`.
- `python backend/server.py` serves both frontend and backend at http://localhost:9090
To build for dev: `npm build-dev` / `yarn build-dev`
## API
`backend/server.py` serves the UI and provides a [socket.io](https://github.com/socketio/socket.io) API via [flask-socketio](https://github.com/miguelgrinberg/flask-socketio).
### Server Listeners
The server listens for these socket.io events:
`cancel`
- Cancels in-progress image generation
- Returns ack only
`generateImage`
- Accepts object of image parameters
- Generates an image
- Returns ack only (image generation function sends progress and result via separate events)
`deleteImage`
- Accepts file path to image
- Deletes image
- Returns ack only
`deleteAllImages` WIP
- Deletes all images in `outputs/`
- Returns ack only
`requestAllImages`
- Returns array of all images in `outputs/`
`requestCapabilities` WIP
- Returns capabilities of server (torch device, GFPGAN and ESRGAN availability, ???)
`sendImage` WIP
- Accepts a File and attributes
- Saves image
- Used to save init images which are not generated images
### Server Emitters
`progress`
- Emitted during each step in generation
- Sends a number from 0 to 1 representing percentage of steps completed
`result` WIP
- Emitted when an image generation has completed
- Sends a object:
```
{
url: relative_file_path,
metadata: image_metadata_object
}
```
To build for production: `npm build` / `yarn build`
## TODO
- Search repo for "TODO"
- My one gripe with Chakra: no way to disable all animations right now and drop the dependence on `framer-motion`. I would prefer to save the ~30kb on bundle and have zero animations. This is on the Chakra roadmap. See https://github.com/chakra-ui/chakra-ui/pull/6368 for last discussion on this. Need to check in on this issue periodically.
- Search repo for "TODO"
- My one gripe with Chakra: no way to disable all animations right now and drop the dependence on
`framer-motion`. I would prefer to save the ~30kb on bundle and have zero animations. This is on
the Chakra roadmap. See https://github.com/chakra-ui/chakra-ui/pull/6368 for last discussion on
this. Need to check in on this issue periodically.
- Mobile friendly layout
- Proper image gallery/viewer/manager
- Help tooltips and such

File diff suppressed because one or more lines are too long

694
frontend/dist/assets/index.de730902.js vendored Normal file

File diff suppressed because one or more lines are too long

View File

@ -4,7 +4,7 @@
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Stable Diffusion Dream Server</title>
<script type="module" crossorigin src="/assets/index.cc5cde43.js"></script>
<script type="module" crossorigin src="/assets/index.de730902.js"></script>
<link rel="stylesheet" href="/assets/index.447eb2a9.css">
</head>
<body>

View File

@ -4,8 +4,7 @@
"version": "0.0.0",
"type": "module",
"scripts": {
"dev": "tsc-watch --onSuccess 'yarn run vite build -m development'",
"hmr": "vite dev",
"dev": "vite dev",
"build": "tsc && vite build",
"build-dev": "tsc && vite build -m development",
"preview": "vite preview"

View File

@ -1,60 +0,0 @@
import { Grid, GridItem } from '@chakra-ui/react';
import CurrentImage from './features/gallery/CurrentImage';
import LogViewer from './features/system/LogViewer';
import PromptInput from './features/sd/PromptInput';
import ProgressBar from './features/header/ProgressBar';
import { useEffect } from 'react';
import { useAppDispatch } from './app/hooks';
import { requestAllImages } from './app/socketio';
import ProcessButtons from './features/sd/ProcessButtons';
import ImageRoll from './features/gallery/ImageRoll';
import SiteHeader from './features/header/SiteHeader';
import OptionsAccordion from './features/sd/OptionsAccordion';
const App = () => {
const dispatch = useAppDispatch();
useEffect(() => {
dispatch(requestAllImages());
}, [dispatch]);
return (
<>
<Grid
width='100vw'
height='100vh'
templateAreas={`
"header header header header"
"progressBar progressBar progressBar progressBar"
"menu prompt processButtons imageRoll"
"menu currentImage currentImage imageRoll"`}
gridTemplateRows={'36px 10px 100px auto'}
gridTemplateColumns={'350px auto 100px 388px'}
gap={2}
>
<GridItem area={'header'} pt={1}>
<SiteHeader />
</GridItem>
<GridItem area={'progressBar'}>
<ProgressBar />
</GridItem>
<GridItem pl='2' area={'menu'} overflowY='scroll'>
<OptionsAccordion />
</GridItem>
<GridItem area={'prompt'}>
<PromptInput />
</GridItem>
<GridItem area={'processButtons'}>
<ProcessButtons />
</GridItem>
<GridItem area={'currentImage'}>
<CurrentImage />
</GridItem>
<GridItem pr='2' area={'imageRoll'} overflowY='scroll'>
<ImageRoll />
</GridItem>
</Grid>
<LogViewer />
</>
);
};
export default App;

68
frontend/src/app/App.tsx Normal file
View File

@ -0,0 +1,68 @@
import { Grid, GridItem } from '@chakra-ui/react';
import { useEffect, useState } from 'react';
import CurrentImageDisplay from '../features/gallery/CurrentImageDisplay';
import ImageGallery from '../features/gallery/ImageGallery';
import ProgressBar from '../features/header/ProgressBar';
import SiteHeader from '../features/header/SiteHeader';
import OptionsAccordion from '../features/sd/OptionsAccordion';
import ProcessButtons from '../features/sd/ProcessButtons';
import PromptInput from '../features/sd/PromptInput';
import LogViewer from '../features/system/LogViewer';
import Loading from '../Loading';
import { useAppDispatch } from './store';
import { requestAllImages } from './socketio/actions';
const App = () => {
const dispatch = useAppDispatch();
const [isReady, setIsReady] = useState<boolean>(false);
// Load images from the gallery once
useEffect(() => {
dispatch(requestAllImages());
setIsReady(true);
}, [dispatch]);
return isReady ? (
<>
<Grid
width="100vw"
height="100vh"
templateAreas={`
"header header header header"
"progressBar progressBar progressBar progressBar"
"menu prompt processButtons imageRoll"
"menu currentImage currentImage imageRoll"`}
gridTemplateRows={'36px 10px 100px auto'}
gridTemplateColumns={'350px auto 100px 388px'}
gap={2}
>
<GridItem area={'header'} pt={1}>
<SiteHeader />
</GridItem>
<GridItem area={'progressBar'}>
<ProgressBar />
</GridItem>
<GridItem pl="2" area={'menu'} overflowY="scroll">
<OptionsAccordion />
</GridItem>
<GridItem area={'prompt'}>
<PromptInput />
</GridItem>
<GridItem area={'processButtons'}>
<ProcessButtons />
</GridItem>
<GridItem area={'currentImage'}>
<CurrentImageDisplay />
</GridItem>
<GridItem pr="2" area={'imageRoll'} overflowY="scroll">
<ImageGallery />
</GridItem>
</Grid>
<LogViewer />
</>
) : (
<Loading />
);
};
export default App;

View File

@ -2,52 +2,52 @@
// Valid samplers
export const SAMPLERS: Array<string> = [
'ddim',
'plms',
'k_lms',
'k_dpm_2',
'k_dpm_2_a',
'k_euler',
'k_euler_a',
'k_heun',
'ddim',
'plms',
'k_lms',
'k_dpm_2',
'k_dpm_2_a',
'k_euler',
'k_euler_a',
'k_heun',
];
// Valid image widths
export const WIDTHS: Array<number> = [
64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960,
1024,
64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960,
1024,
];
// Valid image heights
export const HEIGHTS: Array<number> = [
64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960,
1024,
64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960,
1024,
];
// Valid upscaling levels
export const UPSCALING_LEVELS: Array<{ key: string; value: number }> = [
{ key: '2x', value: 2 },
{ key: '4x', value: 4 },
{ key: '2x', value: 2 },
{ key: '4x', value: 4 },
];
// Internal to human-readable parameters
export const PARAMETERS: { [key: string]: string } = {
prompt: 'Prompt',
iterations: 'Iterations',
steps: 'Steps',
cfgScale: 'CFG Scale',
height: 'Height',
width: 'Width',
sampler: 'Sampler',
seed: 'Seed',
img2imgStrength: 'img2img Strength',
gfpganStrength: 'GFPGAN Strength',
upscalingLevel: 'Upscaling Level',
upscalingStrength: 'Upscaling Strength',
initialImagePath: 'Initial Image',
maskPath: 'Initial Image Mask',
shouldFitToWidthHeight: 'Fit Initial Image',
seamless: 'Seamless Tiling',
prompt: 'Prompt',
iterations: 'Iterations',
steps: 'Steps',
cfgScale: 'CFG Scale',
height: 'Height',
width: 'Width',
sampler: 'Sampler',
seed: 'Seed',
img2imgStrength: 'img2img Strength',
gfpganStrength: 'GFPGAN Strength',
upscalingLevel: 'Upscaling Level',
upscalingStrength: 'Upscaling Strength',
initialImagePath: 'Initial Image',
maskPath: 'Initial Image Mask',
shouldFitToWidthHeight: 'Fit Initial Image',
seamless: 'Seamless Tiling',
};
export const NUMPY_RAND_MIN = 0;

View File

@ -1,7 +0,0 @@
import { useDispatch, useSelector } from 'react-redux';
import type { TypedUseSelectorHook } from 'react-redux';
import type { RootState, AppDispatch } from './store';
// Use throughout your app instead of plain `useDispatch` and `useSelector`
export const useAppDispatch: () => AppDispatch = useDispatch;
export const useAppSelector: TypedUseSelectorHook<RootState> = useSelector;

View File

@ -1,393 +0,0 @@
import { createAction, Middleware } from '@reduxjs/toolkit';
import { io } from 'socket.io-client';
import {
addImage,
clearIntermediateImage,
removeImage,
SDImage,
SDMetadata,
setGalleryImages,
setIntermediateImage,
} from '../features/gallery/gallerySlice';
import {
addLogEntry,
setCurrentStep,
setIsConnected,
setIsProcessing,
} from '../features/system/systemSlice';
import { v4 as uuidv4 } from 'uuid';
import { setInitialImagePath, setMaskPath } from '../features/sd/sdSlice';
import {
backendToFrontendParameters,
frontendToBackendParameters,
} from './parameterTranslation';
export interface SocketIOResponse {
status: 'OK' | 'ERROR';
message?: string;
data?: any;
}
export const socketioMiddleware = () => {
const { hostname, port } = new URL(window.location.href);
const socketio = io(`http://${hostname}:9090`);
let areListenersSet = false;
const middleware: Middleware = (store) => (next) => (action) => {
const { dispatch, getState } = store;
if (!areListenersSet) {
// CONNECT
socketio.on('connect', () => {
try {
dispatch(setIsConnected(true));
} catch (e) {
console.error(e);
}
});
// DISCONNECT
socketio.on('disconnect', () => {
try {
dispatch(setIsConnected(false));
dispatch(setIsProcessing(false));
dispatch(addLogEntry(`Disconnected from server`));
} catch (e) {
console.error(e);
}
});
// PROCESSING RESULT
socketio.on(
'result',
(data: {
url: string;
type: 'generation' | 'esrgan' | 'gfpgan';
uuid?: string;
metadata: { [key: string]: any };
}) => {
try {
const newUuid = uuidv4();
const { type, url, uuid, metadata } = data;
switch (type) {
case 'generation': {
const translatedMetadata =
backendToFrontendParameters(metadata);
dispatch(
addImage({
uuid: newUuid,
url,
metadata: translatedMetadata,
})
);
dispatch(
addLogEntry(`Image generated: ${url}`)
);
break;
}
case 'esrgan': {
const originalImage =
getState().gallery.images.find(
(i: SDImage) => i.uuid === uuid
);
const newMetadata = {
...originalImage.metadata,
};
newMetadata.shouldRunESRGAN = true;
newMetadata.upscalingLevel =
metadata.upscale[0];
newMetadata.upscalingStrength =
metadata.upscale[1];
dispatch(
addImage({
uuid: newUuid,
url,
metadata: newMetadata,
})
);
dispatch(
addLogEntry(`ESRGAN upscaled: ${url}`)
);
break;
}
case 'gfpgan': {
const originalImage =
getState().gallery.images.find(
(i: SDImage) => i.uuid === uuid
);
const newMetadata = {
...originalImage.metadata,
};
newMetadata.shouldRunGFPGAN = true;
newMetadata.gfpganStrength =
metadata.gfpgan_strength;
dispatch(
addImage({
uuid: newUuid,
url,
metadata: newMetadata,
})
);
dispatch(
addLogEntry(`GFPGAN fixed faces: ${url}`)
);
break;
}
}
dispatch(setIsProcessing(false));
} catch (e) {
console.error(e);
}
}
);
// PROGRESS UPDATE
socketio.on('progress', (data: { step: number }) => {
try {
dispatch(setIsProcessing(true));
dispatch(setCurrentStep(data.step));
} catch (e) {
console.error(e);
}
});
// INTERMEDIATE IMAGE
socketio.on(
'intermediateResult',
(data: { url: string; metadata: SDMetadata }) => {
try {
const uuid = uuidv4();
const { url, metadata } = data;
dispatch(
setIntermediateImage({
uuid,
url,
metadata,
})
);
dispatch(
addLogEntry(`Intermediate image generated: ${url}`)
);
} catch (e) {
console.error(e);
}
}
);
// ERROR FROM BACKEND
socketio.on('error', (message) => {
try {
dispatch(addLogEntry(`Server error: ${message}`));
dispatch(setIsProcessing(false));
dispatch(clearIntermediateImage());
} catch (e) {
console.error(e);
}
});
areListenersSet = true;
}
// HANDLE ACTIONS
switch (action.type) {
// GENERATE IMAGE
case 'socketio/generateImage': {
dispatch(setIsProcessing(true));
dispatch(setCurrentStep(-1));
const {
generationParameters,
esrganParameters,
gfpganParameters,
} = frontendToBackendParameters(
getState().sd,
getState().system
);
socketio.emit(
'generateImage',
generationParameters,
esrganParameters,
gfpganParameters
);
dispatch(
addLogEntry(
`Image generation requested: ${JSON.stringify({
...generationParameters,
...esrganParameters,
...gfpganParameters,
})}`
)
);
break;
}
// RUN ESRGAN (UPSCALING)
case 'socketio/runESRGAN': {
const imageToProcess = action.payload;
dispatch(setIsProcessing(true));
dispatch(setCurrentStep(-1));
const { upscalingLevel, upscalingStrength } = getState().sd;
const esrganParameters = {
upscale: [upscalingLevel, upscalingStrength],
};
socketio.emit('runESRGAN', imageToProcess, esrganParameters);
dispatch(
addLogEntry(
`ESRGAN upscale requested: ${JSON.stringify({
file: imageToProcess.url,
...esrganParameters,
})}`
)
);
break;
}
// RUN GFPGAN (FIX FACES)
case 'socketio/runGFPGAN': {
const imageToProcess = action.payload;
dispatch(setIsProcessing(true));
dispatch(setCurrentStep(-1));
const { gfpganStrength } = getState().sd;
const gfpganParameters = {
gfpgan_strength: gfpganStrength,
};
socketio.emit('runGFPGAN', imageToProcess, gfpganParameters);
dispatch(
addLogEntry(
`GFPGAN fix faces requested: ${JSON.stringify({
file: imageToProcess.url,
...gfpganParameters,
})}`
)
);
break;
}
// DELETE IMAGE
case 'socketio/deleteImage': {
const imageToDelete = action.payload;
const { url } = imageToDelete;
socketio.emit(
'deleteImage',
url,
(response: SocketIOResponse) => {
if (response.status === 'OK') {
dispatch(removeImage(imageToDelete));
dispatch(addLogEntry(`Image deleted: ${url}`));
}
}
);
break;
}
// GET ALL IMAGES FOR GALLERY
case 'socketio/requestAllImages': {
socketio.emit(
'requestAllImages',
(response: SocketIOResponse) => {
dispatch(setGalleryImages(response.data));
dispatch(
addLogEntry(`Loaded ${response.data.length} images`)
);
}
);
break;
}
// CANCEL PROCESSING
case 'socketio/cancelProcessing': {
socketio.emit('cancel', (response: SocketIOResponse) => {
const { intermediateImage } = getState().gallery;
if (response.status === 'OK') {
dispatch(setIsProcessing(false));
if (intermediateImage) {
dispatch(addImage(intermediateImage));
dispatch(
addLogEntry(
`Intermediate image saved: ${intermediateImage.url}`
)
);
dispatch(clearIntermediateImage());
}
dispatch(addLogEntry(`Processing canceled`));
}
});
break;
}
// UPLOAD INITIAL IMAGE
case 'socketio/uploadInitialImage': {
const file = action.payload;
socketio.emit(
'uploadInitialImage',
file,
file.name,
(response: SocketIOResponse) => {
if (response.status === 'OK') {
dispatch(setInitialImagePath(response.data));
dispatch(
addLogEntry(
`Initial image uploaded: ${response.data}`
)
);
}
}
);
break;
}
// UPLOAD MASK IMAGE
case 'socketio/uploadMaskImage': {
const file = action.payload;
socketio.emit(
'uploadMaskImage',
file,
file.name,
(response: SocketIOResponse) => {
if (response.status === 'OK') {
dispatch(setMaskPath(response.data));
dispatch(
addLogEntry(
`Mask image uploaded: ${response.data}`
)
);
}
}
);
break;
}
}
next(action);
};
return middleware;
};
// Actions to be used by app
export const generateImage = createAction<undefined>('socketio/generateImage');
export const runESRGAN = createAction<SDImage>('socketio/runESRGAN');
export const runGFPGAN = createAction<SDImage>('socketio/runGFPGAN');
export const deleteImage = createAction<SDImage>('socketio/deleteImage');
export const requestAllImages = createAction<undefined>(
'socketio/requestAllImages'
);
export const cancelProcessing = createAction<undefined>(
'socketio/cancelProcessing'
);
export const uploadInitialImage = createAction<File>(
'socketio/uploadInitialImage'
);
export const uploadMaskImage = createAction<File>('socketio/uploadMaskImage');

View File

@ -0,0 +1,24 @@
import { createAction } from '@reduxjs/toolkit';
import { SDImage } from '../../features/gallery/gallerySlice';
/**
* We can't use redux-toolkit's createSlice() to make these actions,
* because they have no associated reducer. They only exist to dispatch
* requests to the server via socketio. These actions will be handled
* by the middleware.
*/
export const generateImage = createAction<undefined>('socketio/generateImage');
export const runESRGAN = createAction<SDImage>('socketio/runESRGAN');
export const runGFPGAN = createAction<SDImage>('socketio/runGFPGAN');
export const deleteImage = createAction<SDImage>('socketio/deleteImage');
export const requestAllImages = createAction<undefined>(
'socketio/requestAllImages'
);
export const cancelProcessing = createAction<undefined>(
'socketio/cancelProcessing'
);
export const uploadInitialImage = createAction<File>(
'socketio/uploadInitialImage'
);
export const uploadMaskImage = createAction<File>('socketio/uploadMaskImage');

View File

@ -0,0 +1,101 @@
import { AnyAction, Dispatch, MiddlewareAPI } from '@reduxjs/toolkit';
import dateFormat from 'dateformat';
import { Socket } from 'socket.io-client';
import { frontendToBackendParameters } from '../../common/util/parameterTranslation';
import { SDImage } from '../../features/gallery/gallerySlice';
import {
addLogEntry,
setIsProcessing,
} from '../../features/system/systemSlice';
/**
* Returns an object containing all functions which use `socketio.emit()`.
* i.e. those which make server requests.
*/
const makeSocketIOEmitters = (
store: MiddlewareAPI<Dispatch<AnyAction>, any>,
socketio: Socket
) => {
// We need to dispatch actions to redux and get pieces of state from the store.
const { dispatch, getState } = store;
return {
emitGenerateImage: () => {
dispatch(setIsProcessing(true));
const { generationParameters, esrganParameters, gfpganParameters } =
frontendToBackendParameters(getState().sd, getState().system);
socketio.emit(
'generateImage',
generationParameters,
esrganParameters,
gfpganParameters
);
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Image generation requested: ${JSON.stringify({
...generationParameters,
...esrganParameters,
...gfpganParameters,
})}`,
})
);
},
emitRunESRGAN: (imageToProcess: SDImage) => {
dispatch(setIsProcessing(true));
const { upscalingLevel, upscalingStrength } = getState().sd;
const esrganParameters = {
upscale: [upscalingLevel, upscalingStrength],
};
socketio.emit('runESRGAN', imageToProcess, esrganParameters);
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `ESRGAN upscale requested: ${JSON.stringify({
file: imageToProcess.url,
...esrganParameters,
})}`,
})
);
},
emitRunGFPGAN: (imageToProcess: SDImage) => {
dispatch(setIsProcessing(true));
const { gfpganStrength } = getState().sd;
const gfpganParameters = {
gfpgan_strength: gfpganStrength,
};
socketio.emit('runGFPGAN', imageToProcess, gfpganParameters);
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `GFPGAN fix faces requested: ${JSON.stringify({
file: imageToProcess.url,
...gfpganParameters,
})}`,
})
);
},
emitDeleteImage: (imageToDelete: SDImage) => {
const { url, uuid } = imageToDelete;
socketio.emit('deleteImage', url, uuid);
},
emitRequestAllImages: () => {
socketio.emit('requestAllImages');
},
emitCancelProcessing: () => {
socketio.emit('cancel');
},
emitUploadInitialImage: (file: File) => {
socketio.emit('uploadInitialImage', file, file.name);
},
emitUploadMaskImage: (file: File) => {
socketio.emit('uploadMaskImage', file, file.name);
},
};
};
export default makeSocketIOEmitters;

View File

@ -0,0 +1,338 @@
import { AnyAction, MiddlewareAPI, Dispatch } from '@reduxjs/toolkit';
import { v4 as uuidv4 } from 'uuid';
import dateFormat from 'dateformat';
import {
addLogEntry,
setIsConnected,
setIsProcessing,
SystemStatus,
setSystemStatus,
setCurrentStatus,
} from '../../features/system/systemSlice';
import type {
ServerGenerationResult,
ServerESRGANResult,
ServerGFPGANResult,
ServerIntermediateResult,
ServerError,
ServerGalleryImages,
ServerImageUrlAndUuid,
ServerImageUrl,
} from './types';
import { backendToFrontendParameters } from '../../common/util/parameterTranslation';
import {
addImage,
clearIntermediateImage,
removeImage,
SDImage,
setGalleryImages,
setIntermediateImage,
} from '../../features/gallery/gallerySlice';
import { setInitialImagePath, setMaskPath } from '../../features/sd/sdSlice';
/**
* Returns an object containing listener callbacks for socketio events.
* TODO: This file is large, but simple. Should it be split up further?
*/
const makeSocketIOListeners = (
store: MiddlewareAPI<Dispatch<AnyAction>, any>
) => {
const { dispatch, getState } = store;
return {
/**
* Callback to run when we receive a 'connect' event.
*/
onConnect: () => {
try {
dispatch(setIsConnected(true));
dispatch(setCurrentStatus('Connected'));
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive a 'disconnect' event.
*/
onDisconnect: () => {
try {
dispatch(setIsConnected(false));
dispatch(setIsProcessing(false));
dispatch(setCurrentStatus('Disconnected'));
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Disconnected from server`,
level: 'warning',
})
);
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive a 'generationResult' event.
*/
onGenerationResult: (data: ServerGenerationResult) => {
try {
const { url, metadata } = data;
const newUuid = uuidv4();
const translatedMetadata = backendToFrontendParameters(metadata);
dispatch(
addImage({
uuid: newUuid,
url,
metadata: translatedMetadata,
})
);
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Image generated: ${url}`,
})
);
dispatch(setIsProcessing(false));
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive a 'intermediateResult' event.
*/
onIntermediateResult: (data: ServerIntermediateResult) => {
try {
const uuid = uuidv4();
const { url, metadata } = data;
dispatch(
setIntermediateImage({
uuid,
url,
metadata,
})
);
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Intermediate image generated: ${url}`,
})
);
dispatch(setIsProcessing(false));
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive an 'esrganResult' event.
*/
onESRGANResult: (data: ServerESRGANResult) => {
try {
const { url, uuid, metadata } = data;
const newUuid = uuidv4();
// This image was only ESRGAN'd, grab the original image's metadata
const originalImage = getState().gallery.images.find(
(i: SDImage) => i.uuid === uuid
);
// Retain the original metadata
const newMetadata = {
...originalImage.metadata,
};
// Update the ESRGAN-related fields
newMetadata.shouldRunESRGAN = true;
newMetadata.upscalingLevel = metadata.upscale[0];
newMetadata.upscalingStrength = metadata.upscale[1];
dispatch(
addImage({
uuid: newUuid,
url,
metadata: newMetadata,
})
);
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Upscaled: ${url}`,
})
);
dispatch(setIsProcessing(false));
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive a 'gfpganResult' event.
*/
onGFPGANResult: (data: ServerGFPGANResult) => {
try {
const { url, uuid, metadata } = data;
const newUuid = uuidv4();
// This image was only GFPGAN'd, grab the original image's metadata
const originalImage = getState().gallery.images.find(
(i: SDImage) => i.uuid === uuid
);
// Retain the original metadata
const newMetadata = {
...originalImage.metadata,
};
// Update the GFPGAN-related fields
newMetadata.shouldRunGFPGAN = true;
newMetadata.gfpganStrength = metadata.gfpgan_strength;
dispatch(
addImage({
uuid: newUuid,
url,
metadata: newMetadata,
})
);
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Fixed faces: ${url}`,
})
);
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive a 'progressUpdate' event.
* TODO: Add additional progress phases
*/
onProgressUpdate: (data: SystemStatus) => {
try {
dispatch(setIsProcessing(true));
dispatch(setSystemStatus(data));
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive a 'progressUpdate' event.
*/
onError: (data: ServerError) => {
const { message, additionalData } = data;
if (additionalData) {
// TODO: handle more data than short message
}
try {
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Server error: ${message}`,
level: 'error',
})
);
dispatch(setIsProcessing(false));
dispatch(clearIntermediateImage());
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive a 'galleryImages' event.
*/
onGalleryImages: (data: ServerGalleryImages) => {
const { images } = data;
const preparedImages = images.map((image): SDImage => {
return {
uuid: uuidv4(),
url: image.path,
metadata: backendToFrontendParameters(image.metadata),
};
});
dispatch(setGalleryImages(preparedImages));
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Loaded ${images.length} images`,
})
);
},
/**
* Callback to run when we receive a 'processingCanceled' event.
*/
onProcessingCanceled: () => {
dispatch(setIsProcessing(false));
const { intermediateImage } = getState().gallery;
if (intermediateImage) {
dispatch(addImage(intermediateImage));
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Intermediate image saved: ${intermediateImage.url}`,
})
);
dispatch(clearIntermediateImage());
}
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Processing canceled`,
level: 'warning',
})
);
},
/**
* Callback to run when we receive a 'imageDeleted' event.
*/
onImageDeleted: (data: ServerImageUrlAndUuid) => {
const { url, uuid } = data;
dispatch(removeImage(uuid));
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Image deleted: ${url}`,
})
);
},
/**
* Callback to run when we receive a 'initialImageUploaded' event.
*/
onInitialImageUploaded: (data: ServerImageUrl) => {
const { url } = data;
dispatch(setInitialImagePath(url));
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Initial image uploaded: ${url}`,
})
);
},
/**
* Callback to run when we receive a 'maskImageUploaded' event.
*/
onMaskImageUploaded: (data: ServerImageUrl) => {
const { url } = data;
dispatch(setMaskPath(url));
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Mask image uploaded: ${url}`,
})
);
},
};
};
export default makeSocketIOListeners;

View File

@ -0,0 +1,157 @@
import { Middleware } from '@reduxjs/toolkit';
import { io } from 'socket.io-client';
import makeSocketIOListeners from './listeners';
import makeSocketIOEmitters from './emitters';
import type {
ServerGenerationResult,
ServerESRGANResult,
ServerGFPGANResult,
ServerIntermediateResult,
ServerError,
ServerGalleryImages,
ServerImageUrlAndUuid,
ServerImageUrl,
} from './types';
import { SystemStatus } from '../../features/system/systemSlice';
export const socketioMiddleware = () => {
const { hostname, port } = new URL(window.location.href);
const socketio = io(`http://${hostname}:9090`);
let areListenersSet = false;
const middleware: Middleware = (store) => (next) => (action) => {
const {
onConnect,
onDisconnect,
onError,
onESRGANResult,
onGFPGANResult,
onGenerationResult,
onIntermediateResult,
onProgressUpdate,
onGalleryImages,
onProcessingCanceled,
onImageDeleted,
onInitialImageUploaded,
onMaskImageUploaded,
} = makeSocketIOListeners(store);
const {
emitGenerateImage,
emitRunESRGAN,
emitRunGFPGAN,
emitDeleteImage,
emitRequestAllImages,
emitCancelProcessing,
emitUploadInitialImage,
emitUploadMaskImage,
} = makeSocketIOEmitters(store, socketio);
/**
* If this is the first time the middleware has been called (e.g. during store setup),
* initialize all our socket.io listeners.
*/
if (!areListenersSet) {
socketio.on('connect', () => onConnect());
socketio.on('disconnect', () => onDisconnect());
socketio.on('error', (data: ServerError) => onError(data));
socketio.on('generationResult', (data: ServerGenerationResult) =>
onGenerationResult(data)
);
socketio.on('esrganResult', (data: ServerESRGANResult) =>
onESRGANResult(data)
);
socketio.on('gfpganResult', (data: ServerGFPGANResult) =>
onGFPGANResult(data)
);
socketio.on('intermediateResult', (data: ServerIntermediateResult) =>
onIntermediateResult(data)
);
socketio.on('progressUpdate', (data: SystemStatus) =>
onProgressUpdate(data)
);
socketio.on('galleryImages', (data: ServerGalleryImages) =>
onGalleryImages(data)
);
socketio.on('processingCanceled', () => {
onProcessingCanceled();
});
socketio.on('imageDeleted', (data: ServerImageUrlAndUuid) => {
onImageDeleted(data);
});
socketio.on('initialImageUploaded', (data: ServerImageUrl) => {
onInitialImageUploaded(data);
});
socketio.on('maskImageUploaded', (data: ServerImageUrl) => {
onMaskImageUploaded(data);
});
areListenersSet = true;
}
/**
* Handle redux actions caught by middleware.
*/
switch (action.type) {
case 'socketio/generateImage': {
emitGenerateImage();
break;
}
case 'socketio/runESRGAN': {
emitRunESRGAN(action.payload);
break;
}
case 'socketio/runGFPGAN': {
emitRunGFPGAN(action.payload);
break;
}
case 'socketio/deleteImage': {
emitDeleteImage(action.payload);
break;
}
case 'socketio/requestAllImages': {
emitRequestAllImages();
break;
}
case 'socketio/cancelProcessing': {
emitCancelProcessing();
break;
}
case 'socketio/uploadInitialImage': {
emitUploadInitialImage(action.payload);
break;
}
case 'socketio/uploadMaskImage': {
emitUploadMaskImage(action.payload);
break;
}
}
next(action);
};
return middleware;
};

46
frontend/src/app/socketio/types.d.ts vendored Normal file
View File

@ -0,0 +1,46 @@
/**
* Interfaces used by the socketio middleware.
*/
export declare interface ServerGenerationResult {
url: string;
metadata: { [key: string]: any };
}
export declare interface ServerESRGANResult {
url: string;
uuid: string;
metadata: { [key: string]: any };
}
export declare interface ServerGFPGANResult {
url: string;
uuid: string;
metadata: { [key: string]: any };
}
export declare interface ServerIntermediateResult {
url: string;
metadata: { [key: string]: any };
}
export declare interface ServerError {
message: string;
additionalData?: string;
}
export declare interface ServerGalleryImages {
images: Array<{
path: string;
metadata: { [key: string]: any };
}>;
}
export declare interface ServerImageUrlAndUuid {
uuid: string;
url: string;
}
export declare interface ServerImageUrl {
url: string;
}

View File

@ -1,53 +1,78 @@
import { combineReducers, configureStore } from '@reduxjs/toolkit';
import { useDispatch, useSelector } from 'react-redux';
import type { TypedUseSelectorHook } from 'react-redux';
import { persistReducer } from 'redux-persist';
import storage from 'redux-persist/lib/storage'; // defaults to localStorage for web
import sdReducer from '../features/sd/sdSlice';
import galleryReducer from '../features/gallery/gallerySlice';
import systemReducer from '../features/system/systemSlice';
import { socketioMiddleware } from './socketio';
import { socketioMiddleware } from './socketio/middleware';
/**
* redux-persist provides an easy and reliable way to persist state across reloads.
*
* While we definitely want generation parameters to be persisted, there are a number
* of things we do *not* want to be persisted across reloads:
* - Gallery/selected image (user may add/delete images from disk between page loads)
* - Connection/processing status
* - Availability of external libraries like ESRGAN/GFPGAN
*
* These can be blacklisted in redux-persist.
*
* The necesssary nested persistors with blacklists are configured below.
*
* TODO: Do we blacklist initialImagePath? If the image is deleted from disk we get an
* ugly 404. But if we blacklist it, then this is a valuable parameter that is lost
* on reload. Need to figure out a good way to handle this.
*/
const rootPersistConfig = {
key: 'root',
storage,
blacklist: ['gallery', 'system'],
};
const systemPersistConfig = {
key: 'system',
storage,
blacklist: [
'isConnected',
'isProcessing',
'currentStep',
'socketId',
'isESRGANAvailable',
'isGFPGANAvailable',
'currentStep',
'totalSteps',
'currentIteration',
'totalIterations',
'currentStatus',
],
};
const reducers = combineReducers({
sd: sdReducer,
gallery: galleryReducer,
system: systemReducer,
system: persistReducer(systemPersistConfig, systemReducer),
});
const persistConfig = {
key: 'root',
storage,
};
const persistedReducer = persistReducer(persistConfig, reducers);
/*
The frontend needs to be distributed as a production build, so
we cannot reasonably ask users to edit the JS and specify the
host and port on which the socket.io server will run.
The solution is to allow server script to be run with arguments
(or just edited) providing the host and port. Then, the server
serves a route `/socketio_config` which responds with the host
and port.
When the frontend loads, it synchronously requests that route
and thus gets the host and port. This requires a suspicious
fetch somewhere, and the store setup seems like as good a place
as any to make this fetch request.
*/
const persistedReducer = persistReducer(rootPersistConfig, reducers);
// Continue with store setup
export const store = configureStore({
reducer: persistedReducer,
middleware: (getDefaultMiddleware) =>
getDefaultMiddleware({
// redux-persist sometimes needs to have a function in redux, need to disable this check
// redux-persist sometimes needs to temporarily put a function in redux state, need to disable this check
serializableCheck: false,
}).concat(socketioMiddleware()),
});
// Infer the `RootState` and `AppDispatch` types from the store itself
export type RootState = ReturnType<typeof store.getState>;
// Inferred type: {posts: PostsState, comments: CommentsState, users: UsersState}
export type AppDispatch = typeof store.dispatch;
// Use throughout your app instead of plain `useDispatch` and `useSelector`
export const useAppDispatch: () => AppDispatch = useDispatch;
export const useAppSelector: TypedUseSelectorHook<RootState> = useSelector;

View File

@ -0,0 +1,171 @@
/**
* Defines common parameters required to generate an image.
* See #266 for the eventual maturation of this interface.
*/
interface CommonParameters {
/**
* The "txt2img" prompt. String. Minimum one character. No maximum.
*/
prompt: string;
/**
* The number of sampler steps. Integer. Minimum value 1. No maximum.
*/
steps: number;
/**
* Classifier-free guidance scale. Float. Minimum value 0. Maximum?
*/
cfgScale: number;
/**
* Height of output image in pixels. Integer. Minimum 64. Must be multiple of 64. No maximum.
*/
height: number;
/**
* Width of output image in pixels. Integer. Minimum 64. Must be multiple of 64. No maximum.
*/
width: number;
/**
* Name of the sampler to use. String. Restricted values.
*/
sampler:
| 'ddim'
| 'plms'
| 'k_lms'
| 'k_dpm_2'
| 'k_dpm_2_a'
| 'k_euler'
| 'k_euler_a'
| 'k_heun';
/**
* Seed used for randomness. Integer. 0 --> 4294967295, inclusive.
*/
seed: number;
/**
* Flag to enable seamless tiling image generation. Boolean.
*/
seamless: boolean;
}
/**
* Defines parameters needed to use the "img2img" generation method.
*/
interface ImageToImageParameters {
/**
* Folder path to the image used as the initial image. String.
*/
initialImagePath: string;
/**
* Flag to enable the use of a mask image during "img2img" generations.
* Requires valid ImageToImageParameters. Boolean.
*/
shouldUseMaskImage: boolean;
/**
* Folder path to the image used as a mask image. String.
*/
maskImagePath: string;
/**
* Strength of adherance to initial image. Float. 0 --> 1, exclusive.
*/
img2imgStrength: number;
/**
* Flag to enable the stretching of init image to desired output. Boolean.
*/
shouldFit: boolean;
}
/**
* Defines the parameters needed to generate variations.
*/
interface VariationParameters {
/**
* Variation amount. Float. 0 --> 1, exclusive.
* TODO: What does this really do?
*/
variationAmount: number;
/**
* List of seed-weight pairs formatted as "seed:weight,...".
* Seed is a valid seed. Weight is a float, 0 --> 1, exclusive.
* String, must be parseable into [[seed,weight],...] format.
*/
seedWeights: string;
}
/**
* Defines the parameters needed to use GFPGAN postprocessing.
*/
interface GFPGANParameters {
/**
* GFPGAN strength. Strength to apply face-fixing processing. Float. 0 --> 1, exclusive.
*/
gfpganStrength: number;
}
/**
* Defines the parameters needed to use ESRGAN postprocessing.
*/
interface ESRGANParameters {
/**
* ESRGAN strength. Strength to apply upscaling. Float. 0 --> 1, exclusive.
*/
esrganStrength: number;
/**
* ESRGAN upscaling scale. One of 2x | 4x. Represented as integer.
*/
esrganScale: 2 | 4;
}
/**
* Extends the generation and processing method parameters, adding flags to enable each.
*/
interface ProcessingParameters extends CommonParameters {
/**
* Flag to enable the generation of variations. Requires valid VariationParameters. Boolean.
*/
shouldGenerateVariations: boolean;
/**
* Variation parameters.
*/
variationParameters: VariationParameters;
/**
* Flag to enable the use of an initial image, i.e. to use "img2img" generation.
* Requires valid ImageToImageParameters. Boolean.
*/
shouldUseImageToImage: boolean;
/**
* ImageToImage parameters.
*/
imageToImageParameters: ImageToImageParameters;
/**
* Flag to enable GFPGAN postprocessing. Requires valid GFPGANParameters. Boolean.
*/
shouldRunGFPGAN: boolean;
/**
* GFPGAN parameters.
*/
gfpganParameters: GFPGANParameters;
/**
* Flag to enable ESRGAN postprocessing. Requires valid ESRGANParameters. Boolean.
*/
shouldRunESRGAN: boolean;
/**
* ESRGAN parameters.
*/
esrganParameters: GFPGANParameters;
}
/**
* Extends ProcessingParameters, adding items needed to request processing.
*/
interface ProcessingState extends ProcessingParameters {
/**
* Number of images to generate. Integer. Minimum 1.
*/
iterations: number;
/**
* Flag to enable the randomization of the seed on each generation. Boolean.
*/
shouldRandomizeSeed: boolean;
}
export {}

View File

@ -0,0 +1,21 @@
import { Button, ButtonProps } from '@chakra-ui/react';
interface Props extends ButtonProps {
label: string;
}
/**
* Reusable customized button component. Originally was more customized - now probably unecessary.
*
* TODO: Get rid of this.
*/
const SDButton = (props: Props) => {
const { label, size = 'sm', ...rest } = props;
return (
<Button size={size} {...rest}>
{label}
</Button>
);
};
export default SDButton;

View File

@ -16,6 +16,9 @@ interface Props extends NumberInputProps {
width?: string | number;
}
/**
* Customized Chakra FormControl + NumberInput multi-part component.
*/
const SDNumberInput = (props: Props) => {
const {
label,
@ -31,7 +34,7 @@ const SDNumberInput = (props: Props) => {
<Flex gap={2} justifyContent={'space-between'} alignItems={'center'}>
{label && (
<FormLabel marginBottom={1}>
<Text fontSize={fontSize} whiteSpace='nowrap'>
<Text fontSize={fontSize} whiteSpace="nowrap">
{label}
</Text>
</FormLabel>
@ -42,7 +45,7 @@ const SDNumberInput = (props: Props) => {
keepWithinRange={false}
clampValueOnBlur={true}
>
<NumberInputField fontSize={'md'}/>
<NumberInputField fontSize={'md'} />
<NumberInputStepper>
<NumberIncrementStepper />
<NumberDecrementStepper />

View File

@ -0,0 +1,56 @@
import {
Flex,
FormControl,
FormLabel,
Select,
SelectProps,
Text,
} from '@chakra-ui/react';
interface Props extends SelectProps {
label: string;
validValues:
| Array<number | string>
| Array<{ key: string; value: string | number }>;
}
/**
* Customized Chakra FormControl + Select multi-part component.
*/
const SDSelect = (props: Props) => {
const {
label,
isDisabled,
validValues,
size = 'sm',
fontSize = 'md',
marginBottom = 1,
whiteSpace = 'nowrap',
...rest
} = props;
return (
<FormControl isDisabled={isDisabled}>
<Flex justifyContent={'space-between'} alignItems={'center'}>
<FormLabel marginBottom={marginBottom}>
<Text fontSize={fontSize} whiteSpace={whiteSpace}>
{label}
</Text>
</FormLabel>
<Select fontSize={fontSize} size={size} {...rest}>
{validValues.map((opt) => {
return typeof opt === 'string' || typeof opt === 'number' ? (
<option key={opt} value={opt}>
{opt}
</option>
) : (
<option key={opt.value} value={opt.value}>
{opt.key}
</option>
);
})}
</Select>
</Flex>
</FormControl>
);
};
export default SDSelect;

View File

@ -11,6 +11,9 @@ interface Props extends SwitchProps {
width?: string | number;
}
/**
* Customized Chakra FormControl + Switch multi-part component.
*/
const SDSwitch = (props: Props) => {
const {
label,
@ -28,7 +31,7 @@ const SDSwitch = (props: Props) => {
fontSize={fontSize}
marginBottom={1}
flexGrow={2}
whiteSpace='nowrap'
whiteSpace="nowrap"
>
{label}
</FormLabel>

View File

@ -0,0 +1,104 @@
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { useMemo } from 'react';
import { useAppSelector } from '../../app/store';
import { RootState } from '../../app/store';
import { SDState } from '../../features/sd/sdSlice';
import { SystemState } from '../../features/system/systemSlice';
import { validateSeedWeights } from '../util/seedWeightPairs';
const sdSelector = createSelector(
(state: RootState) => state.sd,
(sd: SDState) => {
return {
prompt: sd.prompt,
shouldGenerateVariations: sd.shouldGenerateVariations,
seedWeights: sd.seedWeights,
maskPath: sd.maskPath,
initialImagePath: sd.initialImagePath,
seed: sd.seed,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return {
isProcessing: system.isProcessing,
isConnected: system.isConnected,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
/**
* Checks relevant pieces of state to confirm generation will not deterministically fail.
* This is used to prevent the 'Generate' button from being clicked.
*/
const useCheckParameters = (): boolean => {
const {
prompt,
shouldGenerateVariations,
seedWeights,
maskPath,
initialImagePath,
seed,
} = useAppSelector(sdSelector);
const { isProcessing, isConnected } = useAppSelector(systemSelector);
return useMemo(() => {
// Cannot generate without a prompt
if (!prompt) {
return false;
}
// Cannot generate with a mask without img2img
if (maskPath && !initialImagePath) {
return false;
}
// TODO: job queue
// Cannot generate if already processing an image
if (isProcessing) {
return false;
}
// Cannot generate if not connected
if (!isConnected) {
return false;
}
// Cannot generate variations without valid seed weights
if (
shouldGenerateVariations &&
(!(validateSeedWeights(seedWeights) || seedWeights === '') || seed === -1)
) {
return false;
}
// All good
return true;
}, [
prompt,
maskPath,
initialImagePath,
isProcessing,
isConnected,
shouldGenerateVariations,
seedWeights,
seed,
]);
};
export default useCheckParameters;

View File

@ -1,17 +1,15 @@
import { SDState } from '../features/sd/sdSlice';
import randomInt from '../features/sd/util/randomInt';
import {
seedWeightsToString,
stringToSeedWeights,
} from '../features/sd/util/seedWeightPairs';
import { SystemState } from '../features/system/systemSlice';
import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from './constants';
/*
These functions translate frontend state into parameters
suitable for consumption by the backend, and vice-versa.
*/
import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from "../../app/constants";
import { SDState } from "../../features/sd/sdSlice";
import { SystemState } from "../../features/system/systemSlice";
import randomInt from "./randomInt";
import { seedWeightsToString, stringToSeedWeights } from "./seedWeightPairs";
export const frontendToBackendParameters = (
sdState: SDState,
systemState: SystemState
@ -32,7 +30,7 @@ export const frontendToBackendParameters = (
maskPath,
shouldFitToWidthHeight,
shouldGenerateVariations,
variantAmount,
variationAmount,
seedWeights,
shouldRunESRGAN,
upscalingLevel,
@ -71,7 +69,7 @@ export const frontendToBackendParameters = (
}
if (shouldGenerateVariations) {
generationParameters.variation_amount = variantAmount;
generationParameters.variation_amount = variationAmount;
if (seedWeights) {
generationParameters.with_variations =
stringToSeedWeights(seedWeights);
@ -138,7 +136,7 @@ export const backendToFrontendParameters = (parameters: {
if (variation_amount > 0) {
sd.shouldGenerateVariations = true;
sd.variantAmount = variation_amount;
sd.variationAmount = variation_amount;
if (with_variations) {
sd.seedWeights = seedWeightsToString(with_variations);
}

View File

@ -1,16 +0,0 @@
import { Button, ButtonProps } from '@chakra-ui/react';
interface Props extends ButtonProps {
label: string;
}
const SDButton = (props: Props) => {
const { label, size = 'sm', ...rest } = props;
return (
<Button size={size} {...rest}>
{label}
</Button>
);
};
export default SDButton;

View File

@ -1,57 +0,0 @@
import {
Flex,
FormControl,
FormLabel,
Select,
SelectProps,
Text,
} from '@chakra-ui/react';
interface Props extends SelectProps {
label: string;
validValues:
| Array<number | string>
| Array<{ key: string; value: string | number }>;
}
const SDSelect = (props: Props) => {
const {
label,
isDisabled,
validValues,
size = 'sm',
fontSize = 'md',
marginBottom = 1,
whiteSpace = 'nowrap',
...rest
} = props;
return (
<FormControl isDisabled={isDisabled}>
<Flex justifyContent={'space-between'} alignItems={'center'}>
<FormLabel
marginBottom={marginBottom}
>
<Text fontSize={fontSize} whiteSpace={whiteSpace}>
{label}
</Text>
</FormLabel>
<Select fontSize={fontSize} size={size} {...rest}>
{validValues.map((opt) => {
return typeof opt === 'string' ||
typeof opt === 'number' ? (
<option key={opt} value={opt}>
{opt}
</option>
) : (
<option key={opt.value} value={opt.value}>
{opt.key}
</option>
);
})}
</Select>
</Flex>
</FormControl>
);
};
export default SDSelect;

View File

@ -1,161 +0,0 @@
import { Center, Flex, Image, useColorModeValue } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { RootState } from '../../app/store';
import { setAllParameters, setInitialImagePath, setSeed } from '../sd/sdSlice';
import { useState } from 'react';
import ImageMetadataViewer from './ImageMetadataViewer';
import DeleteImageModalButton from './DeleteImageModalButton';
import SDButton from '../../components/SDButton';
import { runESRGAN, runGFPGAN } from '../../app/socketio';
import { createSelector } from '@reduxjs/toolkit';
import { SystemState } from '../system/systemSlice';
import { isEqual } from 'lodash';
const height = 'calc(100vh - 238px)';
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return {
isProcessing: system.isProcessing,
isConnected: system.isConnected,
isGFPGANAvailable: system.isGFPGANAvailable,
isESRGANAvailable: system.isESRGANAvailable,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const CurrentImage = () => {
const { currentImage, intermediateImage } = useAppSelector(
(state: RootState) => state.gallery
);
const { isProcessing, isConnected, isGFPGANAvailable, isESRGANAvailable } =
useAppSelector(systemSelector);
const dispatch = useAppDispatch();
const bgColor = useColorModeValue(
'rgba(255, 255, 255, 0.85)',
'rgba(0, 0, 0, 0.8)'
);
const [shouldShowImageDetails, setShouldShowImageDetails] =
useState<boolean>(false);
const imageToDisplay = intermediateImage || currentImage;
return (
<Flex direction={'column'} rounded={'md'} borderWidth={1} p={2} gap={2}>
{imageToDisplay && (
<Flex gap={2}>
<SDButton
label='Use as initial image'
colorScheme={'gray'}
flexGrow={1}
variant={'outline'}
onClick={() =>
dispatch(setInitialImagePath(imageToDisplay.url))
}
/>
<SDButton
label='Use all'
colorScheme={'gray'}
flexGrow={1}
variant={'outline'}
onClick={() =>
dispatch(setAllParameters(imageToDisplay.metadata))
}
/>
<SDButton
label='Use seed'
colorScheme={'gray'}
flexGrow={1}
variant={'outline'}
isDisabled={!imageToDisplay.metadata.seed}
onClick={() =>
dispatch(setSeed(imageToDisplay.metadata.seed!))
}
/>
<SDButton
label='Upscale'
colorScheme={'gray'}
flexGrow={1}
variant={'outline'}
isDisabled={
!isESRGANAvailable ||
Boolean(intermediateImage) ||
!(isConnected && !isProcessing)
}
onClick={() => dispatch(runESRGAN(imageToDisplay))}
/>
<SDButton
label='Fix faces'
colorScheme={'gray'}
flexGrow={1}
variant={'outline'}
isDisabled={
!isGFPGANAvailable ||
Boolean(intermediateImage) ||
!(isConnected && !isProcessing)
}
onClick={() => dispatch(runGFPGAN(imageToDisplay))}
/>
<SDButton
label='Details'
colorScheme={'gray'}
variant={shouldShowImageDetails ? 'solid' : 'outline'}
borderWidth={1}
flexGrow={1}
onClick={() =>
setShouldShowImageDetails(!shouldShowImageDetails)
}
/>
<DeleteImageModalButton image={imageToDisplay}>
<SDButton
label='Delete'
colorScheme={'red'}
flexGrow={1}
variant={'outline'}
isDisabled={Boolean(intermediateImage)}
/>
</DeleteImageModalButton>
</Flex>
)}
<Center height={height} position={'relative'}>
{imageToDisplay && (
<Image
src={imageToDisplay.url}
fit='contain'
maxWidth={'100%'}
maxHeight={'100%'}
/>
)}
{imageToDisplay && shouldShowImageDetails && (
<Flex
width={'100%'}
height={'100%'}
position={'absolute'}
top={0}
left={0}
p={3}
boxSizing='border-box'
backgroundColor={bgColor}
overflow='scroll'
>
<ImageMetadataViewer image={imageToDisplay} />
</Flex>
)}
</Center>
</Flex>
);
};
export default CurrentImage;

View File

@ -0,0 +1,149 @@
import { Flex } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from '../../app/store';
import { RootState } from '../../app/store';
import { setAllParameters, setInitialImagePath, setSeed } from '../sd/sdSlice';
import DeleteImageModal from './DeleteImageModal';
import { createSelector } from '@reduxjs/toolkit';
import { SystemState } from '../system/systemSlice';
import { isEqual } from 'lodash';
import { SDImage } from './gallerySlice';
import SDButton from '../../common/components/SDButton';
import { runESRGAN, runGFPGAN } from '../../app/socketio/actions';
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return {
isProcessing: system.isProcessing,
isConnected: system.isConnected,
isGFPGANAvailable: system.isGFPGANAvailable,
isESRGANAvailable: system.isESRGANAvailable,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
type CurrentImageButtonsProps = {
image: SDImage;
shouldShowImageDetails: boolean;
setShouldShowImageDetails: (b: boolean) => void;
};
/**
* Row of buttons for common actions:
* Use as init image, use all params, use seed, upscale, fix faces, details, delete.
*/
const CurrentImageButtons = ({
image,
shouldShowImageDetails,
setShouldShowImageDetails,
}: CurrentImageButtonsProps) => {
const dispatch = useAppDispatch();
const { intermediateImage } = useAppSelector(
(state: RootState) => state.gallery
);
const { upscalingLevel, gfpganStrength } = useAppSelector(
(state: RootState) => state.sd
);
const { isProcessing, isConnected, isGFPGANAvailable, isESRGANAvailable } =
useAppSelector(systemSelector);
const handleClickUseAsInitialImage = () =>
dispatch(setInitialImagePath(image.url));
const handleClickUseAllParameters = () =>
dispatch(setAllParameters(image.metadata));
// Non-null assertion: this button is disabled if there is no seed.
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const handleClickUseSeed = () => dispatch(setSeed(image.metadata.seed!));
const handleClickUpscale = () => dispatch(runESRGAN(image));
const handleClickFixFaces = () => dispatch(runGFPGAN(image));
const handleClickShowImageDetails = () =>
setShouldShowImageDetails(!shouldShowImageDetails);
return (
<Flex gap={2}>
<SDButton
label="Use as initial image"
colorScheme={'gray'}
flexGrow={1}
variant={'outline'}
onClick={handleClickUseAsInitialImage}
/>
<SDButton
label="Use all"
colorScheme={'gray'}
flexGrow={1}
variant={'outline'}
onClick={handleClickUseAllParameters}
/>
<SDButton
label="Use seed"
colorScheme={'gray'}
flexGrow={1}
variant={'outline'}
isDisabled={!image.metadata.seed}
onClick={handleClickUseSeed}
/>
<SDButton
label="Upscale"
colorScheme={'gray'}
flexGrow={1}
variant={'outline'}
isDisabled={
!isESRGANAvailable ||
Boolean(intermediateImage) ||
!(isConnected && !isProcessing) ||
!upscalingLevel
}
onClick={handleClickUpscale}
/>
<SDButton
label="Fix faces"
colorScheme={'gray'}
flexGrow={1}
variant={'outline'}
isDisabled={
!isGFPGANAvailable ||
Boolean(intermediateImage) ||
!(isConnected && !isProcessing) ||
!gfpganStrength
}
onClick={handleClickFixFaces}
/>
<SDButton
label="Details"
colorScheme={'gray'}
variant={shouldShowImageDetails ? 'solid' : 'outline'}
borderWidth={1}
flexGrow={1}
onClick={handleClickShowImageDetails}
/>
<DeleteImageModal image={image}>
<SDButton
label="Delete"
colorScheme={'red'}
flexGrow={1}
variant={'outline'}
isDisabled={Boolean(intermediateImage)}
/>
</DeleteImageModal>
</Flex>
);
};
export default CurrentImageButtons;

View File

@ -0,0 +1,67 @@
import { Center, Flex, Image, Text, useColorModeValue } from '@chakra-ui/react';
import { useAppSelector } from '../../app/store';
import { RootState } from '../../app/store';
import { useState } from 'react';
import ImageMetadataViewer from './ImageMetadataViewer';
import CurrentImageButtons from './CurrentImageButtons';
// TODO: With CSS Grid I had a hard time centering the image in a grid item. This is needed for that.
const height = 'calc(100vh - 238px)';
/**
* Displays the current image if there is one, plus associated actions.
*/
const CurrentImageDisplay = () => {
const { currentImage, intermediateImage } = useAppSelector(
(state: RootState) => state.gallery
);
const bgColor = useColorModeValue(
'rgba(255, 255, 255, 0.85)',
'rgba(0, 0, 0, 0.8)'
);
const [shouldShowImageDetails, setShouldShowImageDetails] =
useState<boolean>(false);
const imageToDisplay = intermediateImage || currentImage;
return imageToDisplay ? (
<Flex direction={'column'} borderWidth={1} rounded={'md'} p={2} gap={2}>
<CurrentImageButtons
image={imageToDisplay}
shouldShowImageDetails={shouldShowImageDetails}
setShouldShowImageDetails={setShouldShowImageDetails}
/>
<Center height={height} position={'relative'}>
<Image
src={imageToDisplay.url}
fit="contain"
maxWidth={'100%'}
maxHeight={'100%'}
/>
{shouldShowImageDetails && (
<Flex
width={'100%'}
height={'100%'}
position={'absolute'}
top={0}
left={0}
p={3}
boxSizing="border-box"
backgroundColor={bgColor}
overflow="scroll"
>
<ImageMetadataViewer image={imageToDisplay} />
</Flex>
)}
</Center>
</Flex>
) : (
<Center height={'100%'} position={'relative'}>
<Text size={'xl'}>No image selected</Text>
</Center>
);
};
export default CurrentImageDisplay;

View File

@ -0,0 +1,121 @@
import {
Text,
AlertDialog,
AlertDialogBody,
AlertDialogFooter,
AlertDialogHeader,
AlertDialogContent,
AlertDialogOverlay,
useDisclosure,
Button,
Switch,
FormControl,
FormLabel,
Flex,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import {
ChangeEvent,
cloneElement,
ReactElement,
SyntheticEvent,
useRef,
} from 'react';
import { useAppDispatch, useAppSelector } from '../../app/store';
import { deleteImage } from '../../app/socketio/actions';
import { RootState } from '../../app/store';
import { setShouldConfirmOnDelete, SystemState } from '../system/systemSlice';
import { SDImage } from './gallerySlice';
interface DeleteImageModalProps {
/**
* Component which, on click, should delete the image/open the modal.
*/
children: ReactElement;
/**
* The image to delete.
*/
image: SDImage;
}
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => system.shouldConfirmOnDelete
);
/**
* Needs a child, which will act as the button to delete an image.
* If system.shouldConfirmOnDelete is true, a confirmation modal is displayed.
* If it is false, the image is deleted immediately.
* The confirmation modal has a "Don't ask me again" switch to set the boolean.
*/
const DeleteImageModal = ({ image, children }: DeleteImageModalProps) => {
const { isOpen, onOpen, onClose } = useDisclosure();
const dispatch = useAppDispatch();
const shouldConfirmOnDelete = useAppSelector(systemSelector);
const cancelRef = useRef<HTMLButtonElement>(null);
const handleClickDelete = (e: SyntheticEvent) => {
e.stopPropagation();
shouldConfirmOnDelete ? onOpen() : handleDelete();
};
const handleDelete = () => {
dispatch(deleteImage(image));
onClose();
};
const handleChangeShouldConfirmOnDelete = (
e: ChangeEvent<HTMLInputElement>
) => dispatch(setShouldConfirmOnDelete(!e.target.checked));
return (
<>
{cloneElement(children, {
// TODO: This feels wrong.
onClick: handleClickDelete,
})}
<AlertDialog
isOpen={isOpen}
leastDestructiveRef={cancelRef}
onClose={onClose}
>
<AlertDialogOverlay>
<AlertDialogContent>
<AlertDialogHeader fontSize="lg" fontWeight="bold">
Delete image
</AlertDialogHeader>
<AlertDialogBody>
<Flex direction={'column'} gap={5}>
<Text>
Are you sure? You can't undo this action afterwards.
</Text>
<FormControl>
<Flex alignItems={'center'}>
<FormLabel mb={0}>Don't ask me again</FormLabel>
<Switch
checked={!shouldConfirmOnDelete}
onChange={handleChangeShouldConfirmOnDelete}
/>
</Flex>
</FormControl>
</Flex>
</AlertDialogBody>
<AlertDialogFooter>
<Button ref={cancelRef} onClick={onClose}>
Cancel
</Button>
<Button colorScheme="red" onClick={handleDelete} ml={3}>
Delete
</Button>
</AlertDialogFooter>
</AlertDialogContent>
</AlertDialogOverlay>
</AlertDialog>
</>
);
};
export default DeleteImageModal;

View File

@ -1,94 +0,0 @@
import {
IconButtonProps,
Modal,
ModalBody,
ModalCloseButton,
ModalContent,
ModalFooter,
ModalHeader,
ModalOverlay,
Text,
useDisclosure,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import {
cloneElement,
ReactElement,
SyntheticEvent,
} from 'react';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { deleteImage } from '../../app/socketio';
import { RootState } from '../../app/store';
import SDButton from '../../components/SDButton';
import { setShouldConfirmOnDelete, SystemState } from '../system/systemSlice';
import { SDImage } from './gallerySlice';
interface Props extends IconButtonProps {
image: SDImage;
'aria-label': string;
children: ReactElement;
}
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => system.shouldConfirmOnDelete
);
/*
TODO: The modal and button to open it should be two different components,
but their state is closely related and I'm not sure how best to accomplish it.
*/
const DeleteImageModalButton = (props: Omit<Props, 'aria-label'>) => {
const { isOpen, onOpen, onClose } = useDisclosure();
const dispatch = useAppDispatch();
const shouldConfirmOnDelete = useAppSelector(systemSelector);
const handleClickDelete = (e: SyntheticEvent) => {
e.stopPropagation();
shouldConfirmOnDelete ? onOpen() : handleDelete();
};
const { image, children } = props;
const handleDelete = () => {
dispatch(deleteImage(image));
onClose();
};
const handleDeleteAndDontAsk = () => {
dispatch(deleteImage(image));
dispatch(setShouldConfirmOnDelete(false));
onClose();
};
return (
<>
{cloneElement(children, {
onClick: handleClickDelete,
})}
<Modal isOpen={isOpen} onClose={onClose}>
<ModalOverlay />
<ModalContent>
<ModalHeader>Are you sure you want to delete this image?</ModalHeader>
<ModalCloseButton />
<ModalBody>
<Text>It will be deleted forever!</Text>
</ModalBody>
<ModalFooter justifyContent={'space-between'}>
<SDButton label={'Yes'} colorScheme='red' onClick={handleDelete} />
<SDButton
label={"Yes, and don't ask me again"}
colorScheme='red'
onClick={handleDeleteAndDontAsk}
/>
<SDButton label='Cancel' colorScheme='blue' onClick={onClose} />
</ModalFooter>
</ModalContent>
</Modal>
</>
);
};
export default DeleteImageModalButton;

View File

@ -0,0 +1,131 @@
import {
Box,
Flex,
Icon,
IconButton,
Image,
useColorModeValue,
} from '@chakra-ui/react';
import { useAppDispatch } from '../../app/store';
import { SDImage, setCurrentImage } from './gallerySlice';
import { FaCheck, FaCopy, FaSeedling, FaTrash } from 'react-icons/fa';
import DeleteImageModal from './DeleteImageModal';
import { memo, SyntheticEvent, useState } from 'react';
import { setAllParameters, setSeed } from '../sd/sdSlice';
interface HoverableImageProps {
image: SDImage;
isSelected: boolean;
}
const memoEqualityCheck = (
prev: HoverableImageProps,
next: HoverableImageProps
) => prev.image.uuid === next.image.uuid && prev.isSelected === next.isSelected;
/**
* Gallery image component with delete/use all/use seed buttons on hover.
*/
const HoverableImage = memo((props: HoverableImageProps) => {
const [isHovered, setIsHovered] = useState<boolean>(false);
const dispatch = useAppDispatch();
const checkColor = useColorModeValue('green.600', 'green.300');
const bgColor = useColorModeValue('gray.200', 'gray.700');
const bgGradient = useColorModeValue(
'radial-gradient(circle, rgba(255,255,255,0.7) 0%, rgba(255,255,255,0.7) 20%, rgba(0,0,0,0) 100%)',
'radial-gradient(circle, rgba(0,0,0,0.7) 0%, rgba(0,0,0,0.7) 20%, rgba(0,0,0,0) 100%)'
);
const { image, isSelected } = props;
const { url, uuid, metadata } = image;
const handleMouseOver = () => setIsHovered(true);
const handleMouseOut = () => setIsHovered(false);
const handleClickSetAllParameters = (e: SyntheticEvent) => {
e.stopPropagation();
dispatch(setAllParameters(metadata));
};
const handleClickSetSeed = (e: SyntheticEvent) => {
e.stopPropagation();
// Non-null assertion: this button is not rendered unless this exists
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
dispatch(setSeed(image.metadata.seed!));
};
const handleClickImage = () => dispatch(setCurrentImage(image));
return (
<Box position={'relative'} key={uuid}>
<Image
width={120}
height={120}
objectFit="cover"
rounded={'md'}
src={url}
loading={'lazy'}
backgroundColor={bgColor}
/>
<Flex
cursor={'pointer'}
position={'absolute'}
top={0}
left={0}
rounded={'md'}
width="100%"
height="100%"
alignItems={'center'}
justifyContent={'center'}
background={isSelected ? bgGradient : undefined}
onClick={handleClickImage}
onMouseOver={handleMouseOver}
onMouseOut={handleMouseOut}
>
{isSelected && (
<Icon fill={checkColor} width={'50%'} height={'50%'} as={FaCheck} />
)}
{isHovered && (
<Flex
direction={'column'}
gap={1}
position={'absolute'}
top={1}
right={1}
>
<DeleteImageModal image={image}>
<IconButton
colorScheme="red"
aria-label="Delete image"
icon={<FaTrash />}
size="xs"
fontSize={15}
/>
</DeleteImageModal>
<IconButton
aria-label="Use all parameters"
colorScheme={'blue'}
icon={<FaCopy />}
size="xs"
fontSize={15}
onClickCapture={handleClickSetAllParameters}
/>
{image.metadata.seed && (
<IconButton
aria-label="Use seed"
colorScheme={'blue'}
icon={<FaSeedling />}
size="xs"
fontSize={16}
onClickCapture={handleClickSetSeed}
/>
)}
</Flex>
)}
</Flex>
</Box>
);
}, memoEqualityCheck);
export default HoverableImage;

View File

@ -0,0 +1,39 @@
import { Center, Flex, Text } from '@chakra-ui/react';
import { RootState } from '../../app/store';
import { useAppSelector } from '../../app/store';
import HoverableImage from './HoverableImage';
/**
* Simple image gallery.
*/
const ImageGallery = () => {
const { images, currentImageUuid } = useAppSelector(
(state: RootState) => state.gallery
);
/**
* I don't like that this needs to rerender whenever the current image is changed.
* What if we have a large number of images? I suppose pagination (planned) will
* mitigate this issue.
*
* TODO: Refactor if performance complaints, or after migrating to new API which supports pagination.
*/
return images.length ? (
<Flex gap={2} wrap="wrap" pb={2}>
{[...images].reverse().map((image) => {
const { uuid } = image;
const isSelected = currentImageUuid === uuid;
return (
<HoverableImage key={uuid} image={image} isSelected={isSelected} />
);
})}
</Flex>
) : (
<Center height={'100%'} position={'relative'}>
<Text size={'xl'}>No images in gallery</Text>
</Center>
);
};
export default ImageGallery;

View File

@ -1,124 +1,134 @@
import {
Center,
Flex,
IconButton,
Link,
List,
ListItem,
Text,
Center,
Flex,
IconButton,
Link,
List,
ListItem,
Text,
} from '@chakra-ui/react';
import { memo } from 'react';
import { FaPlus } from 'react-icons/fa';
import { PARAMETERS } from '../../app/constants';
import { useAppDispatch } from '../../app/hooks';
import SDButton from '../../components/SDButton';
import { useAppDispatch } from '../../app/store';
import SDButton from '../../common/components/SDButton';
import { setAllParameters, setParameter } from '../sd/sdSlice';
import { SDImage, SDMetadata } from './gallerySlice';
type Props = {
image: SDImage;
type ImageMetadataViewerProps = {
image: SDImage;
};
const ImageMetadataViewer = ({ image }: Props) => {
const dispatch = useAppDispatch();
// TODO: I don't know if this is needed.
const memoEqualityCheck = (
prev: ImageMetadataViewerProps,
next: ImageMetadataViewerProps
) => prev.image.uuid === next.image.uuid;
const keys = Object.keys(PARAMETERS);
// TODO: Show more interesting information in this component.
const metadata: Array<{
label: string;
key: string;
value: string | number | boolean;
}> = [];
/**
* Image metadata viewer overlays currently selected image and provides
* access to any of its metadata for use in processing.
*/
const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
const dispatch = useAppDispatch();
keys.forEach((key) => {
const value = image.metadata[key as keyof SDMetadata];
if (value !== undefined) {
metadata.push({ label: PARAMETERS[key], key, value });
}
});
/**
* Build an array representing each item of metadata and a human-readable
* label for it e.g. "cfgScale" > "CFG Scale".
*
* This array is then used to render each item with a button to use that
* parameter in the processing settings.
*
* TODO: All this logic feels sloppy.
*/
const keys = Object.keys(PARAMETERS);
return (
<Flex gap={2} direction={'column'} overflowY={'scroll'} width={'100%'}>
<SDButton
label='Use all parameters'
colorScheme={'gray'}
padding={2}
isDisabled={metadata.length === 0}
onClick={() => dispatch(setAllParameters(image.metadata))}
/>
<Flex gap={2}>
<Text fontWeight={'semibold'}>File:</Text>
<Link href={image.url} isExternal>
<Text>{image.url}</Text>
</Link>
</Flex>
{metadata.length ? (
<>
<List>
{metadata.map((parameter, i) => {
const { label, key, value } = parameter;
return (
<ListItem key={i} pb={1}>
<Flex gap={2}>
<IconButton
aria-label='Use this parameter'
icon={<FaPlus />}
size={'xs'}
onClick={() =>
dispatch(
setParameter({
key,
value,
})
)
}
/>
<Text fontWeight={'semibold'}>
{label}:
</Text>
const metadata: Array<{
label: string;
key: string;
value: string | number | boolean;
}> = [];
{value === undefined ||
value === null ||
value === '' ||
value === 0 ? (
<Text
maxHeight={100}
fontStyle={'italic'}
>
None
</Text>
) : (
<Text
maxHeight={100}
overflowY={'scroll'}
>
{value.toString()}
</Text>
)}
</Flex>
</ListItem>
);
})}
</List>
<Flex gap={2}>
<Text fontWeight={'semibold'}>Raw:</Text>
<Text
maxHeight={100}
overflowY={'scroll'}
wordBreak={'break-all'}
>
{JSON.stringify(image.metadata)}
</Text>
</Flex>
</>
) : (
<Center width={'100%'} pt={10}>
<Text fontSize={'lg'} fontWeight='semibold'>
No metadata available
</Text>
</Center>
)}
</Flex>
);
};
keys.forEach((key) => {
const value = image.metadata[key as keyof SDMetadata];
if (value !== undefined) {
metadata.push({ label: PARAMETERS[key], key, value });
}
});
return (
<Flex gap={2} direction={'column'} overflowY={'scroll'} width={'100%'}>
<SDButton
label="Use all parameters"
colorScheme={'gray'}
padding={2}
isDisabled={metadata.length === 0}
onClick={() => dispatch(setAllParameters(image.metadata))}
/>
<Flex gap={2}>
<Text fontWeight={'semibold'}>File:</Text>
<Link href={image.url} isExternal>
<Text>{image.url}</Text>
</Link>
</Flex>
{metadata.length ? (
<>
<List>
{metadata.map((parameter, i) => {
const { label, key, value } = parameter;
return (
<ListItem key={i} pb={1}>
<Flex gap={2}>
<IconButton
aria-label="Use this parameter"
icon={<FaPlus />}
size={'xs'}
onClick={() =>
dispatch(
setParameter({
key,
value,
})
)
}
/>
<Text fontWeight={'semibold'}>{label}:</Text>
{value === undefined ||
value === null ||
value === '' ||
value === 0 ? (
<Text maxHeight={100} fontStyle={'italic'}>
None
</Text>
) : (
<Text maxHeight={100} overflowY={'scroll'}>
{value.toString()}
</Text>
)}
</Flex>
</ListItem>
);
})}
</List>
<Flex gap={2}>
<Text fontWeight={'semibold'}>Raw:</Text>
<Text maxHeight={100} overflowY={'scroll'} wordBreak={'break-all'}>
{JSON.stringify(image.metadata)}
</Text>
</Flex>
</>
) : (
<Center width={'100%'} pt={10}>
<Text fontSize={'lg'} fontWeight="semibold">
No metadata available
</Text>
</Center>
)}
</Flex>
);
}, memoEqualityCheck);
export default ImageMetadataViewer;

View File

@ -1,150 +0,0 @@
import {
Box,
Flex,
Icon,
IconButton,
Image,
useColorModeValue,
} from '@chakra-ui/react';
import { RootState } from '../../app/store';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { SDImage, setCurrentImage } from './gallerySlice';
import { FaCheck, FaCopy, FaSeedling, FaTrash } from 'react-icons/fa';
import DeleteImageModalButton from './DeleteImageModalButton';
import { memo, SyntheticEvent, useState } from 'react';
import { setAllParameters, setSeed } from '../sd/sdSlice';
interface HoverableImageProps {
image: SDImage;
isSelected: boolean;
}
const HoverableImage = memo(
(props: HoverableImageProps) => {
const [isHovered, setIsHovered] = useState<boolean>(false);
const dispatch = useAppDispatch();
const checkColor = useColorModeValue('green.600', 'green.300');
const bgColor = useColorModeValue('gray.200', 'gray.700');
const bgGradient = useColorModeValue(
'radial-gradient(circle, rgba(255,255,255,0.7) 0%, rgba(255,255,255,0.7) 20%, rgba(0,0,0,0) 100%)',
'radial-gradient(circle, rgba(0,0,0,0.7) 0%, rgba(0,0,0,0.7) 20%, rgba(0,0,0,0) 100%)'
);
const { image, isSelected } = props;
const { url, uuid, metadata } = image;
const handleMouseOver = () => setIsHovered(true);
const handleMouseOut = () => setIsHovered(false);
const handleClickSetAllParameters = (e: SyntheticEvent) => {
e.stopPropagation();
dispatch(setAllParameters(metadata));
};
const handleClickSetSeed = (e: SyntheticEvent) => {
e.stopPropagation();
dispatch(setSeed(image.metadata.seed!)); // component not rendered unless this exists
};
return (
<Box position={'relative'} key={uuid}>
<Image
width={120}
height={120}
objectFit='cover'
rounded={'md'}
src={url}
loading={'lazy'}
backgroundColor={bgColor}
/>
<Flex
cursor={'pointer'}
position={'absolute'}
top={0}
left={0}
rounded={'md'}
width='100%'
height='100%'
alignItems={'center'}
justifyContent={'center'}
background={isSelected ? bgGradient : undefined}
onClick={() => dispatch(setCurrentImage(image))}
onMouseOver={handleMouseOver}
onMouseOut={handleMouseOut}
>
{isSelected && (
<Icon
fill={checkColor}
width={'50%'}
height={'50%'}
as={FaCheck}
/>
)}
{isHovered && (
<Flex
direction={'column'}
gap={1}
position={'absolute'}
top={1}
right={1}
>
<DeleteImageModalButton image={image}>
<IconButton
colorScheme='red'
aria-label='Delete image'
icon={<FaTrash />}
size='xs'
fontSize={15}
/>
</DeleteImageModalButton>
<IconButton
aria-label='Use all parameters'
colorScheme={'blue'}
icon={<FaCopy />}
size='xs'
fontSize={15}
onClickCapture={handleClickSetAllParameters}
/>
{image.metadata.seed && (
<IconButton
aria-label='Use seed'
colorScheme={'blue'}
icon={<FaSeedling />}
size='xs'
fontSize={16}
onClickCapture={handleClickSetSeed}
/>
)}
</Flex>
)}
</Flex>
</Box>
);
},
(prev, next) =>
prev.image.uuid === next.image.uuid &&
prev.isSelected === next.isSelected
);
const ImageRoll = () => {
const { images, currentImageUuid } = useAppSelector(
(state: RootState) => state.gallery
);
return (
<Flex gap={2} wrap='wrap' pb={2}>
{[...images].reverse().map((image) => {
const { uuid } = image;
const isSelected = currentImageUuid === uuid;
return (
<HoverableImage
key={uuid}
image={image}
isSelected={isSelected}
/>
);
})}
</Flex>
);
};
export default ImageRoll;

View File

@ -1,8 +1,7 @@
import { createSlice } from '@reduxjs/toolkit';
import type { PayloadAction } from '@reduxjs/toolkit';
import { v4 as uuidv4 } from 'uuid';
import { UpscalingLevel } from '../sd/sdSlice';
import { backendToFrontendParameters } from '../../app/parameterTranslation';
import { clamp } from 'lodash';
// TODO: Revise pending metadata RFC: https://github.com/lstein/stable-diffusion/issues/266
export interface SDMetadata {
@ -50,29 +49,48 @@ export const gallerySlice = createSlice({
state.currentImage = action.payload;
state.currentImageUuid = action.payload.uuid;
},
removeImage: (state, action: PayloadAction<SDImage>) => {
const { uuid } = action.payload;
removeImage: (state, action: PayloadAction<string>) => {
const uuid = action.payload;
const newImages = state.images.filter((image) => image.uuid !== uuid);
const imageToDeleteIndex = state.images.findIndex(
(image) => image.uuid === uuid
);
if (uuid === state.currentImageUuid) {
/**
* We are deleting the currently selected image.
*
* We want the new currentl selected image to be under the cursor in the
* gallery, so we need to do some fanagling. The currently selected image
* is set by its UUID, not its index in the image list.
*
* Get the currently selected image's index.
*/
const imageToDeleteIndex = state.images.findIndex(
(image) => image.uuid === uuid
);
const newCurrentImageIndex = Math.min(
Math.max(imageToDeleteIndex, 0),
newImages.length - 1
);
/**
* New current image needs to be in the same spot, but because the gallery
* is sorted in reverse order, the new current image's index will actuall be
* one less than the deleted image's index.
*
* Clamp the new index to ensure it is valid..
*/
const newCurrentImageIndex = clamp(
imageToDeleteIndex - 1,
0,
newImages.length - 1
);
state.currentImage = newImages.length
? newImages[newCurrentImageIndex]
: undefined;
state.currentImageUuid = newImages.length
? newImages[newCurrentImageIndex].uuid
: '';
}
state.images = newImages;
state.currentImage = newImages.length
? newImages[newCurrentImageIndex]
: undefined;
state.currentImageUuid = newImages.length
? newImages[newCurrentImageIndex].uuid
: '';
},
addImage: (state, action: PayloadAction<SDImage>) => {
state.images.push(action.payload);
@ -86,47 +104,13 @@ export const gallerySlice = createSlice({
clearIntermediateImage: (state) => {
state.intermediateImage = undefined;
},
setGalleryImages: (
state,
action: PayloadAction<
Array<{
path: string;
metadata: { [key: string]: string | number | boolean };
}>
>
) => {
// TODO: Revise pending metadata RFC: https://github.com/lstein/stable-diffusion/issues/266
const images = action.payload;
if (images.length === 0) {
// there are no images on disk, clear the gallery
state.images = [];
state.currentImageUuid = '';
state.currentImage = undefined;
} else {
// Filter image urls that are already in the rehydrated state
const filteredImages = action.payload.filter(
(image) => !state.images.find((i) => i.url === image.path)
);
const preparedImages = filteredImages.map((image): SDImage => {
return {
uuid: uuidv4(),
url: image.path,
metadata: backendToFrontendParameters(image.metadata),
};
});
const newImages = [...state.images].concat(preparedImages);
// if previous currentimage no longer exists, set a new one
if (!newImages.find((image) => image.uuid === state.currentImageUuid)) {
const newCurrentImage = newImages[newImages.length - 1];
state.currentImage = newCurrentImage;
state.currentImageUuid = newCurrentImage.uuid;
}
setGalleryImages: (state, action: PayloadAction<Array<SDImage>>) => {
const newImages = action.payload;
if (newImages.length) {
const newCurrentImage = newImages[newImages.length - 1];
state.images = newImages;
state.currentImage = newCurrentImage;
state.currentImageUuid = newCurrentImage.uuid;
}
},
},

View File

@ -1,35 +1,38 @@
import { Progress } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { useAppSelector } from '../../app/hooks';
import { useAppSelector } from '../../app/store';
import { RootState } from '../../app/store';
import { SDState } from '../sd/sdSlice';
import { SystemState } from '../system/systemSlice';
const sdSelector = createSelector(
(state: RootState) => state.sd,
(sd: SDState) => {
return {
realSteps: sd.realSteps,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return {
isProcessing: system.isProcessing,
currentStep: system.currentStep,
totalSteps: system.totalSteps,
currentStatusHasSteps: system.currentStatusHasSteps,
};
},
{
memoizeOptions: { resultEqualityCheck: isEqual },
}
);
const ProgressBar = () => {
const { realSteps } = useAppSelector(sdSelector);
const { currentStep } = useAppSelector((state: RootState) => state.system);
const progress = Math.round((currentStep * 100) / realSteps);
return (
<Progress
height='10px'
value={progress}
isIndeterminate={progress < 0 || currentStep === realSteps}
/>
);
const { isProcessing, currentStep, totalSteps, currentStatusHasSteps } =
useAppSelector(systemSelector);
const value = currentStep ? Math.round((currentStep * 100) / totalSteps) : 0;
return (
<Progress
height="10px"
value={value}
isIndeterminate={isProcessing && !currentStatusHasSteps}
/>
);
};
export default ProgressBar;

View File

@ -12,39 +12,66 @@ import { isEqual } from 'lodash';
import { FaSun, FaMoon, FaGithub } from 'react-icons/fa';
import { MdHelp, MdSettings } from 'react-icons/md';
import { useAppSelector } from '../../app/hooks';
import { useAppSelector } from '../../app/store';
import { RootState } from '../../app/store';
import SettingsModal from '../system/SettingsModal';
import { SystemState } from '../system/systemSlice';
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return { isConnected: system.isConnected };
return {
isConnected: system.isConnected,
isProcessing: system.isProcessing,
currentIteration: system.currentIteration,
totalIterations: system.totalIterations,
currentStatus: system.currentStatus,
};
},
{
memoizeOptions: { resultEqualityCheck: isEqual },
}
);
/**
* Header, includes color mode toggle, settings button, status message.
*/
const SiteHeader = () => {
const { colorMode, toggleColorMode } = useColorMode();
const { isConnected } = useAppSelector(systemSelector);
const {
isConnected,
isProcessing,
currentIteration,
totalIterations,
currentStatus,
} = useAppSelector(systemSelector);
const statusMessageTextColor = isConnected ? 'green.500' : 'red.500';
const colorModeIcon = colorMode == 'light' ? <FaMoon /> : <FaSun />;
// Make FaMoon and FaSun icon apparent size consistent
const colorModeIconFontSize = colorMode == 'light' ? 18 : 20;
let statusMessage = currentStatus;
if (isProcessing) {
if (totalIterations > 1) {
statusMessage += ` [${currentIteration}/${totalIterations}]`;
}
}
return (
<Flex minWidth='max-content' alignItems='center' gap='1' pl={2} pr={1}>
<Flex minWidth="max-content" alignItems="center" gap="1" pl={2} pr={1}>
<Heading size={'lg'}>Stable Diffusion Dream Server</Heading>
<Spacer />
<Text textColor={isConnected ? 'green.500' : 'red.500'}>
{isConnected ? `Connected to server` : 'No connection to server'}
</Text>
<Text textColor={statusMessageTextColor}>{statusMessage}</Text>
<SettingsModal>
<IconButton
aria-label='Settings'
variant='link'
aria-label="Settings"
variant="link"
fontSize={24}
size={'sm'}
icon={<MdSettings />}
@ -52,14 +79,14 @@ const SiteHeader = () => {
</SettingsModal>
<IconButton
aria-label='Link to Github Issues'
variant='link'
aria-label="Link to Github Issues"
variant="link"
fontSize={23}
size={'sm'}
icon={
<Link
isExternal
href='http://github.com/lstein/stable-diffusion/issues'
href="http://github.com/lstein/stable-diffusion/issues"
>
<MdHelp />
</Link>
@ -67,24 +94,24 @@ const SiteHeader = () => {
/>
<IconButton
aria-label='Link to Github Repo'
variant='link'
aria-label="Link to Github Repo"
variant="link"
fontSize={20}
size={'sm'}
icon={
<Link isExternal href='http://github.com/lstein/stable-diffusion'>
<Link isExternal href="http://github.com/lstein/stable-diffusion">
<FaGithub />
</Link>
}
/>
<IconButton
aria-label='Toggle Dark Mode'
aria-label="Toggle Dark Mode"
onClick={toggleColorMode}
variant='link'
variant="link"
size={'sm'}
fontSize={colorMode == 'light' ? 18 : 20}
icon={colorMode == 'light' ? <FaMoon /> : <FaSun />}
fontSize={colorModeIconFontSize}
icon={colorModeIcon}
/>
</Flex>
);

View File

@ -1,84 +1,87 @@
import { Flex } from '@chakra-ui/react';
import { RootState } from '../../app/store';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { useAppDispatch, useAppSelector } from '../../app/store';
import {
setUpscalingLevel,
setUpscalingStrength,
UpscalingLevel,
SDState,
setUpscalingLevel,
setUpscalingStrength,
UpscalingLevel,
SDState,
} from '../sd/sdSlice';
import SDNumberInput from '../../components/SDNumberInput';
import SDSelect from '../../components/SDSelect';
import { UPSCALING_LEVELS } from '../../app/constants';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { SystemState } from '../system/systemSlice';
import { ChangeEvent } from 'react';
import SDNumberInput from '../../common/components/SDNumberInput';
import SDSelect from '../../common/components/SDSelect';
const sdSelector = createSelector(
(state: RootState) => state.sd,
(sd: SDState) => {
return {
upscalingLevel: sd.upscalingLevel,
upscalingStrength: sd.upscalingStrength,
};
(state: RootState) => state.sd,
(sd: SDState) => {
return {
upscalingLevel: sd.upscalingLevel,
upscalingStrength: sd.upscalingStrength,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
}
);
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return {
isESRGANAvailable: system.isESRGANAvailable,
};
(state: RootState) => state.system,
(system: SystemState) => {
return {
isESRGANAvailable: system.isESRGANAvailable,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
}
);
/**
* Displays upscaling/ESRGAN options (level and strength).
*/
const ESRGANOptions = () => {
const { upscalingLevel, upscalingStrength } = useAppSelector(sdSelector);
const dispatch = useAppDispatch();
const { upscalingLevel, upscalingStrength } = useAppSelector(sdSelector);
const { isESRGANAvailable } = useAppSelector(systemSelector);
const { isESRGANAvailable } = useAppSelector(systemSelector);
const handleChangeLevel = (e: ChangeEvent<HTMLSelectElement>) =>
dispatch(setUpscalingLevel(Number(e.target.value) as UpscalingLevel));
const dispatch = useAppDispatch();
const handleChangeStrength = (v: string | number) =>
dispatch(setUpscalingStrength(Number(v)));
return (
<Flex direction={'column'} gap={2}>
<SDSelect
isDisabled={!isESRGANAvailable}
label='Scale'
value={upscalingLevel}
onChange={(e) =>
dispatch(
setUpscalingLevel(
Number(e.target.value) as UpscalingLevel
)
)
}
validValues={UPSCALING_LEVELS}
/>
<SDNumberInput
isDisabled={!isESRGANAvailable}
label='Strength'
step={0.05}
min={0}
max={1}
onChange={(v) => dispatch(setUpscalingStrength(Number(v)))}
value={upscalingStrength}
/>
</Flex>
);
return (
<Flex direction={'column'} gap={2}>
<SDSelect
isDisabled={!isESRGANAvailable}
label="Scale"
value={upscalingLevel}
onChange={handleChangeLevel}
validValues={UPSCALING_LEVELS}
/>
<SDNumberInput
isDisabled={!isESRGANAvailable}
label="Strength"
step={0.05}
min={0}
max={1}
onChange={handleChangeStrength}
value={upscalingStrength}
/>
</Flex>
);
};
export default ESRGANOptions;

View File

@ -1,63 +1,68 @@
import { Flex } from '@chakra-ui/react';
import { RootState } from '../../app/store';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { useAppDispatch, useAppSelector } from '../../app/store';
import { SDState, setGfpganStrength } from '../sd/sdSlice';
import SDNumberInput from '../../components/SDNumberInput';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { SystemState } from '../system/systemSlice';
import SDNumberInput from '../../common/components/SDNumberInput';
const sdSelector = createSelector(
(state: RootState) => state.sd,
(sd: SDState) => {
return {
gfpganStrength: sd.gfpganStrength,
};
(state: RootState) => state.sd,
(sd: SDState) => {
return {
gfpganStrength: sd.gfpganStrength,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
}
);
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return {
isGFPGANAvailable: system.isGFPGANAvailable,
};
(state: RootState) => state.system,
(system: SystemState) => {
return {
isGFPGANAvailable: system.isGFPGANAvailable,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
}
);
/**
* Displays face-fixing/GFPGAN options (strength).
*/
const GFPGANOptions = () => {
const { gfpganStrength } = useAppSelector(sdSelector);
const dispatch = useAppDispatch();
const { gfpganStrength } = useAppSelector(sdSelector);
const { isGFPGANAvailable } = useAppSelector(systemSelector);
const { isGFPGANAvailable } = useAppSelector(systemSelector);
const handleChangeStrength = (v: string | number) =>
dispatch(setGfpganStrength(Number(v)));
const dispatch = useAppDispatch();
return (
<Flex direction={'column'} gap={2}>
<SDNumberInput
isDisabled={!isGFPGANAvailable}
label='Strength'
step={0.05}
min={0}
max={1}
onChange={(v) => dispatch(setGfpganStrength(Number(v)))}
value={gfpganStrength}
/>
</Flex>
);
return (
<Flex direction={'column'} gap={2}>
<SDNumberInput
isDisabled={!isGFPGANAvailable}
label="Strength"
step={0.05}
min={0}
max={1}
onChange={handleChangeStrength}
value={gfpganStrength}
/>
</Flex>
);
};
export default GFPGANOptions;

View File

@ -1,54 +1,59 @@
import { Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { ChangeEvent } from 'react';
import { useAppDispatch, useAppSelector } from '../../app/store';
import { RootState } from '../../app/store';
import SDNumberInput from '../../components/SDNumberInput';
import SDSwitch from '../../components/SDSwitch';
import InitImage from './InitImage';
import SDNumberInput from '../../common/components/SDNumberInput';
import SDSwitch from '../../common/components/SDSwitch';
import InitAndMaskImage from './InitAndMaskImage';
import {
SDState,
setImg2imgStrength,
setShouldFitToWidthHeight,
SDState,
setImg2imgStrength,
setShouldFitToWidthHeight,
} from './sdSlice';
const sdSelector = createSelector(
(state: RootState) => state.sd,
(sd: SDState) => {
return {
initialImagePath: sd.initialImagePath,
img2imgStrength: sd.img2imgStrength,
shouldFitToWidthHeight: sd.shouldFitToWidthHeight,
};
}
(state: RootState) => state.sd,
(sd: SDState) => {
return {
img2imgStrength: sd.img2imgStrength,
shouldFitToWidthHeight: sd.shouldFitToWidthHeight,
};
}
);
/**
* Options for img2img generation (strength, fit, init/mask upload).
*/
const ImageToImageOptions = () => {
const { initialImagePath, img2imgStrength, shouldFitToWidthHeight } =
useAppSelector(sdSelector);
const dispatch = useAppDispatch();
const { img2imgStrength, shouldFitToWidthHeight } =
useAppSelector(sdSelector);
const dispatch = useAppDispatch();
return (
<Flex direction={'column'} gap={2}>
<SDNumberInput
isDisabled={!initialImagePath}
label='Strength'
step={0.01}
min={0}
max={1}
onChange={(v) => dispatch(setImg2imgStrength(Number(v)))}
value={img2imgStrength}
/>
<SDSwitch
isDisabled={!initialImagePath}
label='Fit initial image to output size'
isChecked={shouldFitToWidthHeight}
onChange={(e) =>
dispatch(setShouldFitToWidthHeight(e.target.checked))
}
/>
<InitImage />
</Flex>
);
const handleChangeStrength = (v: string | number) =>
dispatch(setImg2imgStrength(Number(v)));
const handleChangeFit = (e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldFitToWidthHeight(e.target.checked));
return (
<Flex direction={'column'} gap={2}>
<SDNumberInput
label="Strength"
step={0.01}
min={0}
max={1}
onChange={handleChangeStrength}
value={img2imgStrength}
/>
<SDSwitch
label="Fit initial image to output size"
isChecked={shouldFitToWidthHeight}
onChange={handleChangeFit}
/>
<InitAndMaskImage />
</Flex>
);
};
export default ImageToImageOptions;

View File

@ -0,0 +1,63 @@
import { cloneElement, ReactElement, SyntheticEvent, useCallback } from 'react';
import { FileRejection, useDropzone } from 'react-dropzone';
type ImageUploaderProps = {
/**
* Component which, on click, should open the upload interface.
*/
children: ReactElement;
/**
* Callback to handle uploading the selected file.
*/
fileAcceptedCallback: (file: File) => void;
/**
* Callback to handle a file being rejected.
*/
fileRejectionCallback: (rejection: FileRejection) => void;
};
/**
* File upload using react-dropzone.
* Needs a child to be the button to activate the upload interface.
*/
const ImageUploader = ({
children,
fileAcceptedCallback,
fileRejectionCallback,
}: ImageUploaderProps) => {
const onDrop = useCallback(
(acceptedFiles: Array<File>, fileRejections: Array<FileRejection>) => {
fileRejections.forEach((rejection: FileRejection) => {
fileRejectionCallback(rejection);
});
acceptedFiles.forEach((file: File) => {
fileAcceptedCallback(file);
});
},
[fileAcceptedCallback, fileRejectionCallback]
);
const { getRootProps, getInputProps, open } = useDropzone({
onDrop,
accept: {
'image/jpeg': ['.jpg', '.jpeg', '.png'],
},
});
const handleClickUploadIcon = (e: SyntheticEvent) => {
e.stopPropagation();
open();
};
return (
<div {...getRootProps()}>
<input {...getInputProps({ multiple: false })} />
{cloneElement(children, {
onClick: handleClickUploadIcon,
})}
</div>
);
};
export default ImageUploader;

View File

@ -0,0 +1,57 @@
import { Flex, Image } from '@chakra-ui/react';
import { useState } from 'react';
import { useAppSelector } from '../../app/store';
import { RootState } from '../../app/store';
import { SDState } from '../../features/sd/sdSlice';
import './InitAndMaskImage.css';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import InitAndMaskUploadButtons from './InitAndMaskUploadButtons';
const sdSelector = createSelector(
(state: RootState) => state.sd,
(sd: SDState) => {
return {
initialImagePath: sd.initialImagePath,
maskPath: sd.maskPath,
};
},
{ memoizeOptions: { resultEqualityCheck: isEqual } }
);
/**
* Displays init and mask images and buttons to upload/delete them.
*/
const InitAndMaskImage = () => {
const { initialImagePath, maskPath } = useAppSelector(sdSelector);
const [shouldShowMask, setShouldShowMask] = useState<boolean>(false);
return (
<Flex direction={'column'} alignItems={'center'} gap={2}>
<InitAndMaskUploadButtons setShouldShowMask={setShouldShowMask} />
{initialImagePath && (
<Flex position={'relative'} width={'100%'}>
<Image
fit={'contain'}
src={initialImagePath}
rounded={'md'}
className={'checkerboard'}
/>
{shouldShowMask && maskPath && (
<Image
position={'absolute'}
top={0}
left={0}
fit={'contain'}
src={maskPath}
rounded={'md'}
zIndex={1}
/>
)}
</Flex>
)}
</Flex>
);
};
export default InitAndMaskImage;

View File

@ -0,0 +1,131 @@
import { Button, Flex, IconButton, useToast } from '@chakra-ui/react';
import { SyntheticEvent, useCallback } from 'react';
import { FaTrash } from 'react-icons/fa';
import { useAppDispatch, useAppSelector } from '../../app/store';
import { RootState } from '../../app/store';
import {
SDState,
setInitialImagePath,
setMaskPath,
} from '../../features/sd/sdSlice';
import { uploadInitialImage, uploadMaskImage } from '../../app/socketio/actions';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import ImageUploader from './ImageUploader';
import { FileRejection } from 'react-dropzone';
const sdSelector = createSelector(
(state: RootState) => state.sd,
(sd: SDState) => {
return {
initialImagePath: sd.initialImagePath,
maskPath: sd.maskPath,
};
},
{ memoizeOptions: { resultEqualityCheck: isEqual } }
);
type InitAndMaskUploadButtonsProps = {
setShouldShowMask: (b: boolean) => void;
};
/**
* Init and mask image upload buttons.
*/
const InitAndMaskUploadButtons = ({
setShouldShowMask,
}: InitAndMaskUploadButtonsProps) => {
const dispatch = useAppDispatch();
const { initialImagePath } = useAppSelector(sdSelector);
// Use a toast to alert user when a file upload is rejected
const toast = useToast();
// Clear the init and mask images
const handleClickResetInitialImageAndMask = (e: SyntheticEvent) => {
e.stopPropagation();
dispatch(setInitialImagePath(''));
dispatch(setMaskPath(''));
};
// Handle hover to view initial image and mask image
const handleMouseOverInitialImageUploadButton = () =>
setShouldShowMask(false);
const handleMouseOutInitialImageUploadButton = () => setShouldShowMask(true);
const handleMouseOverMaskUploadButton = () => setShouldShowMask(true);
const handleMouseOutMaskUploadButton = () => setShouldShowMask(true);
// Callbacks to for handling file upload attempts
const initImageFileAcceptedCallback = useCallback(
(file: File) => dispatch(uploadInitialImage(file)),
[dispatch]
);
const maskImageFileAcceptedCallback = useCallback(
(file: File) => dispatch(uploadMaskImage(file)),
[dispatch]
);
const fileRejectionCallback = useCallback(
(rejection: FileRejection) => {
const msg = rejection.errors.reduce(
(acc: string, cur: { message: string }) => acc + '\n' + cur.message,
''
);
toast({
title: 'Upload failed',
description: msg,
status: 'error',
isClosable: true,
});
},
[toast]
);
return (
<Flex gap={2} justifyContent={'space-between'} width={'100%'}>
<ImageUploader
fileAcceptedCallback={initImageFileAcceptedCallback}
fileRejectionCallback={fileRejectionCallback}
>
<Button
size={'sm'}
fontSize={'md'}
fontWeight={'normal'}
onMouseOver={handleMouseOverInitialImageUploadButton}
onMouseOut={handleMouseOutInitialImageUploadButton}
>
Upload Image
</Button>
</ImageUploader>
<ImageUploader
fileAcceptedCallback={maskImageFileAcceptedCallback}
fileRejectionCallback={fileRejectionCallback}
>
<Button
isDisabled={!initialImagePath}
size={'sm'}
fontSize={'md'}
fontWeight={'normal'}
onMouseOver={handleMouseOverMaskUploadButton}
onMouseOut={handleMouseOutMaskUploadButton}
>
Upload Mask
</Button>
</ImageUploader>
<IconButton
isDisabled={!initialImagePath}
size={'sm'}
aria-label={'Reset initial image and mask'}
onClick={handleClickResetInitialImageAndMask}
icon={<FaTrash />}
/>
</Flex>
);
};
export default InitAndMaskUploadButtons;

View File

@ -1,155 +0,0 @@
import {
Button,
Flex,
IconButton,
Image,
useToast,
} from '@chakra-ui/react';
import { SyntheticEvent, useCallback, useState } from 'react';
import { FileRejection, useDropzone } from 'react-dropzone';
import { FaTrash } from 'react-icons/fa';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { RootState } from '../../app/store';
import {
SDState,
setInitialImagePath,
setMaskPath,
} from '../../features/sd/sdSlice';
import MaskUploader from './MaskUploader';
import './InitImage.css';
import { uploadInitialImage } from '../../app/socketio';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
const sdSelector = createSelector(
(state: RootState) => state.sd,
(sd: SDState) => {
return {
initialImagePath: sd.initialImagePath,
maskPath: sd.maskPath,
};
},
{ memoizeOptions: { resultEqualityCheck: isEqual } }
);
const InitImage = () => {
const toast = useToast();
const dispatch = useAppDispatch();
const { initialImagePath, maskPath } = useAppSelector(sdSelector);
const onDrop = useCallback(
(acceptedFiles: Array<File>, fileRejections: Array<FileRejection>) => {
fileRejections.forEach((rejection: FileRejection) => {
const msg = rejection.errors.reduce(
(acc: string, cur: { message: string }) => acc + '\n' + cur.message,
''
);
toast({
title: 'Upload failed',
description: msg,
status: 'error',
isClosable: true,
});
});
acceptedFiles.forEach((file: File) => {
dispatch(uploadInitialImage(file));
});
},
[dispatch, toast]
);
const { getRootProps, getInputProps, open } = useDropzone({
onDrop,
accept: {
'image/jpeg': ['.jpg', '.jpeg', '.png'],
},
});
const [shouldShowMask, setShouldShowMask] = useState<boolean>(false);
const handleClickUploadIcon = (e: SyntheticEvent) => {
e.stopPropagation();
open();
};
const handleClickResetInitialImageAndMask = (e: SyntheticEvent) => {
e.stopPropagation();
dispatch(setInitialImagePath(''));
dispatch(setMaskPath(''));
};
const handleMouseOverInitialImageUploadButton = () =>
setShouldShowMask(false);
const handleMouseOutInitialImageUploadButton = () => setShouldShowMask(true);
const handleMouseOverMaskUploadButton = () => setShouldShowMask(true);
const handleMouseOutMaskUploadButton = () => setShouldShowMask(true);
return (
<Flex
{...getRootProps({
onClick: initialImagePath ? (e) => e.stopPropagation() : undefined,
})}
direction={'column'}
alignItems={'center'}
gap={2}
>
<input {...getInputProps({ multiple: false })} />
<Flex gap={2} justifyContent={'space-between'} width={'100%'}>
<Button
size={'sm'}
fontSize={'md'}
fontWeight={'normal'}
onClick={handleClickUploadIcon}
onMouseOver={handleMouseOverInitialImageUploadButton}
onMouseOut={handleMouseOutInitialImageUploadButton}
>
Upload Image
</Button>
<MaskUploader>
<Button
size={'sm'}
fontSize={'md'}
fontWeight={'normal'}
onClick={handleClickUploadIcon}
onMouseOver={handleMouseOverMaskUploadButton}
onMouseOut={handleMouseOutMaskUploadButton}
>
Upload Mask
</Button>
</MaskUploader>
<IconButton
size={'sm'}
aria-label={'Reset initial image and mask'}
onClick={handleClickResetInitialImageAndMask}
icon={<FaTrash />}
/>
</Flex>
{initialImagePath && (
<Flex position={'relative'} width={'100%'}>
<Image
fit={'contain'}
src={initialImagePath}
rounded={'md'}
className={'checkerboard'}
/>
{shouldShowMask && maskPath && (
<Image
position={'absolute'}
top={0}
left={0}
fit={'contain'}
src={maskPath}
rounded={'md'}
zIndex={1}
className={'checkerboard'}
/>
)}
</Flex>
)}
</Flex>
);
};
export default InitImage;

View File

@ -1,61 +0,0 @@
import { useToast } from '@chakra-ui/react';
import { cloneElement, ReactElement, SyntheticEvent, useCallback } from 'react';
import { FileRejection, useDropzone } from 'react-dropzone';
import { useAppDispatch } from '../../app/hooks';
import { uploadMaskImage } from '../../app/socketio';
type Props = {
children: ReactElement;
};
const MaskUploader = ({ children }: Props) => {
const dispatch = useAppDispatch();
const toast = useToast();
const onDrop = useCallback(
(acceptedFiles: Array<File>, fileRejections: Array<FileRejection>) => {
fileRejections.forEach((rejection: FileRejection) => {
const msg = rejection.errors.reduce(
(acc: string, cur: { message: string }) =>
acc + '\n' + cur.message,
''
);
toast({
title: 'Upload failed',
description: msg,
status: 'error',
isClosable: true,
});
});
acceptedFiles.forEach((file: File) => {
dispatch(uploadMaskImage(file));
});
},
[dispatch, toast]
);
const { getRootProps, getInputProps, open } = useDropzone({
onDrop,
accept: {
'image/jpeg': ['.jpg', '.jpeg', '.png'],
},
});
const handleClickUploadIcon = (e: SyntheticEvent) => {
e.stopPropagation();
open();
};
return (
<div {...getRootProps()}>
<input {...getInputProps({ multiple: false })} />
{cloneElement(children, {
onClick: handleClickUploadIcon,
})}
</div>
);
};
export default MaskUploader;

View File

@ -1,23 +1,24 @@
import {
Flex,
Box,
Text,
Accordion,
AccordionItem,
AccordionButton,
AccordionIcon,
AccordionPanel,
Switch,
Flex,
Box,
Text,
Accordion,
AccordionItem,
AccordionButton,
AccordionIcon,
AccordionPanel,
Switch,
ExpandedIndex,
} from '@chakra-ui/react';
import { RootState } from '../../app/store';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { useAppDispatch, useAppSelector } from '../../app/store';
import {
setShouldRunGFPGAN,
setShouldRunESRGAN,
SDState,
setShouldUseInitImage,
setShouldRunGFPGAN,
setShouldRunESRGAN,
SDState,
setShouldUseInitImage,
} from '../sd/sdSlice';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
@ -28,184 +29,189 @@ import ESRGANOptions from './ESRGANOptions';
import GFPGANOptions from './GFPGANOptions';
import OutputOptions from './OutputOptions';
import ImageToImageOptions from './ImageToImageOptions';
import { ChangeEvent } from 'react';
const sdSelector = createSelector(
(state: RootState) => state.sd,
(sd: SDState) => {
return {
initialImagePath: sd.initialImagePath,
shouldUseInitImage: sd.shouldUseInitImage,
shouldRunESRGAN: sd.shouldRunESRGAN,
shouldRunGFPGAN: sd.shouldRunGFPGAN,
};
(state: RootState) => state.sd,
(sd: SDState) => {
return {
initialImagePath: sd.initialImagePath,
shouldUseInitImage: sd.shouldUseInitImage,
shouldRunESRGAN: sd.shouldRunESRGAN,
shouldRunGFPGAN: sd.shouldRunGFPGAN,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
}
);
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return {
isGFPGANAvailable: system.isGFPGANAvailable,
isESRGANAvailable: system.isESRGANAvailable,
openAccordions: system.openAccordions,
};
(state: RootState) => state.system,
(system: SystemState) => {
return {
isGFPGANAvailable: system.isGFPGANAvailable,
isESRGANAvailable: system.isESRGANAvailable,
openAccordions: system.openAccordions,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
}
);
/**
* Main container for generation and processing parameters.
*/
const OptionsAccordion = () => {
const {
shouldRunESRGAN,
shouldRunGFPGAN,
shouldUseInitImage,
initialImagePath,
} = useAppSelector(sdSelector);
const {
shouldRunESRGAN,
shouldRunGFPGAN,
shouldUseInitImage,
initialImagePath,
} = useAppSelector(sdSelector);
const { isGFPGANAvailable, isESRGANAvailable, openAccordions } =
useAppSelector(systemSelector);
const { isGFPGANAvailable, isESRGANAvailable, openAccordions } =
useAppSelector(systemSelector);
const dispatch = useAppDispatch();
const dispatch = useAppDispatch();
return (
<Accordion
defaultIndex={openAccordions}
allowMultiple
reduceMotion
onChange={(openAccordions) =>
dispatch(setOpenAccordions(openAccordions))
}
>
<AccordionItem>
<h2>
<AccordionButton>
<Box flex='1' textAlign='left'>
Seed & Variation
</Box>
<AccordionIcon />
</AccordionButton>
</h2>
<AccordionPanel>
<SeedVariationOptions />
</AccordionPanel>
</AccordionItem>
<AccordionItem>
<h2>
<AccordionButton>
<Box flex='1' textAlign='left'>
Sampler
</Box>
<AccordionIcon />
</AccordionButton>
</h2>
<AccordionPanel>
<SamplerOptions />
</AccordionPanel>
</AccordionItem>
<AccordionItem>
<h2>
<AccordionButton>
<Flex
justifyContent={'space-between'}
alignItems={'center'}
width={'100%'}
mr={2}
>
<Text>Upscale (ESRGAN)</Text>
<Switch
isDisabled={!isESRGANAvailable}
isChecked={shouldRunESRGAN}
onChange={(e) =>
dispatch(
setShouldRunESRGAN(e.target.checked)
)
}
/>
</Flex>
<AccordionIcon />
</AccordionButton>
</h2>
<AccordionPanel>
<ESRGANOptions />
</AccordionPanel>
</AccordionItem>
<AccordionItem>
<h2>
<AccordionButton>
<Flex
justifyContent={'space-between'}
alignItems={'center'}
width={'100%'}
mr={2}
>
<Text>Fix Faces (GFPGAN)</Text>
<Switch
isDisabled={!isGFPGANAvailable}
isChecked={shouldRunGFPGAN}
onChange={(e) =>
dispatch(
setShouldRunGFPGAN(e.target.checked)
)
}
/>
</Flex>
<AccordionIcon />
</AccordionButton>
</h2>
<AccordionPanel>
<GFPGANOptions />
</AccordionPanel>
</AccordionItem>
<AccordionItem>
<h2>
<AccordionButton>
<Flex
justifyContent={'space-between'}
alignItems={'center'}
width={'100%'}
mr={2}
>
<Text>Image to Image</Text>
<Switch
isDisabled={!initialImagePath}
isChecked={shouldUseInitImage}
onChange={(e) =>
dispatch(
setShouldUseInitImage(e.target.checked)
)
}
/>
</Flex>
<AccordionIcon />
</AccordionButton>
</h2>
<AccordionPanel>
<ImageToImageOptions />
</AccordionPanel>
</AccordionItem>
<AccordionItem>
<h2>
<AccordionButton>
<Box flex='1' textAlign='left'>
Output
</Box>
<AccordionIcon />
</AccordionButton>
</h2>
<AccordionPanel>
<OutputOptions />
</AccordionPanel>
</AccordionItem>
</Accordion>
);
/**
* Stores accordion state in redux so preferred UI setup is retained.
*/
const handleChangeAccordionState = (openAccordions: ExpandedIndex) =>
dispatch(setOpenAccordions(openAccordions));
const handleChangeShouldRunESRGAN = (e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldRunESRGAN(e.target.checked));
const handleChangeShouldRunGFPGAN = (e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldRunGFPGAN(e.target.checked));
const handleChangeShouldUseInitImage = (e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldUseInitImage(e.target.checked));
return (
<Accordion
defaultIndex={openAccordions}
allowMultiple
reduceMotion
onChange={handleChangeAccordionState}
>
<AccordionItem>
<h2>
<AccordionButton>
<Box flex="1" textAlign="left">
Seed & Variation
</Box>
<AccordionIcon />
</AccordionButton>
</h2>
<AccordionPanel>
<SeedVariationOptions />
</AccordionPanel>
</AccordionItem>
<AccordionItem>
<h2>
<AccordionButton>
<Box flex="1" textAlign="left">
Sampler
</Box>
<AccordionIcon />
</AccordionButton>
</h2>
<AccordionPanel>
<SamplerOptions />
</AccordionPanel>
</AccordionItem>
<AccordionItem>
<h2>
<AccordionButton>
<Flex
justifyContent={'space-between'}
alignItems={'center'}
width={'100%'}
mr={2}
>
<Text>Upscale (ESRGAN)</Text>
<Switch
isDisabled={!isESRGANAvailable}
isChecked={shouldRunESRGAN}
onChange={handleChangeShouldRunESRGAN}
/>
</Flex>
<AccordionIcon />
</AccordionButton>
</h2>
<AccordionPanel>
<ESRGANOptions />
</AccordionPanel>
</AccordionItem>
<AccordionItem>
<h2>
<AccordionButton>
<Flex
justifyContent={'space-between'}
alignItems={'center'}
width={'100%'}
mr={2}
>
<Text>Fix Faces (GFPGAN)</Text>
<Switch
isDisabled={!isGFPGANAvailable}
isChecked={shouldRunGFPGAN}
onChange={handleChangeShouldRunGFPGAN}
/>
</Flex>
<AccordionIcon />
</AccordionButton>
</h2>
<AccordionPanel>
<GFPGANOptions />
</AccordionPanel>
</AccordionItem>
<AccordionItem>
<h2>
<AccordionButton>
<Flex
justifyContent={'space-between'}
alignItems={'center'}
width={'100%'}
mr={2}
>
<Text>Image to Image</Text>
<Switch
isDisabled={!initialImagePath}
isChecked={shouldUseInitImage}
onChange={handleChangeShouldUseInitImage}
/>
</Flex>
<AccordionIcon />
</AccordionButton>
</h2>
<AccordionPanel>
<ImageToImageOptions />
</AccordionPanel>
</AccordionItem>
<AccordionItem>
<h2>
<AccordionButton>
<Box flex="1" textAlign="left">
Output
</Box>
<AccordionIcon />
</AccordionButton>
</h2>
<AccordionPanel>
<OutputOptions />
</AccordionPanel>
</AccordionItem>
</Accordion>
);
};
export default OptionsAccordion;

View File

@ -1,66 +1,76 @@
import { Flex } from '@chakra-ui/react';
import { RootState } from '../../app/store';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { useAppDispatch, useAppSelector } from '../../app/store';
import { setHeight, setWidth, setSeamless, SDState } from '../sd/sdSlice';
import SDSelect from '../../components/SDSelect';
import { HEIGHTS, WIDTHS } from '../../app/constants';
import SDSwitch from '../../components/SDSwitch';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { ChangeEvent } from 'react';
import SDSelect from '../../common/components/SDSelect';
import SDSwitch from '../../common/components/SDSwitch';
const sdSelector = createSelector(
(state: RootState) => state.sd,
(sd: SDState) => {
return {
height: sd.height,
width: sd.width,
seamless: sd.seamless,
};
(state: RootState) => state.sd,
(sd: SDState) => {
return {
height: sd.height,
width: sd.width,
seamless: sd.seamless,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
}
);
/**
* Image output options. Includes width, height, seamless tiling.
*/
const OutputOptions = () => {
const { height, width, seamless } = useAppSelector(sdSelector);
const dispatch = useAppDispatch();
const { height, width, seamless } = useAppSelector(sdSelector);
const dispatch = useAppDispatch();
const handleChangeWidth = (e: ChangeEvent<HTMLSelectElement>) =>
dispatch(setWidth(Number(e.target.value)));
return (
<Flex gap={2} direction={'column'}>
<Flex gap={2}>
<SDSelect
label='Width'
value={width}
flexGrow={1}
onChange={(e) => dispatch(setWidth(Number(e.target.value)))}
validValues={WIDTHS}
/>
<SDSelect
label='Height'
value={height}
flexGrow={1}
onChange={(e) =>
dispatch(setHeight(Number(e.target.value)))
}
validValues={HEIGHTS}
/>
</Flex>
<SDSwitch
label='Seamless tiling'
fontSize={'md'}
isChecked={seamless}
onChange={(e) => dispatch(setSeamless(e.target.checked))}
/>
</Flex>
);
const handleChangeHeight = (e: ChangeEvent<HTMLSelectElement>) =>
dispatch(setHeight(Number(e.target.value)));
const handleChangeSeamless = (e: ChangeEvent<HTMLInputElement>) =>
dispatch(setSeamless(e.target.checked));
return (
<Flex gap={2} direction={'column'}>
<Flex gap={2}>
<SDSelect
label="Width"
value={width}
flexGrow={1}
onChange={handleChangeWidth}
validValues={WIDTHS}
/>
<SDSelect
label="Height"
value={height}
flexGrow={1}
onChange={handleChangeHeight}
validValues={HEIGHTS}
/>
</Flex>
<SDSwitch
label="Seamless tiling"
fontSize={'md'}
isChecked={seamless}
onChange={handleChangeSeamless}
/>
</Flex>
);
};
export default OutputOptions;

View File

@ -1,58 +1,68 @@
import { Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { cancelProcessing, generateImage } from '../../app/socketio';
import { useAppDispatch, useAppSelector } from '../../app/store';
import { cancelProcessing, generateImage } from '../../app/socketio/actions';
import { RootState } from '../../app/store';
import SDButton from '../../components/SDButton';
import SDButton from '../../common/components/SDButton';
import useCheckParameters from '../../common/hooks/useCheckParameters';
import { SystemState } from '../system/systemSlice';
import useCheckParameters from '../system/useCheckParameters';
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return {
isProcessing: system.isProcessing,
isConnected: system.isConnected,
};
(state: RootState) => state.system,
(system: SystemState) => {
return {
isProcessing: system.isProcessing,
isConnected: system.isConnected,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
}
);
/**
* Buttons to start and cancel image generation.
*/
const ProcessButtons = () => {
const { isProcessing, isConnected } = useAppSelector(systemSelector);
const dispatch = useAppDispatch();
const { isProcessing, isConnected } = useAppSelector(systemSelector);
const isReady = useCheckParameters();
const dispatch = useAppDispatch();
const handleClickGenerate = () => dispatch(generateImage());
const isReady = useCheckParameters();
const handleClickCancel = () => dispatch(cancelProcessing());
return (
<Flex gap={2} direction={'column'} alignItems={'space-between'} height={'100%'}>
<SDButton
label='Generate'
type='submit'
colorScheme='green'
flexGrow={1}
isDisabled={!isReady}
fontSize={'md'}
size={'md'}
onClick={() => dispatch(generateImage())}
/>
<SDButton
label='Cancel'
colorScheme='red'
flexGrow={1}
fontSize={'md'}
size={'md'}
isDisabled={!isConnected || !isProcessing}
onClick={() => dispatch(cancelProcessing())}
/>
</Flex>
);
return (
<Flex
gap={2}
direction={'column'}
alignItems={'space-between'}
height={'100%'}
>
<SDButton
label="Generate"
type="submit"
colorScheme="green"
flexGrow={1}
isDisabled={!isReady}
fontSize={'md'}
size={'md'}
onClick={handleClickGenerate}
/>
<SDButton
label="Cancel"
colorScheme="red"
flexGrow={1}
fontSize={'md'}
size={'md'}
isDisabled={!isConnected || !isProcessing}
onClick={handleClickCancel}
/>
</Flex>
);
};
export default ProcessButtons;

View File

@ -1,21 +1,40 @@
import { Textarea } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import {
ChangeEvent,
KeyboardEvent,
} from 'react';
import { useAppDispatch, useAppSelector } from '../../app/store';
import { generateImage } from '../../app/socketio/actions';
import { RootState } from '../../app/store';
import { setPrompt } from '../sd/sdSlice';
/**
* Prompt input text area.
*/
const PromptInput = () => {
const { prompt } = useAppSelector((state: RootState) => state.sd);
const dispatch = useAppDispatch();
const handleChangePrompt = (e: ChangeEvent<HTMLTextAreaElement>) =>
dispatch(setPrompt(e.target.value));
const handleKeyDown = (e: KeyboardEvent<HTMLTextAreaElement>) => {
if (e.key === 'Enter' && e.shiftKey === false) {
e.preventDefault();
dispatch(generateImage())
}
};
return (
<Textarea
id='prompt'
name='prompt'
resize='none'
id="prompt"
name="prompt"
resize="none"
size={'lg'}
height={'100%'}
isInvalid={!prompt.length}
onChange={(e) => dispatch(setPrompt(e.target.value))}
onChange={handleChangePrompt}
onKeyDown={handleKeyDown}
value={prompt}
placeholder="I'm dreaming of..."
/>

View File

@ -1,51 +0,0 @@
import {
Slider,
SliderTrack,
SliderFilledTrack,
SliderThumb,
FormControl,
FormLabel,
Text,
Flex,
SliderProps,
} from '@chakra-ui/react';
interface Props extends SliderProps {
label: string;
value: number;
fontSize?: number | string;
}
const SDSlider = ({
label,
value,
fontSize = 'sm',
onChange,
...rest
}: Props) => {
return (
<FormControl>
<Flex gap={2}>
<FormLabel marginInlineEnd={0} marginBottom={1}>
<Text fontSize={fontSize} whiteSpace='nowrap'>
{label}
</Text>
</FormLabel>
<Slider
aria-label={label}
focusThumbOnChange={true}
value={value}
onChange={onChange}
{...rest}
>
<SliderTrack>
<SliderFilledTrack />
</SliderTrack>
<SliderThumb />
</Slider>
</Flex>
</FormControl>
);
};
export default SDSlider;

View File

@ -1,62 +1,74 @@
import { Flex } from '@chakra-ui/react';
import { RootState } from '../../app/store';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { useAppDispatch, useAppSelector } from '../../app/store';
import { setCfgScale, setSampler, setSteps, SDState } from '../sd/sdSlice';
import SDNumberInput from '../../components/SDNumberInput';
import SDSelect from '../../components/SDSelect';
import { SAMPLERS } from '../../app/constants';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { ChangeEvent } from 'react';
import SDNumberInput from '../../common/components/SDNumberInput';
import SDSelect from '../../common/components/SDSelect';
const sdSelector = createSelector(
(state: RootState) => state.sd,
(sd: SDState) => {
return {
steps: sd.steps,
cfgScale: sd.cfgScale,
sampler: sd.sampler,
};
(state: RootState) => state.sd,
(sd: SDState) => {
return {
steps: sd.steps,
cfgScale: sd.cfgScale,
sampler: sd.sampler,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
}
);
/**
* Sampler options. Includes steps, CFG scale, sampler.
*/
const SamplerOptions = () => {
const { steps, cfgScale, sampler } = useAppSelector(sdSelector);
const dispatch = useAppDispatch();
const { steps, cfgScale, sampler } = useAppSelector(sdSelector);
const dispatch = useAppDispatch();
const handleChangeSteps = (v: string | number) =>
dispatch(setSteps(Number(v)));
return (
<Flex gap={2} direction={'column'}>
<SDNumberInput
label='Steps'
min={1}
step={1}
precision={0}
onChange={(v) => dispatch(setSteps(Number(v)))}
value={steps}
/>
<SDNumberInput
label='CFG scale'
step={0.5}
onChange={(v) => dispatch(setCfgScale(Number(v)))}
value={cfgScale}
/>
<SDSelect
label='Sampler'
value={sampler}
onChange={(e) => dispatch(setSampler(e.target.value))}
validValues={SAMPLERS}
/>
</Flex>
);
const handleChangeCfgScale = (v: string | number) =>
dispatch(setCfgScale(Number(v)));
const handleChangeSampler = (e: ChangeEvent<HTMLSelectElement>) =>
dispatch(setSampler(e.target.value));
return (
<Flex gap={2} direction={'column'}>
<SDNumberInput
label="Steps"
min={1}
step={1}
precision={0}
onChange={handleChangeSteps}
value={steps}
/>
<SDNumberInput
label="CFG scale"
step={0.5}
onChange={handleChangeCfgScale}
value={cfgScale}
/>
<SDSelect
label="Sampler"
value={sampler}
onChange={handleChangeSampler}
validValues={SAMPLERS}
/>
</Flex>
);
};
export default SamplerOptions;

View File

@ -1,144 +1,159 @@
import {
Flex,
Input,
HStack,
FormControl,
FormLabel,
Text,
Button,
Flex,
Input,
HStack,
FormControl,
FormLabel,
Text,
Button,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { ChangeEvent } from 'react';
import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from '../../app/constants';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { useAppDispatch, useAppSelector } from '../../app/store';
import { RootState } from '../../app/store';
import SDNumberInput from '../../components/SDNumberInput';
import SDSwitch from '../../components/SDSwitch';
import SDNumberInput from '../../common/components/SDNumberInput';
import SDSwitch from '../../common/components/SDSwitch';
import randomInt from '../../common/util/randomInt';
import { validateSeedWeights } from '../../common/util/seedWeightPairs';
import {
randomizeSeed,
SDState,
setIterations,
setSeed,
setSeedWeights,
setShouldGenerateVariations,
setShouldRandomizeSeed,
setVariantAmount,
SDState,
setIterations,
setSeed,
setSeedWeights,
setShouldGenerateVariations,
setShouldRandomizeSeed,
setVariationAmount,
} from './sdSlice';
import { validateSeedWeights } from './util/seedWeightPairs';
const sdSelector = createSelector(
(state: RootState) => state.sd,
(sd: SDState) => {
return {
variantAmount: sd.variantAmount,
seedWeights: sd.seedWeights,
shouldGenerateVariations: sd.shouldGenerateVariations,
shouldRandomizeSeed: sd.shouldRandomizeSeed,
seed: sd.seed,
iterations: sd.iterations,
};
(state: RootState) => state.sd,
(sd: SDState) => {
return {
variationAmount: sd.variationAmount,
seedWeights: sd.seedWeights,
shouldGenerateVariations: sd.shouldGenerateVariations,
shouldRandomizeSeed: sd.shouldRandomizeSeed,
seed: sd.seed,
iterations: sd.iterations,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
}
);
/**
* Seed & variation options. Includes iteration, seed, seed randomization, variation options.
*/
const SeedVariationOptions = () => {
const {
shouldGenerateVariations,
variantAmount,
seedWeights,
shouldRandomizeSeed,
seed,
iterations,
} = useAppSelector(sdSelector);
const {
shouldGenerateVariations,
variationAmount,
seedWeights,
shouldRandomizeSeed,
seed,
iterations,
} = useAppSelector(sdSelector);
const dispatch = useAppDispatch();
const dispatch = useAppDispatch();
return (
<Flex gap={2} direction={'column'}>
<SDNumberInput
label='Images to generate'
step={1}
min={1}
precision={0}
onChange={(v) => dispatch(setIterations(Number(v)))}
value={iterations}
/>
<SDSwitch
label='Randomize seed on generation'
isChecked={shouldRandomizeSeed}
onChange={(e) =>
dispatch(setShouldRandomizeSeed(e.target.checked))
}
/>
<Flex gap={2}>
<SDNumberInput
label='Seed'
step={1}
precision={0}
flexGrow={1}
min={NUMPY_RAND_MIN}
max={NUMPY_RAND_MAX}
isDisabled={shouldRandomizeSeed}
isInvalid={seed < 0 && shouldGenerateVariations}
onChange={(v) => dispatch(setSeed(Number(v)))}
value={seed}
/>
<Button
size={'sm'}
isDisabled={shouldRandomizeSeed}
onClick={() => dispatch(randomizeSeed())}
>
<Text pl={2} pr={2}>
Shuffle
</Text>
</Button>
</Flex>
<SDSwitch
label='Generate variations'
isChecked={shouldGenerateVariations}
width={'auto'}
onChange={(e) =>
dispatch(setShouldGenerateVariations(e.target.checked))
}
/>
<SDNumberInput
label='Variation amount'
value={variantAmount}
step={0.01}
min={0}
max={1}
isDisabled={!shouldGenerateVariations}
onChange={(v) => dispatch(setVariantAmount(Number(v)))}
/>
<FormControl
isInvalid={
shouldGenerateVariations &&
!(validateSeedWeights(seedWeights) || seedWeights === '')
}
flexGrow={1}
isDisabled={!shouldGenerateVariations}
>
<HStack>
<FormLabel marginInlineEnd={0} marginBottom={1}>
<Text whiteSpace='nowrap'>
Seed Weights
</Text>
</FormLabel>
<Input
size={'sm'}
value={seedWeights}
onChange={(e) =>
dispatch(setSeedWeights(e.target.value))
}
/>
</HStack>
</FormControl>
</Flex>
);
const handleChangeIterations = (v: string | number) =>
dispatch(setIterations(Number(v)));
const handleChangeShouldRandomizeSeed = (e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldRandomizeSeed(e.target.checked));
const handleChangeSeed = (v: string | number) => dispatch(setSeed(Number(v)));
const handleClickRandomizeSeed = () =>
dispatch(setSeed(randomInt(NUMPY_RAND_MIN, NUMPY_RAND_MAX)));
const handleChangeShouldGenerateVariations = (
e: ChangeEvent<HTMLInputElement>
) => dispatch(setShouldGenerateVariations(e.target.checked));
const handleChangevariationAmount = (v: string | number) =>
dispatch(setVariationAmount(Number(v)));
const handleChangeSeedWeights = (e: ChangeEvent<HTMLInputElement>) =>
dispatch(setSeedWeights(e.target.value));
return (
<Flex gap={2} direction={'column'}>
<SDNumberInput
label="Images to generate"
step={1}
min={1}
precision={0}
onChange={handleChangeIterations}
value={iterations}
/>
<SDSwitch
label="Randomize seed on generation"
isChecked={shouldRandomizeSeed}
onChange={handleChangeShouldRandomizeSeed}
/>
<Flex gap={2}>
<SDNumberInput
label="Seed"
step={1}
precision={0}
flexGrow={1}
min={NUMPY_RAND_MIN}
max={NUMPY_RAND_MAX}
isDisabled={shouldRandomizeSeed}
isInvalid={seed < 0 && shouldGenerateVariations}
onChange={handleChangeSeed}
value={seed}
/>
<Button
size={'sm'}
isDisabled={shouldRandomizeSeed}
onClick={handleClickRandomizeSeed}
>
<Text pl={2} pr={2}>
Shuffle
</Text>
</Button>
</Flex>
<SDSwitch
label="Generate variations"
isChecked={shouldGenerateVariations}
width={'auto'}
onChange={handleChangeShouldGenerateVariations}
/>
<SDNumberInput
label="Variation amount"
value={variationAmount}
step={0.01}
min={0}
max={1}
onChange={handleChangevariationAmount}
/>
<FormControl
isInvalid={
shouldGenerateVariations &&
!(validateSeedWeights(seedWeights) || seedWeights === '')
}
flexGrow={1}
>
<HStack>
<FormLabel marginInlineEnd={0} marginBottom={1}>
<Text whiteSpace="nowrap">Seed Weights</Text>
</FormLabel>
<Input
size={'sm'}
value={seedWeights}
onChange={handleChangeSeedWeights}
/>
</HStack>
</FormControl>
</Flex>
);
};
export default SeedVariationOptions;

View File

@ -1,92 +0,0 @@
import {
Flex,
FormControl,
FormLabel,
HStack,
Input,
Text,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { RootState } from '../../app/store';
import SDNumberInput from '../../components/SDNumberInput';
import SDSwitch from '../../components/SDSwitch';
import {
SDState,
setSeedWeights,
setShouldGenerateVariations,
setVariantAmount,
} from './sdSlice';
import { validateSeedWeights } from './util/seedWeightPairs';
const sdSelector = createSelector(
(state: RootState) => state.sd,
(sd: SDState) => {
return {
variantAmount: sd.variantAmount,
seedWeights: sd.seedWeights,
shouldGenerateVariations: sd.shouldGenerateVariations,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const Variant = () => {
const { shouldGenerateVariations, variantAmount, seedWeights } =
useAppSelector(sdSelector);
const dispatch = useAppDispatch();
return (
<Flex gap={2} alignItems={'center'} pl={1}>
<SDSwitch
label='Generate variations'
isChecked={shouldGenerateVariations}
width={'auto'}
onChange={(e) =>
dispatch(setShouldGenerateVariations(e.target.checked))
}
/>
<SDNumberInput
label='Amount'
value={variantAmount}
step={0.01}
min={0}
max={1}
width={240}
isDisabled={!shouldGenerateVariations}
onChange={(v) => dispatch(setVariantAmount(Number(v)))}
/>
<FormControl
isInvalid={
shouldGenerateVariations &&
!(validateSeedWeights(seedWeights) || seedWeights === '')
}
flexGrow={1}
isDisabled={!shouldGenerateVariations}
>
<HStack>
<FormLabel marginInlineEnd={0} marginBottom={1}>
<Text fontSize={'sm'} whiteSpace='nowrap'>
Seed Weights
</Text>
</FormLabel>
<Input
size={'sm'}
value={seedWeights}
onChange={(e) =>
dispatch(setSeedWeights(e.target.value))
}
/>
</HStack>
</FormControl>
</Flex>
);
};
export default Variant;

View File

@ -1,24 +1,13 @@
import { createSlice } from '@reduxjs/toolkit';
import type { PayloadAction } from '@reduxjs/toolkit';
import { SDMetadata } from '../gallery/gallerySlice';
import randomInt from './util/randomInt';
import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from '../../app/constants';
const calculateRealSteps = (
steps: number,
strength: number,
hasInitImage: boolean
): number => {
return hasInitImage ? Math.floor(strength * steps) : steps;
};
export type UpscalingLevel = 0 | 2 | 3 | 4;
export type UpscalingLevel = 2 | 4;
export interface SDState {
prompt: string;
iterations: number;
steps: number;
realSteps: number;
cfgScale: number;
height: number;
width: number;
@ -34,7 +23,7 @@ export interface SDState {
seamless: boolean;
shouldFitToWidthHeight: boolean;
shouldGenerateVariations: boolean;
variantAmount: number;
variationAmount: number;
seedWeights: string;
shouldRunESRGAN: boolean;
shouldRunGFPGAN: boolean;
@ -45,7 +34,6 @@ const initialSDState: SDState = {
prompt: '',
iterations: 1,
steps: 50,
realSteps: 50,
cfgScale: 7.5,
height: 512,
width: 512,
@ -58,7 +46,7 @@ const initialSDState: SDState = {
maskPath: '',
shouldFitToWidthHeight: true,
shouldGenerateVariations: false,
variantAmount: 0.1,
variationAmount: 0.1,
seedWeights: '',
shouldRunESRGAN: false,
upscalingLevel: 4,
@ -81,14 +69,7 @@ export const sdSlice = createSlice({
state.iterations = action.payload;
},
setSteps: (state, action: PayloadAction<number>) => {
const { img2imgStrength, initialImagePath } = state;
const steps = action.payload;
state.steps = steps;
state.realSteps = calculateRealSteps(
steps,
img2imgStrength,
Boolean(initialImagePath)
);
state.steps = action.payload;
},
setCfgScale: (state, action: PayloadAction<number>) => {
state.cfgScale = action.payload;
@ -107,14 +88,7 @@ export const sdSlice = createSlice({
state.shouldRandomizeSeed = false;
},
setImg2imgStrength: (state, action: PayloadAction<number>) => {
const img2imgStrength = action.payload;
const { steps, initialImagePath } = state;
state.img2imgStrength = img2imgStrength;
state.realSteps = calculateRealSteps(
steps,
img2imgStrength,
Boolean(initialImagePath)
);
state.img2imgStrength = action.payload;
},
setGfpganStrength: (state, action: PayloadAction<number>) => {
state.gfpganStrength = action.payload;
@ -129,15 +103,9 @@ export const sdSlice = createSlice({
state.shouldUseInitImage = action.payload;
},
setInitialImagePath: (state, action: PayloadAction<string>) => {
const initialImagePath = action.payload;
const { steps, img2imgStrength } = state;
state.shouldUseInitImage = initialImagePath ? true : false;
state.initialImagePath = initialImagePath;
state.realSteps = calculateRealSteps(
steps,
img2imgStrength,
Boolean(initialImagePath)
);
const newInitialImagePath = action.payload;
state.shouldUseInitImage = newInitialImagePath ? true : false;
state.initialImagePath = newInitialImagePath;
},
setMaskPath: (state, action: PayloadAction<string>) => {
state.maskPath = action.payload;
@ -151,13 +119,11 @@ export const sdSlice = createSlice({
resetSeed: (state) => {
state.seed = -1;
},
randomizeSeed: (state) => {
state.seed = randomInt(NUMPY_RAND_MIN, NUMPY_RAND_MAX);
},
setParameter: (
state,
action: PayloadAction<{ key: string; value: string | number | boolean }>
) => {
// TODO: This probably needs to be refactored.
const { key, value } = action.payload;
const temp = { ...state, [key]: value };
if (key === 'seed') {
@ -171,13 +137,14 @@ export const sdSlice = createSlice({
setShouldGenerateVariations: (state, action: PayloadAction<boolean>) => {
state.shouldGenerateVariations = action.payload;
},
setVariantAmount: (state, action: PayloadAction<number>) => {
state.variantAmount = action.payload;
setVariationAmount: (state, action: PayloadAction<number>) => {
state.variationAmount = action.payload;
},
setSeedWeights: (state, action: PayloadAction<string>) => {
state.seedWeights = action.payload;
},
setAllParameters: (state, action: PayloadAction<SDMetadata>) => {
// TODO: This probably needs to be refactored.
const {
prompt,
steps,
@ -267,13 +234,12 @@ export const {
setInitialImagePath,
setMaskPath,
resetSeed,
randomizeSeed,
resetSDState,
setShouldFitToWidthHeight,
setParameter,
setShouldGenerateVariations,
setSeedWeights,
setVariantAmount,
setVariationAmount,
setAllParameters,
setShouldRunGFPGAN,
setShouldRunESRGAN,

View File

@ -1,11 +1,11 @@
import {
IconButton,
useColorModeValue,
Flex,
Text,
Tooltip,
IconButton,
useColorModeValue,
Flex,
Text,
Tooltip,
} from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { useAppDispatch, useAppSelector } from '../../app/store';
import { RootState } from '../../app/store';
import { setShouldShowLogViewer, SystemState } from './systemSlice';
import { useLayoutEffect, useRef, useState } from 'react';
@ -14,112 +14,138 @@ import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
const logSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => system.log,
{
memoizeOptions: {
resultEqualityCheck: (a, b) => a.length === b.length,
},
}
(state: RootState) => state.system,
(system: SystemState) => system.log,
{
memoizeOptions: {
// We don't need a deep equality check for this selector.
resultEqualityCheck: (a, b) => a.length === b.length,
},
}
);
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return { shouldShowLogViewer: system.shouldShowLogViewer };
(state: RootState) => state.system,
(system: SystemState) => {
return { shouldShowLogViewer: system.shouldShowLogViewer };
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
}
);
/**
* Basic log viewer, floats on bottom of page.
*/
const LogViewer = () => {
const dispatch = useAppDispatch();
const bg = useColorModeValue('gray.50', 'gray.900');
const borderColor = useColorModeValue('gray.500', 'gray.500');
const [shouldAutoscroll, setShouldAutoscroll] = useState<boolean>(true);
const dispatch = useAppDispatch();
const log = useAppSelector(logSelector);
const { shouldShowLogViewer } = useAppSelector(systemSelector);
const log = useAppSelector(logSelector);
const { shouldShowLogViewer } = useAppSelector(systemSelector);
// Set colors based on dark/light mode
const bg = useColorModeValue('gray.50', 'gray.900');
const borderColor = useColorModeValue('gray.500', 'gray.500');
const logTextColors = useColorModeValue(
{
info: undefined,
warning: 'yellow.500',
error: 'red.500',
},
{
info: undefined,
warning: 'yellow.300',
error: 'red.300',
}
);
const viewerRef = useRef<HTMLDivElement>(null);
// Rudimentary autoscroll
const [shouldAutoscroll, setShouldAutoscroll] = useState<boolean>(true);
const viewerRef = useRef<HTMLDivElement>(null);
useLayoutEffect(() => {
if (viewerRef.current !== null && shouldAutoscroll) {
viewerRef.current.scrollTop = viewerRef.current.scrollHeight;
}
});
/**
* If autoscroll is on, scroll to the bottom when:
* - log updates
* - viewer is toggled
*
* Also scroll to the bottom whenever autoscroll is turned on.
*/
useLayoutEffect(() => {
if (viewerRef.current !== null && shouldAutoscroll) {
viewerRef.current.scrollTop = viewerRef.current.scrollHeight;
}
}, [shouldAutoscroll, log, shouldShowLogViewer]);
return (
<>
{shouldShowLogViewer && (
<Flex
position={'fixed'}
left={0}
bottom={0}
height='200px'
width='100vw'
overflow='auto'
direction='column'
fontFamily='monospace'
fontSize='sm'
pl={12}
pr={2}
pb={2}
borderTopWidth='4px'
borderColor={borderColor}
background={bg}
ref={viewerRef}
>
{log.map((entry, i) => (
<Flex gap={2} key={i}>
<Text fontSize='sm' fontWeight={'semibold'}>
{entry.timestamp}:
</Text>
<Text fontSize='sm' wordBreak={'break-all'}>
{entry.message}
</Text>
</Flex>
))}
</Flex>
)}
{shouldShowLogViewer && (
<Tooltip
label={
shouldAutoscroll ? 'Autoscroll on' : 'Autoscroll off'
}
>
<IconButton
size='sm'
position={'fixed'}
left={2}
bottom={12}
aria-label='Toggle autoscroll'
variant={'solid'}
colorScheme={shouldAutoscroll ? 'blue' : 'gray'}
icon={<FaAngleDoubleDown />}
onClick={() => setShouldAutoscroll(!shouldAutoscroll)}
/>
</Tooltip>
)}
<Tooltip label={shouldShowLogViewer ? 'Hide logs' : 'Show logs'}>
<IconButton
size='sm'
position={'fixed'}
left={2}
bottom={2}
variant={'solid'}
aria-label='Toggle Log Viewer'
icon={shouldShowLogViewer ? <FaMinus /> : <FaCode />}
onClick={() =>
dispatch(setShouldShowLogViewer(!shouldShowLogViewer))
}
/>
</Tooltip>
</>
);
const handleClickLogViewerToggle = () => {
dispatch(setShouldShowLogViewer(!shouldShowLogViewer));
};
return (
<>
{shouldShowLogViewer && (
<Flex
position={'fixed'}
left={0}
bottom={0}
height="200px" // TODO: Make the log viewer resizeable.
width="100vw"
overflow="auto"
direction="column"
fontFamily="monospace"
fontSize="sm"
pl={12}
pr={2}
pb={2}
borderTopWidth="4px"
borderColor={borderColor}
background={bg}
ref={viewerRef}
>
{log.map((entry, i) => {
const { timestamp, message, level } = entry;
return (
<Flex gap={2} key={i} textColor={logTextColors[level]}>
<Text fontSize="sm" fontWeight={'semibold'}>
{timestamp}:
</Text>
<Text fontSize="sm" wordBreak={'break-all'}>
{message}
</Text>
</Flex>
);
})}
</Flex>
)}
{shouldShowLogViewer && (
<Tooltip label={shouldAutoscroll ? 'Autoscroll on' : 'Autoscroll off'}>
<IconButton
size="sm"
position={'fixed'}
left={2}
bottom={12}
aria-label="Toggle autoscroll"
variant={'solid'}
colorScheme={shouldAutoscroll ? 'blue' : 'gray'}
icon={<FaAngleDoubleDown />}
onClick={() => setShouldAutoscroll(!shouldAutoscroll)}
/>
</Tooltip>
)}
<Tooltip label={shouldShowLogViewer ? 'Hide logs' : 'Show logs'}>
<IconButton
size="sm"
position={'fixed'}
left={2}
bottom={2}
variant={'solid'}
aria-label="Toggle Log Viewer"
icon={shouldShowLogViewer ? <FaMinus /> : <FaCode />}
onClick={handleClickLogViewerToggle}
/>
</Tooltip>
</>
);
};
export default LogViewer;

View File

@ -1,170 +1,164 @@
import {
Flex,
FormControl,
FormLabel,
Heading,
HStack,
Modal,
ModalBody,
ModalCloseButton,
ModalContent,
ModalFooter,
ModalHeader,
ModalOverlay,
Switch,
Text,
useDisclosure,
Button,
Flex,
FormControl,
FormLabel,
Heading,
HStack,
Modal,
ModalBody,
ModalCloseButton,
ModalContent,
ModalFooter,
ModalHeader,
ModalOverlay,
Switch,
Text,
useDisclosure,
} from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { useAppDispatch, useAppSelector } from '../../app/store';
import {
setShouldConfirmOnDelete,
setShouldDisplayInProgress,
SystemState,
setShouldConfirmOnDelete,
setShouldDisplayInProgress,
SystemState,
} from './systemSlice';
import { RootState } from '../../app/store';
import SDButton from '../../components/SDButton';
import { persistor } from '../../main';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { cloneElement, ReactElement } from 'react';
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
const { shouldDisplayInProgress, shouldConfirmOnDelete } = system;
return { shouldDisplayInProgress, shouldConfirmOnDelete };
},
{
memoizeOptions: { resultEqualityCheck: isEqual },
}
(state: RootState) => state.system,
(system: SystemState) => {
const { shouldDisplayInProgress, shouldConfirmOnDelete } = system;
return { shouldDisplayInProgress, shouldConfirmOnDelete };
},
{
memoizeOptions: { resultEqualityCheck: isEqual },
}
);
type Props = {
children: ReactElement;
type SettingsModalProps = {
/* The button to open the Settings Modal */
children: ReactElement;
};
const SettingsModal = ({ children }: Props) => {
const {
isOpen: isSettingsModalOpen,
onOpen: onSettingsModalOpen,
onClose: onSettingsModalClose,
} = useDisclosure();
/**
* Modal for app settings. Also provides Reset functionality in which the
* app's localstorage is wiped via redux-persist.
*
* Secondary post-reset modal is included here.
*/
const SettingsModal = ({ children }: SettingsModalProps) => {
const {
isOpen: isSettingsModalOpen,
onOpen: onSettingsModalOpen,
onClose: onSettingsModalClose,
} = useDisclosure();
const {
isOpen: isRefreshModalOpen,
onOpen: onRefreshModalOpen,
onClose: onRefreshModalClose,
} = useDisclosure();
const {
isOpen: isRefreshModalOpen,
onOpen: onRefreshModalOpen,
onClose: onRefreshModalClose,
} = useDisclosure();
const { shouldDisplayInProgress, shouldConfirmOnDelete } =
useAppSelector(systemSelector);
const { shouldDisplayInProgress, shouldConfirmOnDelete } =
useAppSelector(systemSelector);
const dispatch = useAppDispatch();
const dispatch = useAppDispatch();
const handleClickResetWebUI = () => {
persistor.purge().then(() => {
onSettingsModalClose();
onRefreshModalOpen();
});
};
/**
* Resets localstorage, then opens a secondary modal informing user to
* refresh their browser.
* */
const handleClickResetWebUI = () => {
persistor.purge().then(() => {
onSettingsModalClose();
onRefreshModalOpen();
});
};
return (
<>
{cloneElement(children, {
onClick: onSettingsModalOpen,
})}
return (
<>
{cloneElement(children, {
onClick: onSettingsModalOpen,
})}
<Modal isOpen={isSettingsModalOpen} onClose={onSettingsModalClose}>
<ModalOverlay />
<ModalContent>
<ModalHeader>Settings</ModalHeader>
<ModalCloseButton />
<ModalBody>
<Flex gap={5} direction='column'>
<FormControl>
<HStack>
<FormLabel marginBottom={1}>
Display in-progress images (slower)
</FormLabel>
<Switch
isChecked={shouldDisplayInProgress}
onChange={(e) =>
dispatch(
setShouldDisplayInProgress(
e.target.checked
)
)
}
/>
</HStack>
</FormControl>
<FormControl>
<HStack>
<FormLabel marginBottom={1}>
Confirm on delete
</FormLabel>
<Switch
isChecked={shouldConfirmOnDelete}
onChange={(e) =>
dispatch(
setShouldConfirmOnDelete(
e.target.checked
)
)
}
/>
</HStack>
</FormControl>
<Modal isOpen={isSettingsModalOpen} onClose={onSettingsModalClose}>
<ModalOverlay />
<ModalContent>
<ModalHeader>Settings</ModalHeader>
<ModalCloseButton />
<ModalBody>
<Flex gap={5} direction="column">
<FormControl>
<HStack>
<FormLabel marginBottom={1}>
Display in-progress images (slower)
</FormLabel>
<Switch
isChecked={shouldDisplayInProgress}
onChange={(e) =>
dispatch(setShouldDisplayInProgress(e.target.checked))
}
/>
</HStack>
</FormControl>
<FormControl>
<HStack>
<FormLabel marginBottom={1}>Confirm on delete</FormLabel>
<Switch
isChecked={shouldConfirmOnDelete}
onChange={(e) =>
dispatch(setShouldConfirmOnDelete(e.target.checked))
}
/>
</HStack>
</FormControl>
<Heading size={'md'}>Reset Web UI</Heading>
<Text>
Resetting the web UI only resets the browser's
local cache of your images and remembered
settings. It does not delete any images from
disk.
</Text>
<Text>
If images aren't showing up in the gallery or
something else isn't working, please try
resetting before submitting an issue on GitHub.
</Text>
<SDButton
label='Reset Web UI'
colorScheme='red'
onClick={handleClickResetWebUI}
/>
</Flex>
</ModalBody>
<Heading size={'md'}>Reset Web UI</Heading>
<Text>
Resetting the web UI only resets the browser's local cache of
your images and remembered settings. It does not delete any
images from disk.
</Text>
<Text>
If images aren't showing up in the gallery or something else
isn't working, please try resetting before submitting an issue
on GitHub.
</Text>
<Button colorScheme="red" onClick={handleClickResetWebUI}>
Reset Web UI
</Button>
</Flex>
</ModalBody>
<ModalFooter>
<SDButton
label='Close'
onClick={onSettingsModalClose}
/>
</ModalFooter>
</ModalContent>
</Modal>
<ModalFooter>
<Button onClick={onSettingsModalClose}>Close</Button>
</ModalFooter>
</ModalContent>
</Modal>
<Modal
closeOnOverlayClick={false}
isOpen={isRefreshModalOpen}
onClose={onRefreshModalClose}
isCentered
>
<ModalOverlay bg='blackAlpha.300' backdropFilter='blur(40px)' />
<ModalContent>
<ModalBody pb={6} pt={6}>
<Flex justifyContent={'center'}>
<Text fontSize={'lg'}>
Web UI has been reset. Refresh the page to
reload.
</Text>
</Flex>
</ModalBody>
</ModalContent>
</Modal>
</>
);
<Modal
closeOnOverlayClick={false}
isOpen={isRefreshModalOpen}
onClose={onRefreshModalClose}
isCentered
>
<ModalOverlay bg="blackAlpha.300" backdropFilter="blur(40px)" />
<ModalContent>
<ModalBody pb={6} pt={6}>
<Flex justifyContent={'center'}>
<Text fontSize={'lg'}>
Web UI has been reset. Refresh the page to reload.
</Text>
</Flex>
</ModalBody>
</ModalContent>
</Modal>
</>
);
};
export default SettingsModal;

View File

@ -1,10 +1,12 @@
import { createSlice } from '@reduxjs/toolkit';
import type { PayloadAction } from '@reduxjs/toolkit';
import dateFormat from 'dateformat';
import { ExpandedIndex } from '@chakra-ui/react';
export type LogLevel = 'info' | 'warning' | 'error';
export interface LogEntry {
timestamp: string;
level: LogLevel;
message: string;
}
@ -12,10 +14,18 @@ export interface Log {
[index: number]: LogEntry;
}
export interface SystemState {
shouldDisplayInProgress: boolean;
export interface SystemStatus {
isProcessing: boolean;
currentStep: number;
totalSteps: number;
currentIteration: number;
totalIterations: number;
currentStatus: string;
currentStatusHasSteps: boolean;
}
export interface SystemState extends SystemStatus {
shouldDisplayInProgress: boolean;
log: Array<LogEntry>;
shouldShowLogViewer: boolean;
isGFPGANAvailable: boolean;
@ -24,12 +34,17 @@ export interface SystemState {
socketId: string;
shouldConfirmOnDelete: boolean;
openAccordions: ExpandedIndex;
currentStep: number;
totalSteps: number;
currentIteration: number;
totalIterations: number;
currentStatus: string;
currentStatusHasSteps: boolean;
}
const initialSystemState = {
isConnected: false,
isProcessing: false,
currentStep: 0,
log: [],
shouldShowLogViewer: false,
shouldDisplayInProgress: false,
@ -38,6 +53,12 @@ const initialSystemState = {
socketId: '',
shouldConfirmOnDelete: true,
openAccordions: [0],
currentStep: 0,
totalSteps: 0,
currentIteration: 0,
totalIterations: 0,
currentStatus: '',
currentStatusHasSteps: false,
};
const initialState: SystemState = initialSystemState;
@ -51,18 +72,35 @@ export const systemSlice = createSlice({
},
setIsProcessing: (state, action: PayloadAction<boolean>) => {
state.isProcessing = action.payload;
if (action.payload === false) {
state.currentStep = 0;
}
},
setCurrentStep: (state, action: PayloadAction<number>) => {
state.currentStep = action.payload;
setCurrentStatus: (state, action: PayloadAction<string>) => {
state.currentStatus = action.payload;
},
addLogEntry: (state, action: PayloadAction<string>) => {
setSystemStatus: (state, action: PayloadAction<SystemStatus>) => {
const currentStatus =
!action.payload.isProcessing && state.isConnected
? 'Connected'
: action.payload.currentStatus;
return { ...state, ...action.payload, currentStatus };
},
addLogEntry: (
state,
action: PayloadAction<{
timestamp: string;
message: string;
level?: LogLevel;
}>
) => {
const { timestamp, message, level } = action.payload;
const logLevel = level || 'info';
const entry: LogEntry = {
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: action.payload,
timestamp,
message,
level: logLevel,
};
state.log.push(entry);
},
setShouldShowLogViewer: (state, action: PayloadAction<boolean>) => {
@ -86,13 +124,14 @@ export const systemSlice = createSlice({
export const {
setShouldDisplayInProgress,
setIsProcessing,
setCurrentStep,
addLogEntry,
setShouldShowLogViewer,
setIsConnected,
setSocketId,
setShouldConfirmOnDelete,
setOpenAccordions,
setSystemStatus,
setCurrentStatus,
} = systemSlice.actions;
export default systemSlice.reducer;

View File

@ -1,108 +0,0 @@
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { useMemo } from 'react';
import { useAppSelector } from '../../app/hooks';
import { RootState } from '../../app/store';
import { SDState } from '../sd/sdSlice';
import { validateSeedWeights } from '../sd/util/seedWeightPairs';
import { SystemState } from './systemSlice';
const sdSelector = createSelector(
(state: RootState) => state.sd,
(sd: SDState) => {
return {
prompt: sd.prompt,
shouldGenerateVariations: sd.shouldGenerateVariations,
seedWeights: sd.seedWeights,
maskPath: sd.maskPath,
initialImagePath: sd.initialImagePath,
seed: sd.seed,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return {
isProcessing: system.isProcessing,
isConnected: system.isConnected,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
/*
Checks relevant pieces of state to confirm generation will not deterministically fail.
This is used to prevent the 'Generate' button from being clicked.
Other parameter values may cause failure but we rely on input validation for those.
*/
const useCheckParameters = () => {
const {
prompt,
shouldGenerateVariations,
seedWeights,
maskPath,
initialImagePath,
seed,
} = useAppSelector(sdSelector);
const { isProcessing, isConnected } = useAppSelector(systemSelector);
return useMemo(() => {
// Cannot generate without a prompt
if (!prompt) {
return false;
}
// Cannot generate with a mask without img2img
if (maskPath && !initialImagePath) {
return false;
}
// TODO: job queue
// Cannot generate if already processing an image
if (isProcessing) {
return false;
}
// Cannot generate if not connected
if (!isConnected) {
return false;
}
// Cannot generate variations without valid seed weights
if (
shouldGenerateVariations &&
(!(validateSeedWeights(seedWeights) || seedWeights === '') ||
seed === -1)
) {
return false;
}
// All good
return true;
}, [
prompt,
maskPath,
initialImagePath,
isProcessing,
isConnected,
shouldGenerateVariations,
seedWeights,
seed,
]);
};
export default useCheckParameters;

View File

@ -8,9 +8,9 @@ import { persistStore } from 'redux-persist';
export const persistor = persistStore(store);
import App from './App';
import { theme } from './app/theme';
import Loading from './Loading';
import App from './app/App';
ReactDOM.createRoot(document.getElementById('root') as HTMLElement).render(
<React.StrictMode>

View File

@ -2,7 +2,10 @@
The Args class parses both the command line (shell) arguments, as well as the
command string passed at the dream> prompt. It serves as the definitive repository
of all the arguments used by Generate and their default values.
of all the arguments used by Generate and their default values, and implements the
preliminary metadata standards discussed here:
https://github.com/lstein/stable-diffusion/issues/266
To use:
opt = Args()
@ -52,15 +55,38 @@ you wish to apply logic as to which one to use. For example:
To add new attributes, edit the _create_arg_parser() and
_create_dream_cmd_parser() methods.
We also export the function build_metadata
**Generating and retrieving sd-metadata**
To generate a dict representing RFC266 metadata:
metadata = metadata_dumps(opt,<seeds,model_hash,postprocesser>)
This will generate an RFC266 dictionary that can then be turned into a JSON
and written to the PNG file. The optional seeds, weights, model_hash and
postprocesser arguments are not available to the opt object and so must be
provided externally. See how dream.py does it.
Note that this function was originally called format_metadata() and a wrapper
is provided that issues a deprecation notice.
To retrieve a (series of) opt objects corresponding to the metadata, do this:
opt_list = metadata_loads(metadata)
The metadata should be pulled out of the PNG image. pngwriter has a method
retrieve_metadata that will do this.
"""
import argparse
from argparse import Namespace
import shlex
import json
import hashlib
import os
import copy
import base64
from ldm.dream.conditioning import split_weighted_subprompts
SAMPLER_CHOICES = [
@ -105,6 +131,7 @@ class Args(object):
try:
elements = shlex.split(command)
except ValueError:
import sys, traceback
print(traceback.format_exc(), file=sys.stderr)
return
switches = ['']
@ -141,24 +168,26 @@ class Args(object):
a = vars(self)
a.update(kwargs)
switches = list()
switches.append(f'"{a["prompt"]}')
switches.append(f'"{a["prompt"]}"')
switches.append(f'-s {a["steps"]}')
switches.append(f'-S {a["seed"]}')
switches.append(f'-W {a["width"]}')
switches.append(f'-H {a["height"]}')
switches.append(f'-C {a["cfg_scale"]}')
switches.append(f'-A {a["sampler_name"]}')
switches.append(f'-S {a["seed"]}')
if a['grid']:
switches.append('--grid')
if a['iterations'] and a['iterations']>0:
switches.append(f'-n {a["iterations"]}')
if a['seamless']:
switches.append('--seamless')
if a['init_img'] and len(a['init_img'])>0:
switches.append(f'-I {a["init_img"]}')
if a['init_mask'] and len(a['init_mask'])>0:
switches.append(f'-M {a["init_mask"]}')
if a['init_color'] and len(a['init_color'])>0:
switches.append(f'--init_color {a["init_color"]}')
if a['fit']:
switches.append(f'--fit')
if a['strength'] and a['strength']>0:
if a['init_img'] and a['strength'] and a['strength']>0:
switches.append(f'-f {a["strength"]}')
if a['gfpgan_strength']:
switches.append(f'-G {a["gfpgan_strength"]}')
@ -189,10 +218,10 @@ class Args(object):
pass
if cmd_switches and arg_switches and name=='__dict__':
a = arg_switches.__dict__
a.update(cmd_switches.__dict__)
return a
return self._merge_dict(
arg_switches.__dict__,
cmd_switches.__dict__,
)
try:
return object.__getattribute__(self,name)
except AttributeError:
@ -216,13 +245,8 @@ class Args(object):
# the arg value. For example, the --grid and --individual options are a little
# funny because of their push/pull relationship. This is how to handle it.
if name=='grid':
return value_arg or value_cmd # arg supersedes cmd
if name=='individual':
return value_cmd or value_arg # cmd supersedes arg
if value_cmd is not None:
return value_cmd
else:
return value_arg
return not cmd_switches.individual and value_arg # arg supersedes cmd
return value_cmd if value_cmd is not None else value_arg
def __setattr__(self,name,value):
if name.startswith('_'):
@ -230,6 +254,14 @@ class Args(object):
else:
self._cmd_switches.__dict__[name] = value
def _merge_dict(self,dict1,dict2):
new_dict = {}
for k in set(list(dict1.keys())+list(dict2.keys())):
value1 = dict1.get(k,None)
value2 = dict2.get(k,None)
new_dict[k] = value2 if value2 is not None else value1
return new_dict
def _create_arg_parser(self):
'''
This defines all the arguments used on the command line when you launch
@ -268,6 +300,17 @@ class Args(object):
default='stable-diffusion-1.4',
help='Indicates which diffusion model to load. (currently "stable-diffusion-1.4" (default) or "laion400m")',
)
model_group.add_argument(
'--sampler',
'-A',
'-m',
dest='sampler_name',
type=str,
choices=SAMPLER_CHOICES,
metavar='SAMPLER_NAME',
help=f'Switch to a different sampler. Supported samplers: {", ".join(SAMPLER_CHOICES)}',
default='k_lms',
)
model_group.add_argument(
'-F',
'--full_precision',
@ -294,11 +337,6 @@ class Args(object):
action='store_true',
help='Place images in subdirectories named after the prompt.',
)
render_group.add_argument(
'--seamless',
action='store_true',
help='Change the model to seamless tiling (circular) mode',
)
render_group.add_argument(
'--grid',
'-g',
@ -393,14 +431,12 @@ class Args(object):
'--width',
type=int,
help='Image width, multiple of 64',
default=512
)
render_group.add_argument(
'-H',
'--height',
type=int,
help='Image height, multiple of 64',
default=512,
)
render_group.add_argument(
'-C',
@ -416,8 +452,8 @@ class Args(object):
help='generate a grid'
)
render_group.add_argument(
'--individual',
'-i',
'--individual',
action='store_true',
help='override command-line --grid setting and generate individual images'
)
@ -436,7 +472,6 @@ class Args(object):
choices=SAMPLER_CHOICES,
metavar='SAMPLER_NAME',
help=f'Switch to a different sampler. Supported samplers: {", ".join(SAMPLER_CHOICES)}',
default='k_lms',
)
render_group.add_argument(
'-t',
@ -448,7 +483,6 @@ class Args(object):
'--outdir',
'-o',
type=str,
default='outputs/img-samples',
help='Directory to save generated images and a log of prompts and seeds',
)
img2img_group.add_argument(
@ -463,6 +497,11 @@ class Args(object):
type=str,
help='Path to input mask for inpainting mode (supersedes width and height)',
)
img2img_group.add_argument(
'--init_color',
type=str,
help='Path to reference image for color correction (used for repeated img2img and inpainting)'
)
img2img_group.add_argument(
'-T',
'-fit',
@ -477,6 +516,12 @@ class Args(object):
help='Strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely',
default=0.75,
)
postprocessing_group.add_argument(
'-ft',
'--facetool',
type=str,
help='Select the face restoration AI to use: gfpgan, codeformer',
)
postprocessing_group.add_argument(
'-G',
'--gfpgan_strength',
@ -484,6 +529,13 @@ class Args(object):
help='The strength at which to apply the GFPGAN model to the result, in order to improve faces.',
default=0,
)
postprocessing_group.add_argument(
'-cf',
'--codeformer_fidelity',
type=float,
help='Takes values between 0 and 1. 0 produces high quality but low accuracy. 1 produces high accuracy but low quality.',
default=0.75
)
postprocessing_group.add_argument(
'-U',
'--upscale',
@ -535,18 +587,31 @@ class Args(object):
)
return parser
# very partial implementation of https://github.com/lstein/stable-diffusion/issues/266
# it does not write all the required top-level metadata, writes too much image
# data, and doesn't support grids yet. But you gotta start somewhere, no?
def format_metadata(opt,
seeds=[],
weights=None,
model_hash=None,
postprocessing=None):
def format_metadata(**kwargs):
print(f'format_metadata() is deprecated. Please use metadata_dumps()')
return metadata_dumps(kwargs)
def metadata_dumps(opt,
seeds=[],
model_hash=None,
postprocessing=None):
'''
Given an Args object, returns a partial implementation of
the stable diffusion metadata standard
Given an Args object, returns a dict containing the keys and
structure of the proposed stable diffusion metadata standard
https://github.com/lstein/stable-diffusion/discussions/392
This is intended to be turned into JSON and stored in the
"sd
'''
# top-level metadata minus `image` or `images`
metadata = {
'model' : 'stable diffusion',
'model_id' : opt.model,
'model_hash' : model_hash,
'app_id' : APP_ID,
'app_version' : APP_VERSION,
}
# add some RFC266 fields that are generated internally, and not as
# user args
image_dict = opt.to_dict(
@ -587,24 +652,67 @@ def format_metadata(opt,
if opt.init_img:
rfc_dict['type'] = 'img2img'
rfc_dict['strength_steps'] = rfc_dict.pop('strength')
rfc_dict['orig_hash'] = sha256(image_dict['init_img'])
rfc_dict['orig_hash'] = calculate_init_img_hash(opt.init_img)
rfc_dict['sampler'] = 'ddim' # FIX ME WHEN IMG2IMG SUPPORTS ALL SAMPLERS
else:
rfc_dict['type'] = 'txt2img'
images = []
for seed in seeds:
rfc_dict['seed'] = seed
images.append(copy.copy(rfc_dict))
if len(seeds)==0 and opt.seed:
seeds=[seed]
return {
'model' : 'stable diffusion',
'model_id' : opt.model,
'model_hash' : model_hash,
'app_id' : APP_ID,
'app_version' : APP_VERSION,
'images' : images,
}
if opt.grid:
images = []
for seed in seeds:
rfc_dict['seed'] = seed
images.append(copy.copy(rfc_dict))
metadata['images'] = images
else:
# there should only ever be a single seed if we did not generate a grid
assert len(seeds) == 1, 'Expected a single seed'
rfc_dict['seed'] = seeds[0]
metadata['image'] = rfc_dict
return metadata
def metadata_loads(metadata):
'''
Takes the dictionary corresponding to RFC266 (https://github.com/lstein/stable-diffusion/issues/266)
and returns a series of opt objects for each of the images described in the dictionary.
'''
results = []
try:
images = metadata['sd-metadata']['images']
for image in images:
# repack the prompt and variations
image['prompt'] = ','.join([':'.join([x['prompt'], str(x['weight'])]) for x in image['prompt']])
image['variations'] = ','.join([':'.join([str(x['seed']),str(x['weight'])]) for x in image['variations']])
# fix a bit of semantic drift here
image['sampler_name']=image.pop('sampler')
opt = Args()
opt._cmd_switches = Namespace(**image)
results.append(opt)
except KeyError as e:
import sys, traceback
print('>> badly-formatted metadata',file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
return results
# image can either be a file path on disk or a base64-encoded
# representation of the file's contents
def calculate_init_img_hash(image_string):
prefix = 'data:image/png;base64,'
hash = None
if image_string.startswith(prefix):
imagebase64 = image_string[len(prefix):]
imagedata = base64.b64decode(imagebase64)
with open('outputs/test.png','wb') as file:
file.write(imagedata)
sha = hashlib.sha256()
sha.update(imagedata)
hash = sha.hexdigest()
else:
hash = sha256(image_string)
return hash
# Bah. This should be moved somewhere else...
def sha256(path):

View File

@ -13,7 +13,20 @@ import re
import torch
def get_uc_and_c(prompt, model, log_tokens=False, skip_normalize=False):
uc = model.get_learned_conditioning([''])
# Extract Unconditioned Words From Prompt
unconditioned_words = ''
unconditional_regex = r'\[(.*?)\]'
unconditionals = re.findall(unconditional_regex, prompt)
if len(unconditionals) > 0:
unconditioned_words = ' '.join(unconditionals)
# Remove Unconditioned Words From Prompt
unconditional_regex_compile = re.compile(unconditional_regex)
clean_prompt = unconditional_regex_compile.sub(' ', prompt)
prompt = re.sub(' +', ' ', clean_prompt)
uc = model.get_learned_conditioning([unconditioned_words])
# get weighted sub-prompts
weighted_subprompts = split_weighted_subprompts(
@ -25,15 +38,16 @@ def get_uc_and_c(prompt, model, log_tokens=False, skip_normalize=False):
c = torch.zeros_like(uc)
# normalize each "sub prompt" and add it
for subprompt, weight in weighted_subprompts:
log_tokenization(subprompt, model, log_tokens)
log_tokenization(subprompt, model, log_tokens, weight)
c = torch.add(
c,
model.get_learned_conditioning([subprompt]),
alpha=weight,
)
else: # just standard 1 prompt
log_tokenization(prompt, model, log_tokens)
log_tokenization(prompt, model, log_tokens, 1)
c = model.get_learned_conditioning([prompt])
uc = model.get_learned_conditioning([unconditioned_words])
return (uc, c)
def split_weighted_subprompts(text, skip_normalize=False)->list:
@ -72,7 +86,7 @@ def split_weighted_subprompts(text, skip_normalize=False)->list:
# shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' '
def log_tokenization(text, model, log=False):
def log_tokenization(text, model, log=False, weight=1):
if not log:
return
tokens = model.cond_stage_model.tokenizer._tokenize(text)
@ -89,8 +103,8 @@ def log_tokenization(text, model, log=False):
usedTokens += 1
else: # over max token length
discarded = discarded + f"\x1b[0;3{s};40m{token}"
print(f"\n>> Tokens ({usedTokens}):\n{tokenized}\x1b[0m")
if discarded != "":
print(
f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m"
)
print(f"\n>> Tokens ({usedTokens}), Weight ({weight:.2f}):\n{tokenized}\x1b[0m")
if discarded != "":
print(
f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m"
)

View File

@ -22,7 +22,8 @@ class Completer:
def complete(self, text, state):
buffer = readline.get_line_buffer()
if text.startswith(('-I', '--init_img','-M','--init_mask')):
if text.startswith(('-I', '--init_img','-M','--init_mask',
'--init_color')):
return self._path_completions(text, state, ('.png','.jpg','.jpeg'))
if buffer.strip().endswith('cd') or text.startswith(('.', '/')):
@ -57,6 +58,8 @@ class Completer:
path = text.replace('--init_mask=', '', 1).lstrip()
elif text.startswith('-M'):
path = text.replace('-M', '', 1).lstrip()
elif text.startswith('--init_color='):
path = text.replace('--init_color=', '', 1).lstrip()
else:
path = text
@ -100,6 +103,7 @@ if readline_available:
'--individual','-i',
'--init_img','-I',
'--init_mask','-M',
'--init_color',
'--strength','-f',
'--variants','-v',
'--outdir','-o',

View File

@ -4,7 +4,7 @@ import copy
import base64
import mimetypes
import os
from ldm.dream.args import Args, format_metadata
from ldm.dream.args import Args, metadata_dumps
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from ldm.dream.pngwriter import PngWriter
from threading import Event
@ -76,7 +76,7 @@ class DreamServer(BaseHTTPRequestHandler):
self.send_response(200)
self.send_header("Content-type", "text/html")
self.end_headers()
with open("./static/dream_web/index.html", "rb") as content:
with open("./static/legacy_web/index.html", "rb") as content:
self.wfile.write(content.read())
elif self.path == "/config.js":
# unfortunately this import can't be at the top level, since that would cause a circular import
@ -94,7 +94,7 @@ class DreamServer(BaseHTTPRequestHandler):
self.end_headers()
output = []
log_file = os.path.join(self.outdir, "dream_web_log.txt")
log_file = os.path.join(self.outdir, "legacy_web_log.txt")
if os.path.exists(log_file):
with open(log_file, "r") as log:
for line in log:
@ -114,7 +114,7 @@ class DreamServer(BaseHTTPRequestHandler):
else:
path_dir = os.path.dirname(self.path)
out_dir = os.path.realpath(self.outdir.rstrip('/'))
if self.path.startswith('/static/dream_web/'):
if self.path.startswith('/static/legacy_web/'):
path = '.' + self.path
elif out_dir.replace('\\', '/').endswith(path_dir):
file = os.path.basename(self.path)
@ -145,7 +145,6 @@ class DreamServer(BaseHTTPRequestHandler):
opt = build_opt(post_data, self.model.seed, gfpgan_model_exists)
self.canceled.clear()
print(f">> Request to generate with prompt: {opt.prompt}")
# In order to handle upscaled images, the PngWriter needs to maintain state
# across images generated by each call to prompt2img(), so we define it in
# the outer scope of image_done()
@ -176,10 +175,9 @@ class DreamServer(BaseHTTPRequestHandler):
path = pngwriter.save_image_and_prompt_to_png(
image,
dream_prompt = formatted_prompt,
metadata = format_metadata(iter_opt,
seeds = [seed],
weights = self.model.weights,
model_hash = self.model.model_hash
metadata = metadata_dumps(iter_opt,
seeds = [seed],
model_hash = self.model.model_hash
),
name = name,
)
@ -188,7 +186,7 @@ class DreamServer(BaseHTTPRequestHandler):
config['seed'] = seed
# Append post_data to log, but only once!
if not upscaled:
with open(os.path.join(self.outdir, "dream_web_log.txt"), "a") as log:
with open(os.path.join(self.outdir, "legacy_web_log.txt"), "a") as log:
log.write(f"{path}: {json.dumps(config)}\n")
self.wfile.write(bytes(json.dumps(
@ -228,7 +226,8 @@ class DreamServer(BaseHTTPRequestHandler):
nonlocal step_index
if opt.progress_images and step % 5 == 0 and step < opt.steps - 1:
image = self.model.sample_to_image(sample)
name = f'{prefix}.{opt.seed}.{step_index}.png'
step_index_padded = str(step_index).rjust(len(str(opt.steps)), '0')
name = f'{prefix}.{opt.seed}.{step_index_padded}.png'
metadata = f'{opt.prompt} -S{opt.seed} [intermediate]'
path = step_writer.save_image_and_prompt_to_png(image, dream_prompt=metadata, name=name)
step_index += 1

View File

@ -15,6 +15,8 @@ import traceback
import transformers
import io
import hashlib
import cv2
import skimage
from omegaconf import OmegaConf
from PIL import Image, ImageOps
@ -220,11 +222,14 @@ class Generate:
init_mask = None,
fit = False,
strength = None,
init_color = None,
# these are specific to embiggen (which also relies on img2img args)
embiggen = None,
embiggen_tiles = None,
# these are specific to GFPGAN/ESRGAN
facetool = None,
gfpgan_strength = 0,
codeformer_fidelity = None,
save_original = False,
upscale = None,
# Set this True to handle KeyboardInterrupt internally
@ -362,10 +367,17 @@ class Generate:
embiggen_tiles = embiggen_tiles,
)
if init_color:
self.correct_colors(image_list = results,
reference_image_path = init_color,
image_callback = image_callback)
if upscale is not None or gfpgan_strength > 0:
self.upscale_and_reconstruct(results,
upscale = upscale,
facetool = facetool,
strength = gfpgan_strength,
codeformer_fidelity = codeformer_fidelity,
save_original = save_original,
image_callback = image_callback)
@ -475,17 +487,44 @@ class Generate:
return self.model
def correct_colors(self,
image_list,
reference_image_path,
image_callback = None):
reference_image = Image.open(reference_image_path)
correction_target = cv2.cvtColor(np.asarray(reference_image),
cv2.COLOR_RGB2LAB)
for r in image_list:
image, seed = r
image = cv2.cvtColor(np.asarray(image),
cv2.COLOR_RGB2LAB)
image = skimage.exposure.match_histograms(image,
correction_target,
channel_axis=2)
image = Image.fromarray(
cv2.cvtColor(image, cv2.COLOR_LAB2RGB).astype("uint8")
)
if image_callback is not None:
image_callback(image, seed)
else:
r[0] = image
def upscale_and_reconstruct(self,
image_list,
facetool = 'gfpgan',
upscale = None,
strength = 0.0,
codeformer_fidelity = 0.75,
save_original = False,
image_callback = None):
try:
if upscale is not None:
from ldm.gfpgan.gfpgan_tools import real_esrgan_upscale
if strength > 0:
from ldm.gfpgan.gfpgan_tools import run_gfpgan
if facetool == 'codeformer':
from ldm.restoration.codeformer.codeformer import CodeFormerRestoration
else:
from ldm.gfpgan.gfpgan_tools import run_gfpgan
except (ModuleNotFoundError, ImportError):
print(traceback.format_exc(), file=sys.stderr)
print('>> You may need to install the ESRGAN and/or GFPGAN modules')
@ -504,9 +543,12 @@ class Generate:
seed,
)
if strength > 0:
image = run_gfpgan(
image, strength, seed, 1
)
if facetool == 'codeformer':
image = CodeFormerRestoration().process(image=image, strength=strength, device=self.device, seed=seed, fidelity=codeformer_fidelity)
else:
image = run_gfpgan(
image, strength, seed, 1
)
except Exception as e:
print(
f'>> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}'

View File

@ -90,7 +90,7 @@ class LinearAttention(nn.Module):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
k = k.softmax(dim=-1)
k = k.softmax(dim=-1)
context = torch.einsum('bhdn,bhen->bhde', k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q)
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
@ -167,101 +167,85 @@ class CrossAttention(nn.Module):
nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout)
)
if torch.cuda.is_available():
self.einsum_op = self.einsum_op_cuda
else:
self.mem_total = psutil.virtual_memory().total / (1024**3)
self.einsum_op = self.einsum_op_mps_v1 if self.mem_total >= 32 else self.einsum_op_mps_v2
def einsum_op_compvis(self, q, k, v, r1):
s1 = einsum('b i d, b j d -> b i j', q, k) * self.scale # faster
s2 = s1.softmax(dim=-1, dtype=q.dtype)
del s1
r1 = einsum('b i j, b j d -> b i d', s2, v)
del s2
return r1
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
def einsum_op_mps_v1(self, q, k, v, r1):
def einsum_op_compvis(self, q, k, v):
s = einsum('b i d, b j d -> b i j', q, k)
s = s.softmax(dim=-1, dtype=s.dtype)
return einsum('b i j, b j d -> b i d', s, v)
def einsum_op_slice_0(self, q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[0], slice_size):
end = i + slice_size
r[i:end] = self.einsum_op_compvis(q[i:end], k[i:end], v[i:end])
return r
def einsum_op_slice_1(self, q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
r[:, i:end] = self.einsum_op_compvis(q[:, i:end], k, v)
return r
def einsum_op_mps_v1(self, q, k, v):
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
r1 = self.einsum_op_compvis(q, k, v, r1)
return self.einsum_op_compvis(q, k, v)
else:
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
del s1
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
return r1
return self.einsum_op_slice_1(q, k, v, slice_size)
def einsum_op_mps_v2(self, q, k, v, r1):
if self.mem_total >= 8 and q.shape[1] <= 4096:
r1 = self.einsum_op_compvis(q, k, v, r1)
def einsum_op_mps_v2(self, q, k, v):
if self.mem_total_gb > 8 and q.shape[1] <= 4096:
return self.einsum_op_compvis(q, k, v)
else:
slice_size = 1
for i in range(0, q.shape[0], slice_size):
end = min(q.shape[0], i + slice_size)
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
s1 *= self.scale
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
del s1
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
del s2
return r1
def einsum_op_cuda(self, q, k, v, r1):
return self.einsum_op_slice_0(q, k, v, 1)
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb):
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
if size_mb <= max_tensor_mb:
return self.einsum_op_compvis(q, k, v)
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
if div <= q.shape[0]:
return self.einsum_op_slice_0(q, k, v, q.shape[0] // div)
return self.einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
def einsum_op_cuda(self, q, k, v):
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_cuda, _ = torch.cuda.mem_get_info(q.device)
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
# Divide factor of safety as there's copying and fragmentation
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * 4
mem_required = tensor_size * 2.5
steps = 1
def einsum_op(self, q, k, v):
if q.device.type == 'cuda':
return self.einsum_op_cuda(q, k, v)
if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
if q.device.type == 'mps':
if self.mem_total_gb >= 32:
return self.einsum_op_mps_v1(q, k, v)
return self.einsum_op_mps_v2(q, k, v)
if steps > 64:
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = min(q.shape[1], i + slice_size)
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
del s1
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
return r1
# Smaller slices are faster due to L2/L3/SLC caches.
# Tested on i7 with 8MB L3 cache.
return self.einsum_op_tensor_mem(q, k, v, 32)
def forward(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
q = self.to_q(x)
context = default(context, x)
k_in = self.to_k(context)
v_in = self.to_v(context)
device_type = 'mps' if x.device.type == 'mps' else 'cuda'
k = self.to_k(context) * self.scale
v = self.to_v(context)
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
r1 = self.einsum_op(q, k, v, r1)
del q, k, v
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
return self.to_out(r2)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
r = self.einsum_op(q, k, v)
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
class BasicTransformerBlock(nn.Module):

View File

@ -3,6 +3,7 @@ import gc
import math
import torch
import torch.nn as nn
from torch.nn.functional import silu
import numpy as np
from einops import rearrange
@ -32,11 +33,6 @@ def get_timestep_embedding(timesteps, embedding_dim):
return emb
def nonlinearity(x):
# swish
return x*torch.sigmoid(x)
def Normalize(in_channels, num_groups=32):
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
@ -122,14 +118,14 @@ class ResnetBlock(nn.Module):
def forward(self, x, temb):
h = self.norm1(x)
h = nonlinearity(h)
h = silu(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
h = h + self.temb_proj(silu(temb))[:,:,None,None]
h = self.norm2(h)
h = nonlinearity(h)
h = silu(h)
h = self.dropout(h)
h = self.conv2(h)
@ -368,7 +364,7 @@ class Model(nn.Module):
assert t is not None
temb = get_timestep_embedding(t, self.ch)
temb = self.temb.dense[0](temb)
temb = nonlinearity(temb)
temb = silu(temb)
temb = self.temb.dense[1](temb)
else:
temb = None
@ -402,7 +398,7 @@ class Model(nn.Module):
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = silu(h)
h = self.conv_out(h)
return h
@ -499,7 +495,7 @@ class Encoder(nn.Module):
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = silu(h)
h = self.conv_out(h)
return h
@ -611,7 +607,7 @@ class Decoder(nn.Module):
return h
h = self.norm_out(h)
h = nonlinearity(h)
h = silu(h)
h = self.conv_out(h)
if self.tanh_out:
h = torch.tanh(h)
@ -649,7 +645,7 @@ class SimpleDecoder(nn.Module):
x = layer(x)
h = self.norm_out(x)
h = nonlinearity(h)
h = silu(h)
x = self.conv_out(h)
return x
@ -697,7 +693,7 @@ class UpsampleDecoder(nn.Module):
if i_level != self.num_resolutions - 1:
h = self.upsample_blocks[k](h)
h = self.norm_out(h)
h = nonlinearity(h)
h = silu(h)
h = self.conv_out(h)
return h
@ -873,7 +869,7 @@ class FirstStagePostProcessor(nn.Module):
z_fs = self.encode_with_pretrained(x)
z = self.proj_norm(z_fs)
z = self.proj(z)
z = nonlinearity(z)
z = silu(z)
for submodel, downmodel in zip(self.model,self.downsampler):
z = submodel(z,temb=None)

View File

@ -252,12 +252,6 @@ def normalization(channels):
return GroupNorm32(32, channels)
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
class SiLU(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
class GroupNorm32(nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)

View File

@ -82,7 +82,9 @@ class EmbeddingManager(nn.Module):
get_embedding_for_clip_token,
embedder.transformer.text_model.embeddings,
)
token_dim = 1280
# per bug report #572
#token_dim = 1280
token_dim = 768
else: # using LDM's BERT encoder
self.is_clip = False
get_token_for_string = partial(

Some files were not shown because too many files have changed in this diff Show More