mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge remote-tracking branch 'upstream/development' into development
This commit is contained in:
commit
1e3200801f
64
.github/workflows/cache-model.yml
vendored
64
.github/workflows/cache-model.yml
vendored
@ -1,64 +0,0 @@
|
|||||||
name: Cache Model
|
|
||||||
on:
|
|
||||||
workflow_dispatch
|
|
||||||
jobs:
|
|
||||||
build:
|
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
os: [ macos-12 ]
|
|
||||||
name: Create Caches using ${{ matrix.os }}
|
|
||||||
runs-on: ${{ matrix.os }}
|
|
||||||
steps:
|
|
||||||
- name: Checkout sources
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
- name: Cache model
|
|
||||||
id: cache-sd-v1-4
|
|
||||||
uses: actions/cache@v3
|
|
||||||
env:
|
|
||||||
cache-name: cache-sd-v1-4
|
|
||||||
with:
|
|
||||||
path: models/ldm/stable-diffusion-v1/model.ckpt
|
|
||||||
key: ${{ env.cache-name }}
|
|
||||||
restore-keys: |
|
|
||||||
${{ env.cache-name }}
|
|
||||||
- name: Download Stable Diffusion v1.4 model
|
|
||||||
if: ${{ steps.cache-sd-v1-4.outputs.cache-hit != 'true' }}
|
|
||||||
continue-on-error: true
|
|
||||||
run: |
|
|
||||||
if [ ! -e models/ldm/stable-diffusion-v1 ]; then
|
|
||||||
mkdir -p models/ldm/stable-diffusion-v1
|
|
||||||
fi
|
|
||||||
if [ ! -e models/ldm/stable-diffusion-v1/model.ckpt ]; then
|
|
||||||
curl -o models/ldm/stable-diffusion-v1/model.ckpt ${{ secrets.SD_V1_4_URL }}
|
|
||||||
fi
|
|
||||||
# Uncomment this when we no longer make changes to environment-mac.yaml
|
|
||||||
# - name: Cache environment
|
|
||||||
# id: cache-conda-env-ldm
|
|
||||||
# uses: actions/cache@v3
|
|
||||||
# env:
|
|
||||||
# cache-name: cache-conda-env-ldm
|
|
||||||
# with:
|
|
||||||
# path: ~/.conda/envs/ldm
|
|
||||||
# key: ${{ env.cache-name }}
|
|
||||||
# restore-keys: |
|
|
||||||
# ${{ env.cache-name }}
|
|
||||||
- name: Install dependencies
|
|
||||||
# if: ${{ steps.cache-conda-env-ldm.outputs.cache-hit != 'true' }}
|
|
||||||
run: |
|
|
||||||
conda env create -f environment-mac.yaml
|
|
||||||
- name: Cache hugginface and torch models
|
|
||||||
id: cache-hugginface-torch
|
|
||||||
uses: actions/cache@v3
|
|
||||||
env:
|
|
||||||
cache-name: cache-hugginface-torch
|
|
||||||
with:
|
|
||||||
path: ~/.cache
|
|
||||||
key: ${{ env.cache-name }}
|
|
||||||
restore-keys: |
|
|
||||||
${{ env.cache-name }}
|
|
||||||
- name: Download Huggingface and Torch models
|
|
||||||
if: ${{ steps.cache-hugginface-torch.outputs.cache-hit != 'true' }}
|
|
||||||
continue-on-error: true
|
|
||||||
run: |
|
|
||||||
export PYTHON_BIN=/usr/local/miniconda/envs/ldm/bin/python
|
|
||||||
$PYTHON_BIN scripts/preload_models.py
|
|
70
.github/workflows/create-caches.yml
vendored
Normal file
70
.github/workflows/create-caches.yml
vendored
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
name: Create Caches
|
||||||
|
on:
|
||||||
|
workflow_dispatch
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
os: [ ubuntu-latest, macos-12 ]
|
||||||
|
name: Create Caches on ${{ matrix.os }} conda
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
steps:
|
||||||
|
- name: Set platform variables
|
||||||
|
id: vars
|
||||||
|
run: |
|
||||||
|
if [ "$RUNNER_OS" = "macOS" ]; then
|
||||||
|
echo "::set-output name=ENV_FILE::environment-mac.yaml"
|
||||||
|
echo "::set-output name=PYTHON_BIN::/usr/local/miniconda/envs/ldm/bin/python"
|
||||||
|
elif [ "$RUNNER_OS" = "Linux" ]; then
|
||||||
|
echo "::set-output name=ENV_FILE::environment.yaml"
|
||||||
|
echo "::set-output name=PYTHON_BIN::/usr/share/miniconda/envs/ldm/bin/python"
|
||||||
|
fi
|
||||||
|
- name: Checkout sources
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
- name: Use Cached Stable Diffusion v1.4 Model
|
||||||
|
id: cache-sd-v1-4
|
||||||
|
uses: actions/cache@v3
|
||||||
|
env:
|
||||||
|
cache-name: cache-sd-v1-4
|
||||||
|
with:
|
||||||
|
path: models/ldm/stable-diffusion-v1/model.ckpt
|
||||||
|
key: ${{ env.cache-name }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ env.cache-name }}
|
||||||
|
- name: Download Stable Diffusion v1.4 Model
|
||||||
|
if: ${{ steps.cache-sd-v1-4.outputs.cache-hit != 'true' }}
|
||||||
|
run: |
|
||||||
|
if [ ! -e models/ldm/stable-diffusion-v1 ]; then
|
||||||
|
mkdir -p models/ldm/stable-diffusion-v1
|
||||||
|
fi
|
||||||
|
if [ ! -e models/ldm/stable-diffusion-v1/model.ckpt ]; then
|
||||||
|
curl -o models/ldm/stable-diffusion-v1/model.ckpt ${{ secrets.SD_V1_4_URL }}
|
||||||
|
fi
|
||||||
|
- name: Use Cached Dependencies
|
||||||
|
id: cache-conda-env-ldm
|
||||||
|
uses: actions/cache@v3
|
||||||
|
env:
|
||||||
|
cache-name: cache-conda-env-ldm
|
||||||
|
with:
|
||||||
|
path: ~/.conda/envs/ldm
|
||||||
|
key: ${{ env.cache-name }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ env.cache-name }}-${{ runner.os }}-${{ hashFiles(steps.vars.outputs.ENV_FILE) }}
|
||||||
|
- name: Install Dependencies
|
||||||
|
if: ${{ steps.cache-conda-env-ldm.outputs.cache-hit != 'true' }}
|
||||||
|
run: |
|
||||||
|
conda env create -f ${{ steps.vars.outputs.ENV_FILE }}
|
||||||
|
- name: Use Cached Huggingface and Torch models
|
||||||
|
id: cache-huggingface-torch
|
||||||
|
uses: actions/cache@v3
|
||||||
|
env:
|
||||||
|
cache-name: cache-huggingface-torch
|
||||||
|
with:
|
||||||
|
path: ~/.cache
|
||||||
|
key: ${{ env.cache-name }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ env.cache-name }}-${{ hashFiles('scripts/preload_models.py') }}
|
||||||
|
- name: Download Huggingface and Torch models
|
||||||
|
if: ${{ steps.cache-huggingface-torch.outputs.cache-hit != 'true' }}
|
||||||
|
run: |
|
||||||
|
${{ steps.vars.outputs.PYTHON_BIN }} scripts/preload_models.py
|
80
.github/workflows/macos12-miniconda.yml
vendored
80
.github/workflows/macos12-miniconda.yml
vendored
@ -1,80 +0,0 @@
|
|||||||
name: Build
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches: [ main ]
|
|
||||||
pull_request:
|
|
||||||
branches: [ main ]
|
|
||||||
jobs:
|
|
||||||
build:
|
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
os: [ macos-12 ]
|
|
||||||
name: Build on ${{ matrix.os }} miniconda
|
|
||||||
runs-on: ${{ matrix.os }}
|
|
||||||
steps:
|
|
||||||
- name: Checkout sources
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
- name: Cache model
|
|
||||||
id: cache-sd-v1-4
|
|
||||||
uses: actions/cache@v3
|
|
||||||
env:
|
|
||||||
cache-name: cache-sd-v1-4
|
|
||||||
with:
|
|
||||||
path: models/ldm/stable-diffusion-v1/model.ckpt
|
|
||||||
key: ${{ env.cache-name }}
|
|
||||||
restore-keys: |
|
|
||||||
${{ env.cache-name }}
|
|
||||||
- name: Download Stable Diffusion v1.4 model
|
|
||||||
if: ${{ steps.cache-sd-v1-4.outputs.cache-hit != 'true' }}
|
|
||||||
continue-on-error: true
|
|
||||||
run: |
|
|
||||||
if [ ! -e models/ldm/stable-diffusion-v1 ]; then
|
|
||||||
mkdir -p models/ldm/stable-diffusion-v1
|
|
||||||
fi
|
|
||||||
if [ ! -e models/ldm/stable-diffusion-v1/model.ckpt ]; then
|
|
||||||
curl -o models/ldm/stable-diffusion-v1/model.ckpt ${{ secrets.SD_V1_4_URL }}
|
|
||||||
fi
|
|
||||||
# Uncomment this when we no longer make changes to environment-mac.yaml
|
|
||||||
# - name: Cache environment
|
|
||||||
# id: cache-conda-env-ldm
|
|
||||||
# uses: actions/cache@v3
|
|
||||||
# env:
|
|
||||||
# cache-name: cache-conda-env-ldm
|
|
||||||
# with:
|
|
||||||
# path: ~/.conda/envs/ldm
|
|
||||||
# key: ${{ env.cache-name }}
|
|
||||||
# restore-keys: |
|
|
||||||
# ${{ env.cache-name }}
|
|
||||||
- name: Install dependencies
|
|
||||||
# if: ${{ steps.cache-conda-env-ldm.outputs.cache-hit != 'true' }}
|
|
||||||
run: |
|
|
||||||
conda env create -f environment-mac.yaml
|
|
||||||
- name: Cache hugginface and torch models
|
|
||||||
id: cache-hugginface-torch
|
|
||||||
uses: actions/cache@v3
|
|
||||||
env:
|
|
||||||
cache-name: cache-hugginface-torch
|
|
||||||
with:
|
|
||||||
path: ~/.cache
|
|
||||||
key: ${{ env.cache-name }}
|
|
||||||
restore-keys: |
|
|
||||||
${{ env.cache-name }}
|
|
||||||
- name: Download Huggingface and Torch models
|
|
||||||
if: ${{ steps.cache-hugginface-torch.outputs.cache-hit != 'true' }}
|
|
||||||
continue-on-error: true
|
|
||||||
run: |
|
|
||||||
export PYTHON_BIN=/usr/local/miniconda/envs/ldm/bin/python
|
|
||||||
$PYTHON_BIN scripts/preload_models.py
|
|
||||||
- name: Run the tests
|
|
||||||
run: |
|
|
||||||
# Note, can't "activate" via automation, and activation is just env vars and path
|
|
||||||
export PYTHON_BIN=/usr/local/miniconda/envs/ldm/bin/python
|
|
||||||
export PYTORCH_ENABLE_MPS_FALLBACK=1
|
|
||||||
$PYTHON_BIN scripts/preload_models.py
|
|
||||||
mkdir -p outputs/img-samples
|
|
||||||
time $PYTHON_BIN scripts/dream.py --from_file tests/prompts.txt </dev/null 2> outputs/img-samples/err.log > outputs/img-samples/out.log
|
|
||||||
- name: Archive results
|
|
||||||
uses: actions/upload-artifact@v3
|
|
||||||
with:
|
|
||||||
name: results
|
|
||||||
path: outputs/img-samples
|
|
97
.github/workflows/test-dream-conda.yml
vendored
Normal file
97
.github/workflows/test-dream-conda.yml
vendored
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
name: Test Dream with Conda
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- 'main'
|
||||||
|
- 'development'
|
||||||
|
jobs:
|
||||||
|
os_matrix:
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
os: [ ubuntu-latest, macos-12 ]
|
||||||
|
name: Test dream.py on ${{ matrix.os }} with conda
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
steps:
|
||||||
|
- run: |
|
||||||
|
echo The PR was merged
|
||||||
|
- name: Set platform variables
|
||||||
|
id: vars
|
||||||
|
run: |
|
||||||
|
# Note, can't "activate" via github action; specifying the env's python has the same effect
|
||||||
|
if [ "$RUNNER_OS" = "macOS" ]; then
|
||||||
|
echo "::set-output name=ENV_FILE::environment-mac.yaml"
|
||||||
|
echo "::set-output name=PYTHON_BIN::/usr/local/miniconda/envs/ldm/bin/python"
|
||||||
|
elif [ "$RUNNER_OS" = "Linux" ]; then
|
||||||
|
echo "::set-output name=ENV_FILE::environment.yaml"
|
||||||
|
echo "::set-output name=PYTHON_BIN::/usr/share/miniconda/envs/ldm/bin/python"
|
||||||
|
fi
|
||||||
|
- name: Checkout sources
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
- name: Use Cached Stable Diffusion v1.4 Model
|
||||||
|
id: cache-sd-v1-4
|
||||||
|
uses: actions/cache@v3
|
||||||
|
env:
|
||||||
|
cache-name: cache-sd-v1-4
|
||||||
|
with:
|
||||||
|
path: models/ldm/stable-diffusion-v1/model.ckpt
|
||||||
|
key: ${{ env.cache-name }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ env.cache-name }}
|
||||||
|
- name: Download Stable Diffusion v1.4 Model
|
||||||
|
if: ${{ steps.cache-sd-v1-4.outputs.cache-hit != 'true' }}
|
||||||
|
run: |
|
||||||
|
if [ ! -e models/ldm/stable-diffusion-v1 ]; then
|
||||||
|
mkdir -p models/ldm/stable-diffusion-v1
|
||||||
|
fi
|
||||||
|
if [ ! -e models/ldm/stable-diffusion-v1/model.ckpt ]; then
|
||||||
|
curl -o models/ldm/stable-diffusion-v1/model.ckpt ${{ secrets.SD_V1_4_URL }}
|
||||||
|
fi
|
||||||
|
- name: Use Cached Dependencies
|
||||||
|
id: cache-conda-env-ldm
|
||||||
|
uses: actions/cache@v3
|
||||||
|
env:
|
||||||
|
cache-name: cache-conda-env-ldm
|
||||||
|
with:
|
||||||
|
path: ~/.conda/envs/ldm
|
||||||
|
key: ${{ env.cache-name }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ env.cache-name }}-${{ runner.os }}-${{ hashFiles(steps.vars.outputs.ENV_FILE) }}
|
||||||
|
- name: Install Dependencies
|
||||||
|
if: ${{ steps.cache-conda-env-ldm.outputs.cache-hit != 'true' }}
|
||||||
|
run: |
|
||||||
|
conda env create -f ${{ steps.vars.outputs.ENV_FILE }}
|
||||||
|
- name: Use Cached Huggingface and Torch models
|
||||||
|
id: cache-hugginface-torch
|
||||||
|
uses: actions/cache@v3
|
||||||
|
env:
|
||||||
|
cache-name: cache-hugginface-torch
|
||||||
|
with:
|
||||||
|
path: ~/.cache
|
||||||
|
key: ${{ env.cache-name }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ env.cache-name }}-${{ hashFiles('scripts/preload_models.py') }}
|
||||||
|
- name: Download Huggingface and Torch models
|
||||||
|
if: ${{ steps.cache-hugginface-torch.outputs.cache-hit != 'true' }}
|
||||||
|
run: |
|
||||||
|
${{ steps.vars.outputs.PYTHON_BIN }} scripts/preload_models.py
|
||||||
|
# - name: Run tmate
|
||||||
|
# uses: mxschmitt/action-tmate@v3
|
||||||
|
# timeout-minutes: 30
|
||||||
|
- name: Run the tests
|
||||||
|
run: |
|
||||||
|
# Note, can't "activate" via github action; specifying the env's python has the same effect
|
||||||
|
if [ $(uname) = "Darwin" ]; then
|
||||||
|
export PYTORCH_ENABLE_MPS_FALLBACK=1
|
||||||
|
fi
|
||||||
|
# Utterly hacky, but I don't know how else to do this
|
||||||
|
if [[ ${{ github.ref }} == 'refs/heads/master' ]]; then
|
||||||
|
time ${{ steps.vars.outputs.PYTHON_BIN }} scripts/dream.py --from_file tests/preflight_prompts.txt --full_precision
|
||||||
|
elif [[ ${{ github.ref }} == 'refs/heads/development' ]]; then
|
||||||
|
time ${{ steps.vars.outputs.PYTHON_BIN }} scripts/dream.py --from_file tests/dev_prompts.txt --full_precision
|
||||||
|
fi
|
||||||
|
mkdir -p outputs/img-samples
|
||||||
|
- name: Archive results
|
||||||
|
uses: actions/upload-artifact@v3
|
||||||
|
with:
|
||||||
|
name: results
|
||||||
|
path: outputs/img-samples
|
12
.gitignore
vendored
12
.gitignore
vendored
@ -77,9 +77,6 @@ db.sqlite3-journal
|
|||||||
instance/
|
instance/
|
||||||
.webassets-cache
|
.webassets-cache
|
||||||
|
|
||||||
# WebUI temp files:
|
|
||||||
img2img-tmp.png
|
|
||||||
|
|
||||||
# Scrapy stuff:
|
# Scrapy stuff:
|
||||||
.scrapy
|
.scrapy
|
||||||
|
|
||||||
@ -186,3 +183,12 @@ testtube
|
|||||||
checkpoints
|
checkpoints
|
||||||
# If it's a Mac
|
# If it's a Mac
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
|
||||||
|
# Let the frontend manage its own gitignore
|
||||||
|
!frontend/*
|
||||||
|
|
||||||
|
# Scratch folder
|
||||||
|
.scratch/
|
||||||
|
.vscode/
|
||||||
|
gfpgan/
|
||||||
|
models/ldm/stable-diffusion-v1/model.sha256
|
||||||
|
218
backend/modules/parameters.py
Normal file
218
backend/modules/parameters.py
Normal file
@ -0,0 +1,218 @@
|
|||||||
|
from modules.parse_seed_weights import parse_seed_weights
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
SAMPLER_CHOICES = [
|
||||||
|
'ddim',
|
||||||
|
'k_dpm_2_a',
|
||||||
|
'k_dpm_2',
|
||||||
|
'k_euler_a',
|
||||||
|
'k_euler',
|
||||||
|
'k_heun',
|
||||||
|
'k_lms',
|
||||||
|
'plms',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def parameters_to_command(params):
|
||||||
|
"""
|
||||||
|
Converts dict of parameters into a `dream.py` REPL command.
|
||||||
|
"""
|
||||||
|
|
||||||
|
switches = list()
|
||||||
|
|
||||||
|
if 'prompt' in params:
|
||||||
|
switches.append(f'"{params["prompt"]}"')
|
||||||
|
if 'steps' in params:
|
||||||
|
switches.append(f'-s {params["steps"]}')
|
||||||
|
if 'seed' in params:
|
||||||
|
switches.append(f'-S {params["seed"]}')
|
||||||
|
if 'width' in params:
|
||||||
|
switches.append(f'-W {params["width"]}')
|
||||||
|
if 'height' in params:
|
||||||
|
switches.append(f'-H {params["height"]}')
|
||||||
|
if 'cfg_scale' in params:
|
||||||
|
switches.append(f'-C {params["cfg_scale"]}')
|
||||||
|
if 'sampler_name' in params:
|
||||||
|
switches.append(f'-A {params["sampler_name"]}')
|
||||||
|
if 'seamless' in params and params["seamless"] == True:
|
||||||
|
switches.append(f'--seamless')
|
||||||
|
if 'init_img' in params and len(params['init_img']) > 0:
|
||||||
|
switches.append(f'-I {params["init_img"]}')
|
||||||
|
if 'init_mask' in params and len(params['init_mask']) > 0:
|
||||||
|
switches.append(f'-M {params["init_mask"]}')
|
||||||
|
if 'strength' in params and 'init_img' in params:
|
||||||
|
switches.append(f'-f {params["strength"]}')
|
||||||
|
if 'fit' in params and params["fit"] == True:
|
||||||
|
switches.append(f'--fit')
|
||||||
|
if 'gfpgan_strength' in params and params["gfpgan_strength"]:
|
||||||
|
switches.append(f'-G {params["gfpgan_strength"]}')
|
||||||
|
if 'upscale' in params and params["upscale"]:
|
||||||
|
switches.append(f'-U {params["upscale"][0]} {params["upscale"][1]}')
|
||||||
|
if 'variation_amount' in params and params['variation_amount'] > 0:
|
||||||
|
switches.append(f'-v {params["variation_amount"]}')
|
||||||
|
if 'with_variations' in params:
|
||||||
|
seed_weight_pairs = ','.join(f'{seed}:{weight}' for seed, weight in params["with_variations"])
|
||||||
|
switches.append(f'-V {seed_weight_pairs}')
|
||||||
|
|
||||||
|
return ' '.join(switches)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def create_cmd_parser():
|
||||||
|
"""
|
||||||
|
This is simply a copy of the parser from `dream.py` with a change to give
|
||||||
|
prompt a default value. This is a temporary hack pending merge of #587 which
|
||||||
|
provides a better way to do this.
|
||||||
|
"""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description='Example: dream> a fantastic alien landscape -W1024 -H960 -s100 -n12',
|
||||||
|
exit_on_error=True,
|
||||||
|
)
|
||||||
|
parser.add_argument('prompt', nargs='?', default='')
|
||||||
|
parser.add_argument('-s', '--steps', type=int, help='Number of steps')
|
||||||
|
parser.add_argument(
|
||||||
|
'-S',
|
||||||
|
'--seed',
|
||||||
|
type=int,
|
||||||
|
help='Image seed; a +ve integer, or use -1 for the previous seed, -2 for the one before that, etc',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'-n',
|
||||||
|
'--iterations',
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help='Number of samplings to perform (slower, but will provide seeds for individual images)',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'-W', '--width', type=int, help='Image width, multiple of 64'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'-H', '--height', type=int, help='Image height, multiple of 64'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'-C',
|
||||||
|
'--cfg_scale',
|
||||||
|
default=7.5,
|
||||||
|
type=float,
|
||||||
|
help='Classifier free guidance (CFG) scale - higher numbers cause generator to "try" harder.',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'-g', '--grid', action='store_true', help='generate a grid'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--outdir',
|
||||||
|
'-o',
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help='Directory to save generated images and a log of prompts and seeds',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--seamless',
|
||||||
|
action='store_true',
|
||||||
|
help='Change the model to seamless tiling (circular) mode',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'-i',
|
||||||
|
'--individual',
|
||||||
|
action='store_true',
|
||||||
|
help='Generate individual files (default)',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'-I',
|
||||||
|
'--init_img',
|
||||||
|
type=str,
|
||||||
|
help='Path to input image for img2img mode (supersedes width and height)',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'-M',
|
||||||
|
'--init_mask',
|
||||||
|
type=str,
|
||||||
|
help='Path to input mask for inpainting mode (supersedes width and height)',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'-T',
|
||||||
|
'-fit',
|
||||||
|
'--fit',
|
||||||
|
action='store_true',
|
||||||
|
help='If specified, will resize the input image to fit within the dimensions of width x height (512x512 default)',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'-f',
|
||||||
|
'--strength',
|
||||||
|
default=0.75,
|
||||||
|
type=float,
|
||||||
|
help='Strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'-G',
|
||||||
|
'--gfpgan_strength',
|
||||||
|
default=0,
|
||||||
|
type=float,
|
||||||
|
help='The strength at which to apply the GFPGAN model to the result, in order to improve faces.',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'-U',
|
||||||
|
'--upscale',
|
||||||
|
nargs='+',
|
||||||
|
default=None,
|
||||||
|
type=float,
|
||||||
|
help='Scale factor (2, 4) for upscaling followed by upscaling strength (0-1.0). If strength not specified, defaults to 0.75'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'-save_orig',
|
||||||
|
'--save_original',
|
||||||
|
action='store_true',
|
||||||
|
help='Save original. Use it when upscaling to save both versions.',
|
||||||
|
)
|
||||||
|
# variants is going to be superseded by a generalized "prompt-morph" function
|
||||||
|
# parser.add_argument('-v','--variants',type=int,help="in img2img mode, the first generated image will get passed back to img2img to generate the requested number of variants")
|
||||||
|
parser.add_argument(
|
||||||
|
'-x',
|
||||||
|
'--skip_normalize',
|
||||||
|
action='store_true',
|
||||||
|
help='Skip subprompt weight normalization',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'-A',
|
||||||
|
'-m',
|
||||||
|
'--sampler',
|
||||||
|
dest='sampler_name',
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
choices=SAMPLER_CHOICES,
|
||||||
|
metavar='SAMPLER_NAME',
|
||||||
|
help=f'Switch to a different sampler. Supported samplers: {", ".join(SAMPLER_CHOICES)}',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'-t',
|
||||||
|
'--log_tokenization',
|
||||||
|
action='store_true',
|
||||||
|
help='shows how the prompt is split into tokens'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--threshold',
|
||||||
|
default=0.0,
|
||||||
|
type=float,
|
||||||
|
help='Add threshold value aka perform clipping.',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--perlin',
|
||||||
|
default=0.0,
|
||||||
|
type=float,
|
||||||
|
help='Add perlin noise.',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'-v',
|
||||||
|
'--variation_amount',
|
||||||
|
default=0.0,
|
||||||
|
type=float,
|
||||||
|
help='If > 0, generates variations on the initial seed instead of random seeds per iteration. Must be between 0 and 1. Higher values will be more different.'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'-V',
|
||||||
|
'--with_variations',
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
help='list of variations to apply, in the format `seed:weight,seed:weight,...'
|
||||||
|
)
|
||||||
|
return parser
|
47
backend/modules/parse_seed_weights.py
Normal file
47
backend/modules/parse_seed_weights.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
def parse_seed_weights(seed_weights):
|
||||||
|
"""
|
||||||
|
Accepts seed weights as string in "12345:0.1,23456:0.2,3456:0.3" format
|
||||||
|
Validates them
|
||||||
|
If valid: returns as [[12345, 0.1], [23456, 0.2], [3456, 0.3]]
|
||||||
|
If invalid: returns False
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Must be a string
|
||||||
|
if not isinstance(seed_weights, str):
|
||||||
|
return False
|
||||||
|
# String must not be empty
|
||||||
|
if len(seed_weights) == 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
pairs = []
|
||||||
|
|
||||||
|
for pair in seed_weights.split(","):
|
||||||
|
split_values = pair.split(":")
|
||||||
|
|
||||||
|
# Seed and weight are required
|
||||||
|
if len(split_values) != 2:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if len(split_values[0]) == 0 or len(split_values[1]) == 1:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Try casting the seed to int and weight to float
|
||||||
|
try:
|
||||||
|
seed = int(split_values[0])
|
||||||
|
weight = float(split_values[1])
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Seed must be 0 or above
|
||||||
|
if not seed >= 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Weight must be between 0 and 1
|
||||||
|
if not (weight >= 0 and weight <= 1):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# This pair is valid
|
||||||
|
pairs.append([seed, weight])
|
||||||
|
|
||||||
|
# All pairs are valid
|
||||||
|
return pairs
|
397
backend/server.py
Normal file
397
backend/server.py
Normal file
@ -0,0 +1,397 @@
|
|||||||
|
import mimetypes
|
||||||
|
import transformers
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import traceback
|
||||||
|
import eventlet
|
||||||
|
import glob
|
||||||
|
import shlex
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from flask_socketio import SocketIO
|
||||||
|
from flask import Flask, send_from_directory, url_for, jsonify
|
||||||
|
from pathlib import Path
|
||||||
|
from PIL import Image
|
||||||
|
from pytorch_lightning import logging
|
||||||
|
from threading import Event
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from ldm.gfpgan.gfpgan_tools import real_esrgan_upscale
|
||||||
|
from ldm.gfpgan.gfpgan_tools import run_gfpgan
|
||||||
|
from ldm.generate import Generate
|
||||||
|
from ldm.dream.pngwriter import PngWriter, retrieve_metadata
|
||||||
|
|
||||||
|
from modules.parameters import parameters_to_command, create_cmd_parser
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
USER CONFIG
|
||||||
|
"""
|
||||||
|
|
||||||
|
output_dir = "outputs/" # Base output directory for images
|
||||||
|
#host = 'localhost' # Web & socket.io host
|
||||||
|
host = '0.0.0.0' # Web & socket.io host
|
||||||
|
port = 9090 # Web & socket.io port
|
||||||
|
verbose = False # enables copious socket.io logging
|
||||||
|
additional_allowed_origins = ['http://localhost:9090'] # additional CORS allowed origins
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
END USER CONFIG
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
SERVER SETUP
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# fix missing mimetypes on windows due to registry wonkiness
|
||||||
|
mimetypes.add_type('application/javascript', '.js')
|
||||||
|
mimetypes.add_type('text/css', '.css')
|
||||||
|
|
||||||
|
app = Flask(__name__, static_url_path='', static_folder='../frontend/dist/')
|
||||||
|
|
||||||
|
|
||||||
|
app.config['OUTPUTS_FOLDER'] = "../outputs"
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/outputs/<path:filename>')
|
||||||
|
def outputs(filename):
|
||||||
|
return send_from_directory(
|
||||||
|
app.config['OUTPUTS_FOLDER'],
|
||||||
|
filename
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/", defaults={'path': ''})
|
||||||
|
def serve(path):
|
||||||
|
return send_from_directory(app.static_folder, 'index.html')
|
||||||
|
|
||||||
|
|
||||||
|
logger = True if verbose else False
|
||||||
|
engineio_logger = True if verbose else False
|
||||||
|
|
||||||
|
# default 1,000,000, needs to be higher for socketio to accept larger images
|
||||||
|
max_http_buffer_size = 10000000
|
||||||
|
|
||||||
|
cors_allowed_origins = [f"http://{host}:{port}"] + additional_allowed_origins
|
||||||
|
|
||||||
|
socketio = SocketIO(
|
||||||
|
app,
|
||||||
|
logger=logger,
|
||||||
|
engineio_logger=engineio_logger,
|
||||||
|
max_http_buffer_size=max_http_buffer_size,
|
||||||
|
cors_allowed_origins=cors_allowed_origins,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
END SERVER SETUP
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
APP SETUP
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class CanceledException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
canceled = Event()
|
||||||
|
|
||||||
|
# reduce logging outputs to error
|
||||||
|
transformers.logging.set_verbosity_error()
|
||||||
|
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
|
||||||
|
|
||||||
|
# Initialize and load model
|
||||||
|
model = Generate()
|
||||||
|
model.load_model()
|
||||||
|
|
||||||
|
|
||||||
|
# location for "finished" images
|
||||||
|
result_path = os.path.join(output_dir, 'img-samples/')
|
||||||
|
|
||||||
|
# temporary path for intermediates
|
||||||
|
intermediate_path = os.path.join(result_path, 'intermediates/')
|
||||||
|
|
||||||
|
# path for user-uploaded init images and masks
|
||||||
|
init_path = os.path.join(result_path, 'init-images/')
|
||||||
|
mask_path = os.path.join(result_path, 'mask-images/')
|
||||||
|
|
||||||
|
# txt log
|
||||||
|
log_path = os.path.join(result_path, 'dream_log.txt')
|
||||||
|
|
||||||
|
# make all output paths
|
||||||
|
[os.makedirs(path, exist_ok=True)
|
||||||
|
for path in [result_path, intermediate_path, init_path, mask_path]]
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
END APP SETUP
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
SOCKET.IO LISTENERS
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@socketio.on('requestAllImages')
|
||||||
|
def handle_request_all_images():
|
||||||
|
print(f'>> All images requested')
|
||||||
|
parser = create_cmd_parser()
|
||||||
|
paths = list(filter(os.path.isfile, glob.glob(result_path + "*.png")))
|
||||||
|
paths.sort(key=lambda x: os.path.getmtime(x))
|
||||||
|
image_array = []
|
||||||
|
for path in paths:
|
||||||
|
# image = Image.open(path)
|
||||||
|
all_metadata = retrieve_metadata(path)
|
||||||
|
if 'Dream' in all_metadata and not all_metadata['sd-metadata']:
|
||||||
|
metadata = vars(parser.parse_args(shlex.split(all_metadata['Dream'])))
|
||||||
|
else:
|
||||||
|
metadata = all_metadata['sd-metadata']
|
||||||
|
image_array.append({'path': path, 'metadata': metadata})
|
||||||
|
return make_response("OK", data=image_array)
|
||||||
|
|
||||||
|
|
||||||
|
@socketio.on('generateImage')
|
||||||
|
def handle_generate_image_event(generation_parameters, esrgan_parameters, gfpgan_parameters):
|
||||||
|
print(f'>> Image generation requested: {generation_parameters}\nESRGAN parameters: {esrgan_parameters}\nGFPGAN parameters: {gfpgan_parameters}')
|
||||||
|
generate_images(
|
||||||
|
generation_parameters,
|
||||||
|
esrgan_parameters,
|
||||||
|
gfpgan_parameters
|
||||||
|
)
|
||||||
|
return make_response("OK")
|
||||||
|
|
||||||
|
|
||||||
|
@socketio.on('runESRGAN')
|
||||||
|
def handle_run_esrgan_event(original_image, esrgan_parameters):
|
||||||
|
print(f'>> ESRGAN upscale requested for "{original_image["url"]}": {esrgan_parameters}')
|
||||||
|
image = Image.open(original_image["url"])
|
||||||
|
|
||||||
|
seed = original_image['metadata']['seed'] if 'seed' in original_image['metadata'] else 'unknown_seed'
|
||||||
|
|
||||||
|
image = real_esrgan_upscale(
|
||||||
|
image=image,
|
||||||
|
upsampler_scale=esrgan_parameters['upscale'][0],
|
||||||
|
strength=esrgan_parameters['upscale'][1],
|
||||||
|
seed=seed
|
||||||
|
)
|
||||||
|
|
||||||
|
esrgan_parameters['seed'] = seed
|
||||||
|
path = save_image(image, esrgan_parameters, result_path, postprocessing='esrgan')
|
||||||
|
command = parameters_to_command(esrgan_parameters)
|
||||||
|
|
||||||
|
write_log_message(f'[Upscaled] "{original_image["url"]}" > "{path}": {command}')
|
||||||
|
|
||||||
|
socketio.emit(
|
||||||
|
'result', {'url': os.path.relpath(path), 'type': 'esrgan', 'uuid': original_image['uuid'],'metadata': esrgan_parameters})
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@socketio.on('runGFPGAN')
|
||||||
|
def handle_run_gfpgan_event(original_image, gfpgan_parameters):
|
||||||
|
print(f'>> GFPGAN face fix requested for "{original_image["url"]}": {gfpgan_parameters}')
|
||||||
|
image = Image.open(original_image["url"])
|
||||||
|
|
||||||
|
seed = original_image['metadata']['seed'] if 'seed' in original_image['metadata'] else 'unknown_seed'
|
||||||
|
|
||||||
|
image = run_gfpgan(
|
||||||
|
image=image,
|
||||||
|
strength=gfpgan_parameters['gfpgan_strength'],
|
||||||
|
seed=seed,
|
||||||
|
upsampler_scale=1
|
||||||
|
)
|
||||||
|
|
||||||
|
gfpgan_parameters['seed'] = seed
|
||||||
|
path = save_image(image, gfpgan_parameters, result_path, postprocessing='gfpgan')
|
||||||
|
command = parameters_to_command(gfpgan_parameters)
|
||||||
|
|
||||||
|
write_log_message(f'[Fixed faces] "{original_image["url"]}" > "{path}": {command}')
|
||||||
|
|
||||||
|
socketio.emit(
|
||||||
|
'result', {'url': os.path.relpath(path), 'type': 'gfpgan', 'uuid': original_image['uuid'],'metadata': gfpgan_parameters})
|
||||||
|
|
||||||
|
|
||||||
|
@socketio.on('cancel')
|
||||||
|
def handle_cancel():
|
||||||
|
print(f'>> Cancel processing requested')
|
||||||
|
canceled.set()
|
||||||
|
return make_response("OK")
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: I think this needs a safety mechanism.
|
||||||
|
@socketio.on('deleteImage')
|
||||||
|
def handle_delete_image(path):
|
||||||
|
print(f'>> Delete requested "{path}"')
|
||||||
|
Path(path).unlink()
|
||||||
|
return make_response("OK")
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: I think this needs a safety mechanism.
|
||||||
|
@socketio.on('uploadInitialImage')
|
||||||
|
def handle_upload_initial_image(bytes, name):
|
||||||
|
print(f'>> Init image upload requested "{name}"')
|
||||||
|
uuid = uuid4().hex
|
||||||
|
split = os.path.splitext(name)
|
||||||
|
name = f'{split[0]}.{uuid}{split[1]}'
|
||||||
|
file_path = os.path.join(init_path, name)
|
||||||
|
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||||
|
newFile = open(file_path, "wb")
|
||||||
|
newFile.write(bytes)
|
||||||
|
return make_response("OK", data=file_path)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: I think this needs a safety mechanism.
|
||||||
|
@socketio.on('uploadMaskImage')
|
||||||
|
def handle_upload_mask_image(bytes, name):
|
||||||
|
print(f'>> Mask image upload requested "{name}"')
|
||||||
|
uuid = uuid4().hex
|
||||||
|
split = os.path.splitext(name)
|
||||||
|
name = f'{split[0]}.{uuid}{split[1]}'
|
||||||
|
file_path = os.path.join(mask_path, name)
|
||||||
|
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||||
|
newFile = open(file_path, "wb")
|
||||||
|
newFile.write(bytes)
|
||||||
|
return make_response("OK", data=file_path)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
END SOCKET.IO LISTENERS
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
ADDITIONAL FUNCTIONS
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def write_log_message(message, log_path=log_path):
|
||||||
|
"""Logs the filename and parameters used to generate or process that image to log file"""
|
||||||
|
message = f'{message}\n'
|
||||||
|
with open(log_path, 'a', encoding='utf-8') as file:
|
||||||
|
file.writelines(message)
|
||||||
|
|
||||||
|
|
||||||
|
def make_response(status, message=None, data=None):
|
||||||
|
response = {'status': status}
|
||||||
|
if message is not None:
|
||||||
|
response['message'] = message
|
||||||
|
if data is not None:
|
||||||
|
response['data'] = data
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def save_image(image, parameters, output_dir, step_index=None, postprocessing=False):
|
||||||
|
seed = parameters['seed'] if 'seed' in parameters else 'unknown_seed'
|
||||||
|
|
||||||
|
pngwriter = PngWriter(output_dir)
|
||||||
|
prefix = pngwriter.unique_prefix()
|
||||||
|
|
||||||
|
filename = f'{prefix}.{seed}'
|
||||||
|
|
||||||
|
if step_index:
|
||||||
|
filename += f'.{step_index}'
|
||||||
|
if postprocessing:
|
||||||
|
filename += f'.postprocessed'
|
||||||
|
|
||||||
|
filename += '.png'
|
||||||
|
|
||||||
|
command = parameters_to_command(parameters)
|
||||||
|
|
||||||
|
path = pngwriter.save_image_and_prompt_to_png(image, command, metadata=parameters, name=filename)
|
||||||
|
|
||||||
|
return path
|
||||||
|
|
||||||
|
def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters):
|
||||||
|
canceled.clear()
|
||||||
|
|
||||||
|
step_index = 1
|
||||||
|
|
||||||
|
def image_progress(sample, step):
|
||||||
|
if canceled.is_set():
|
||||||
|
raise CanceledException
|
||||||
|
nonlocal step_index
|
||||||
|
nonlocal generation_parameters
|
||||||
|
if generation_parameters["progress_images"] and step % 5 == 0 and step < generation_parameters['steps'] - 1:
|
||||||
|
image = model.sample_to_image(sample)
|
||||||
|
path = save_image(image, generation_parameters, intermediate_path, step_index)
|
||||||
|
|
||||||
|
step_index += 1
|
||||||
|
socketio.emit('intermediateResult', {
|
||||||
|
'url': os.path.relpath(path), 'metadata': generation_parameters})
|
||||||
|
socketio.emit('progress', {'step': step + 1})
|
||||||
|
eventlet.sleep(0)
|
||||||
|
|
||||||
|
def image_done(image, seed):
|
||||||
|
nonlocal generation_parameters
|
||||||
|
nonlocal esrgan_parameters
|
||||||
|
nonlocal gfpgan_parameters
|
||||||
|
|
||||||
|
all_parameters = generation_parameters
|
||||||
|
postprocessing = False
|
||||||
|
|
||||||
|
if esrgan_parameters:
|
||||||
|
image = real_esrgan_upscale(
|
||||||
|
image=image,
|
||||||
|
strength=esrgan_parameters['strength'],
|
||||||
|
upsampler_scale=esrgan_parameters['level'],
|
||||||
|
seed=seed
|
||||||
|
)
|
||||||
|
postprocessing = True
|
||||||
|
all_parameters["upscale"] = [esrgan_parameters['level'], esrgan_parameters['strength']]
|
||||||
|
|
||||||
|
if gfpgan_parameters:
|
||||||
|
image = run_gfpgan(
|
||||||
|
image=image,
|
||||||
|
strength=gfpgan_parameters['strength'],
|
||||||
|
seed=seed,
|
||||||
|
upsampler_scale=1,
|
||||||
|
)
|
||||||
|
postprocessing = True
|
||||||
|
all_parameters["gfpgan_strength"] = gfpgan_parameters['strength']
|
||||||
|
|
||||||
|
all_parameters['seed'] = seed
|
||||||
|
|
||||||
|
path = save_image(image, all_parameters, result_path, postprocessing=postprocessing)
|
||||||
|
command = parameters_to_command(all_parameters)
|
||||||
|
|
||||||
|
print(f'Image generated: "{path}"')
|
||||||
|
write_log_message(f'[Generated] "{path}": {command}')
|
||||||
|
|
||||||
|
socketio.emit(
|
||||||
|
'result', {'url': os.path.relpath(path), 'type': 'generation', 'metadata': all_parameters})
|
||||||
|
eventlet.sleep(0)
|
||||||
|
|
||||||
|
try:
|
||||||
|
model.prompt2image(
|
||||||
|
**generation_parameters,
|
||||||
|
step_callback=image_progress,
|
||||||
|
image_callback=image_done
|
||||||
|
)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
raise
|
||||||
|
except CanceledException:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
socketio.emit('error', (str(e)))
|
||||||
|
print("\n")
|
||||||
|
traceback.print_exc()
|
||||||
|
print("\n")
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
END ADDITIONAL FUNCTIONS
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
print(f'Starting server at http://{host}:{port}')
|
||||||
|
socketio.run(app, host=host, port=port)
|
19
docs/index.html
Normal file
19
docs/index.html
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
<!-- HTML for static distribution bundle build -->
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<title>Swagger UI</title>
|
||||||
|
<link rel="stylesheet" type="text/css" href="swagger-ui/swagger-ui.css" />
|
||||||
|
<link rel="stylesheet" type="text/css" href="swagger-ui/index.css" />
|
||||||
|
<link rel="icon" type="image/png" href="swagger-ui/favicon-32x32.png" sizes="32x32" />
|
||||||
|
<link rel="icon" type="image/png" href="swagger-ui/favicon-16x16.png" sizes="16x16" />
|
||||||
|
</head>
|
||||||
|
|
||||||
|
<body>
|
||||||
|
<div id="swagger-ui"></div>
|
||||||
|
<script src="swagger-ui/swagger-ui-bundle.js" charset="UTF-8"> </script>
|
||||||
|
<script src="swagger-ui/swagger-ui-standalone-preset.js" charset="UTF-8"> </script>
|
||||||
|
<script src="swagger-ui/swagger-initializer.js" charset="UTF-8"> </script>
|
||||||
|
</body>
|
||||||
|
</html>
|
@ -7,10 +7,7 @@ title: macOS
|
|||||||
- macOS 12.3 Monterey or later
|
- macOS 12.3 Monterey or later
|
||||||
- Python
|
- Python
|
||||||
- Patience
|
- Patience
|
||||||
- Apple Silicon\*
|
- Apple Silicon or Intel Mac
|
||||||
|
|
||||||
\*I haven't tested any of this on Intel Macs but I have read that one person got
|
|
||||||
it to work, so Apple Silicon might not be requried.
|
|
||||||
|
|
||||||
Things have moved really fast and so these instructions change often and are
|
Things have moved really fast and so these instructions change often and are
|
||||||
often out-of-date. One of the problems is that there are so many different ways
|
often out-of-date. One of the problems is that there are so many different ways
|
||||||
@ -59,9 +56,13 @@ First get the weights checkpoint download started - it's big:
|
|||||||
# install python 3, git, cmake, protobuf:
|
# install python 3, git, cmake, protobuf:
|
||||||
brew install cmake protobuf rust
|
brew install cmake protobuf rust
|
||||||
|
|
||||||
# install miniconda (M1 arm64 version):
|
# install miniconda for M1 arm64:
|
||||||
curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh -o Miniconda3-latest-MacOSX-arm64.sh
|
curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh -o Miniconda3-latest-MacOSX-arm64.sh
|
||||||
/bin/bash Miniconda3-latest-MacOSX-arm64.sh
|
/bin/bash Miniconda3-latest-MacOSX-arm64.sh
|
||||||
|
|
||||||
|
# OR install miniconda for Intel:
|
||||||
|
curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh -o Miniconda3-latest-MacOSX-x86_64.sh
|
||||||
|
/bin/bash Miniconda3-latest-MacOSX-x86_64.sh
|
||||||
|
|
||||||
|
|
||||||
# EITHER WAY,
|
# EITHER WAY,
|
||||||
@ -82,15 +83,22 @@ brew install cmake protobuf rust
|
|||||||
|
|
||||||
ln -s "$PATH_TO_CKPT/sd-v1-4.ckpt" models/ldm/stable-diffusion-v1/model.ckpt
|
ln -s "$PATH_TO_CKPT/sd-v1-4.ckpt" models/ldm/stable-diffusion-v1/model.ckpt
|
||||||
|
|
||||||
# install packages
|
# install packages for arm64
|
||||||
PIP_EXISTS_ACTION=w CONDA_SUBDIR=osx-arm64 conda env create -f environment-mac.yaml
|
PIP_EXISTS_ACTION=w CONDA_SUBDIR=osx-arm64 conda env create -f environment-mac.yaml
|
||||||
conda activate ldm
|
conda activate ldm
|
||||||
|
|
||||||
|
# OR install packages for x86_64
|
||||||
|
PIP_EXISTS_ACTION=w CONDA_SUBDIR=osx-x86_64 conda env create -f environment-mac.yaml
|
||||||
|
conda activate ldm
|
||||||
|
|
||||||
# only need to do this once
|
# only need to do this once
|
||||||
python scripts/preload_models.py
|
python scripts/preload_models.py
|
||||||
|
|
||||||
# run SD!
|
# run SD!
|
||||||
python scripts/dream.py --full_precision # half-precision requires autocast and won't work
|
python scripts/dream.py --full_precision # half-precision requires autocast and won't work
|
||||||
|
|
||||||
|
# or run the web interface!
|
||||||
|
python scripts/dream.py --web
|
||||||
```
|
```
|
||||||
|
|
||||||
The original scripts should work as well.
|
The original scripts should work as well.
|
||||||
@ -181,7 +189,12 @@ There are several causes of these errors.
|
|||||||
- Third, if it says you're missing taming you need to rebuild your virtual
|
- Third, if it says you're missing taming you need to rebuild your virtual
|
||||||
environment.
|
environment.
|
||||||
|
|
||||||
`conda env remove -n ldm conda env create -f environment-mac.yaml`
|
````bash
|
||||||
|
conda deactivate
|
||||||
|
|
||||||
|
conda env remove -n ldm
|
||||||
|
PIP_EXISTS_ACTION=w CONDA_SUBDIR=osx-arm64 conda env create -f environment-mac.yaml
|
||||||
|
```
|
||||||
|
|
||||||
Fourth, If you have activated the ldm virtual environment and tried rebuilding
|
Fourth, If you have activated the ldm virtual environment and tried rebuilding
|
||||||
it, maybe the problem could be that I have something installed that you don't
|
it, maybe the problem could be that I have something installed that you don't
|
||||||
|
73
docs/openapi3_0.yaml
Normal file
73
docs/openapi3_0.yaml
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
openapi: 3.0.3
|
||||||
|
info:
|
||||||
|
title: Stable Diffusion
|
||||||
|
description: |-
|
||||||
|
TODO: Description Here
|
||||||
|
|
||||||
|
Some useful links:
|
||||||
|
- [Stable Diffusion Dream Server](https://github.com/lstein/stable-diffusion)
|
||||||
|
|
||||||
|
license:
|
||||||
|
name: MIT License
|
||||||
|
url: https://github.com/lstein/stable-diffusion/blob/main/LICENSE
|
||||||
|
version: 1.0.0
|
||||||
|
servers:
|
||||||
|
- url: http://localhost:9090/api
|
||||||
|
tags:
|
||||||
|
- name: images
|
||||||
|
description: Retrieve and manage generated images
|
||||||
|
paths:
|
||||||
|
/images/{imageId}:
|
||||||
|
get:
|
||||||
|
tags:
|
||||||
|
- images
|
||||||
|
summary: Get image by ID
|
||||||
|
description: Returns a single image
|
||||||
|
operationId: getImageById
|
||||||
|
parameters:
|
||||||
|
- name: imageId
|
||||||
|
in: path
|
||||||
|
description: ID of image to return
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: successful operation
|
||||||
|
content:
|
||||||
|
image/png:
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
format: binary
|
||||||
|
'404':
|
||||||
|
description: Image not found
|
||||||
|
/intermediates/{intermediateId}/{step}:
|
||||||
|
get:
|
||||||
|
tags:
|
||||||
|
- images
|
||||||
|
summary: Get intermediate image by ID
|
||||||
|
description: Returns a single intermediate image
|
||||||
|
operationId: getIntermediateById
|
||||||
|
parameters:
|
||||||
|
- name: intermediateId
|
||||||
|
in: path
|
||||||
|
description: ID of intermediate to return
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
- name: step
|
||||||
|
in: path
|
||||||
|
description: The generation step of the intermediate
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: successful operation
|
||||||
|
content:
|
||||||
|
image/png:
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
format: binary
|
||||||
|
'404':
|
||||||
|
description: Intermediate not found
|
@ -51,6 +51,7 @@ We thank them for all of their time and hard work.
|
|||||||
- [Any Winter](https://github.com/any-winter-4079)
|
- [Any Winter](https://github.com/any-winter-4079)
|
||||||
- [Doggettx](https://github.com/doggettx)
|
- [Doggettx](https://github.com/doggettx)
|
||||||
- [Matthias Wild](https://github.com/mauwii)
|
- [Matthias Wild](https://github.com/mauwii)
|
||||||
|
- [Kyle Schouviller](https://github.com/kyle0654)
|
||||||
|
|
||||||
## __Original CompVis Authors:__
|
## __Original CompVis Authors:__
|
||||||
|
|
||||||
|
BIN
docs/swagger-ui/favicon-16x16.png
Normal file
BIN
docs/swagger-ui/favicon-16x16.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 665 B |
BIN
docs/swagger-ui/favicon-32x32.png
Normal file
BIN
docs/swagger-ui/favicon-32x32.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 628 B |
16
docs/swagger-ui/index.css
Normal file
16
docs/swagger-ui/index.css
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
html {
|
||||||
|
box-sizing: border-box;
|
||||||
|
overflow: -moz-scrollbars-vertical;
|
||||||
|
overflow-y: scroll;
|
||||||
|
}
|
||||||
|
|
||||||
|
*,
|
||||||
|
*:before,
|
||||||
|
*:after {
|
||||||
|
box-sizing: inherit;
|
||||||
|
}
|
||||||
|
|
||||||
|
body {
|
||||||
|
margin: 0;
|
||||||
|
background: #fafafa;
|
||||||
|
}
|
79
docs/swagger-ui/oauth2-redirect.html
Normal file
79
docs/swagger-ui/oauth2-redirect.html
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
<!doctype html>
|
||||||
|
<html lang="en-US">
|
||||||
|
<head>
|
||||||
|
<title>Swagger UI: OAuth2 Redirect</title>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<script>
|
||||||
|
'use strict';
|
||||||
|
function run () {
|
||||||
|
var oauth2 = window.opener.swaggerUIRedirectOauth2;
|
||||||
|
var sentState = oauth2.state;
|
||||||
|
var redirectUrl = oauth2.redirectUrl;
|
||||||
|
var isValid, qp, arr;
|
||||||
|
|
||||||
|
if (/code|token|error/.test(window.location.hash)) {
|
||||||
|
qp = window.location.hash.substring(1).replace('?', '&');
|
||||||
|
} else {
|
||||||
|
qp = location.search.substring(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
arr = qp.split("&");
|
||||||
|
arr.forEach(function (v,i,_arr) { _arr[i] = '"' + v.replace('=', '":"') + '"';});
|
||||||
|
qp = qp ? JSON.parse('{' + arr.join() + '}',
|
||||||
|
function (key, value) {
|
||||||
|
return key === "" ? value : decodeURIComponent(value);
|
||||||
|
}
|
||||||
|
) : {};
|
||||||
|
|
||||||
|
isValid = qp.state === sentState;
|
||||||
|
|
||||||
|
if ((
|
||||||
|
oauth2.auth.schema.get("flow") === "accessCode" ||
|
||||||
|
oauth2.auth.schema.get("flow") === "authorizationCode" ||
|
||||||
|
oauth2.auth.schema.get("flow") === "authorization_code"
|
||||||
|
) && !oauth2.auth.code) {
|
||||||
|
if (!isValid) {
|
||||||
|
oauth2.errCb({
|
||||||
|
authId: oauth2.auth.name,
|
||||||
|
source: "auth",
|
||||||
|
level: "warning",
|
||||||
|
message: "Authorization may be unsafe, passed state was changed in server. The passed state wasn't returned from auth server."
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (qp.code) {
|
||||||
|
delete oauth2.state;
|
||||||
|
oauth2.auth.code = qp.code;
|
||||||
|
oauth2.callback({auth: oauth2.auth, redirectUrl: redirectUrl});
|
||||||
|
} else {
|
||||||
|
let oauthErrorMsg;
|
||||||
|
if (qp.error) {
|
||||||
|
oauthErrorMsg = "["+qp.error+"]: " +
|
||||||
|
(qp.error_description ? qp.error_description+ ". " : "no accessCode received from the server. ") +
|
||||||
|
(qp.error_uri ? "More info: "+qp.error_uri : "");
|
||||||
|
}
|
||||||
|
|
||||||
|
oauth2.errCb({
|
||||||
|
authId: oauth2.auth.name,
|
||||||
|
source: "auth",
|
||||||
|
level: "error",
|
||||||
|
message: oauthErrorMsg || "[Authorization failed]: no accessCode received from the server."
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
oauth2.callback({auth: oauth2.auth, token: qp, isValid: isValid, redirectUrl: redirectUrl});
|
||||||
|
}
|
||||||
|
window.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (document.readyState !== 'loading') {
|
||||||
|
run();
|
||||||
|
} else {
|
||||||
|
document.addEventListener('DOMContentLoaded', function () {
|
||||||
|
run();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>
|
20
docs/swagger-ui/swagger-initializer.js
Normal file
20
docs/swagger-ui/swagger-initializer.js
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
window.onload = function() {
|
||||||
|
//<editor-fold desc="Changeable Configuration Block">
|
||||||
|
|
||||||
|
// the following lines will be replaced by docker/configurator, when it runs in a docker-container
|
||||||
|
window.ui = SwaggerUIBundle({
|
||||||
|
url: "openapi3_0.yaml",
|
||||||
|
dom_id: '#swagger-ui',
|
||||||
|
deepLinking: true,
|
||||||
|
presets: [
|
||||||
|
SwaggerUIBundle.presets.apis,
|
||||||
|
SwaggerUIStandalonePreset
|
||||||
|
],
|
||||||
|
plugins: [
|
||||||
|
SwaggerUIBundle.plugins.DownloadUrl
|
||||||
|
],
|
||||||
|
layout: "StandaloneLayout"
|
||||||
|
});
|
||||||
|
|
||||||
|
//</editor-fold>
|
||||||
|
};
|
3
docs/swagger-ui/swagger-ui-bundle.js
Normal file
3
docs/swagger-ui/swagger-ui-bundle.js
Normal file
File diff suppressed because one or more lines are too long
1
docs/swagger-ui/swagger-ui-bundle.js.map
Normal file
1
docs/swagger-ui/swagger-ui-bundle.js.map
Normal file
File diff suppressed because one or more lines are too long
3
docs/swagger-ui/swagger-ui-es-bundle-core.js
Normal file
3
docs/swagger-ui/swagger-ui-es-bundle-core.js
Normal file
File diff suppressed because one or more lines are too long
1
docs/swagger-ui/swagger-ui-es-bundle-core.js.map
Normal file
1
docs/swagger-ui/swagger-ui-es-bundle-core.js.map
Normal file
File diff suppressed because one or more lines are too long
3
docs/swagger-ui/swagger-ui-es-bundle.js
Normal file
3
docs/swagger-ui/swagger-ui-es-bundle.js
Normal file
File diff suppressed because one or more lines are too long
1
docs/swagger-ui/swagger-ui-es-bundle.js.map
Normal file
1
docs/swagger-ui/swagger-ui-es-bundle.js.map
Normal file
File diff suppressed because one or more lines are too long
3
docs/swagger-ui/swagger-ui-standalone-preset.js
Normal file
3
docs/swagger-ui/swagger-ui-standalone-preset.js
Normal file
File diff suppressed because one or more lines are too long
1
docs/swagger-ui/swagger-ui-standalone-preset.js.map
Normal file
1
docs/swagger-ui/swagger-ui-standalone-preset.js.map
Normal file
File diff suppressed because one or more lines are too long
4
docs/swagger-ui/swagger-ui.css
Normal file
4
docs/swagger-ui/swagger-ui.css
Normal file
File diff suppressed because one or more lines are too long
1
docs/swagger-ui/swagger-ui.css.map
Normal file
1
docs/swagger-ui/swagger-ui.css.map
Normal file
File diff suppressed because one or more lines are too long
2
docs/swagger-ui/swagger-ui.js
Normal file
2
docs/swagger-ui/swagger-ui.js
Normal file
File diff suppressed because one or more lines are too long
1
docs/swagger-ui/swagger-ui.js.map
Normal file
1
docs/swagger-ui/swagger-ui.js.map
Normal file
File diff suppressed because one or more lines are too long
@ -40,6 +40,11 @@ dependencies:
|
|||||||
- tensorboard==2.9.0
|
- tensorboard==2.9.0
|
||||||
- torchmetrics==0.9.3
|
- torchmetrics==0.9.3
|
||||||
- pip:
|
- pip:
|
||||||
|
- flask==2.1.3
|
||||||
|
- flask_socketio==5.3.0
|
||||||
|
- flask_cors==3.0.10
|
||||||
|
- dependency_injector==4.40.0
|
||||||
|
- eventlet
|
||||||
- opencv-python==4.6.0
|
- opencv-python==4.6.0
|
||||||
- protobuf==3.20.1
|
- protobuf==3.20.1
|
||||||
- realesrgan==0.2.5.0
|
- realesrgan==0.2.5.0
|
||||||
|
@ -3,7 +3,7 @@ channels:
|
|||||||
- pytorch
|
- pytorch
|
||||||
- defaults
|
- defaults
|
||||||
dependencies:
|
dependencies:
|
||||||
- python=3.8.5
|
- python>=3.9
|
||||||
- pip=20.3
|
- pip=20.3
|
||||||
- cudatoolkit=11.3
|
- cudatoolkit=11.3
|
||||||
- pytorch=1.11.0
|
- pytorch=1.11.0
|
||||||
@ -20,11 +20,16 @@ dependencies:
|
|||||||
- realesrgan==0.2.5.0
|
- realesrgan==0.2.5.0
|
||||||
- test-tube>=0.7.5
|
- test-tube>=0.7.5
|
||||||
- streamlit==1.12.0
|
- streamlit==1.12.0
|
||||||
- pillow==9.2.0
|
- pillow==6.2.0
|
||||||
- einops==0.3.0
|
- einops==0.3.0
|
||||||
- torch-fidelity==0.3.0
|
- torch-fidelity==0.3.0
|
||||||
- transformers==4.19.2
|
- transformers==4.19.2
|
||||||
- torchmetrics==0.6.0
|
- torchmetrics==0.6.0
|
||||||
|
- flask==2.1.3
|
||||||
|
- flask_socketio==5.3.0
|
||||||
|
- flask_cors==3.0.10
|
||||||
|
- dependency_injector==4.40.0
|
||||||
|
- eventlet
|
||||||
- kornia==0.6.0
|
- kornia==0.6.0
|
||||||
- -e git+https://github.com/openai/CLIP.git@main#egg=clip
|
- -e git+https://github.com/openai/CLIP.git@main#egg=clip
|
||||||
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
|
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
|
||||||
|
6
frontend/.eslintrc.cjs
Normal file
6
frontend/.eslintrc.cjs
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
module.exports = {
|
||||||
|
extends: ['eslint:recommended', 'plugin:@typescript-eslint/recommended', 'plugin:react-hooks/recommended'],
|
||||||
|
parser: '@typescript-eslint/parser',
|
||||||
|
plugins: ['@typescript-eslint', 'eslint-plugin-react-hooks'],
|
||||||
|
root: true,
|
||||||
|
};
|
25
frontend/.gitignore
vendored
Normal file
25
frontend/.gitignore
vendored
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
# Logs
|
||||||
|
logs
|
||||||
|
*.log
|
||||||
|
npm-debug.log*
|
||||||
|
yarn-debug.log*
|
||||||
|
yarn-error.log*
|
||||||
|
pnpm-debug.log*
|
||||||
|
lerna-debug.log*
|
||||||
|
|
||||||
|
node_modules
|
||||||
|
# We want to distribute the repo
|
||||||
|
# dist
|
||||||
|
dist-ssr
|
||||||
|
*.local
|
||||||
|
|
||||||
|
# Editor directories and files
|
||||||
|
.vscode/*
|
||||||
|
!.vscode/extensions.json
|
||||||
|
.idea
|
||||||
|
.DS_Store
|
||||||
|
*.suo
|
||||||
|
*.ntvs*
|
||||||
|
*.njsproj
|
||||||
|
*.sln
|
||||||
|
*.sw?
|
85
frontend/README.md
Normal file
85
frontend/README.md
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
# Stable Diffusion Web UI
|
||||||
|
|
||||||
|
Demo at https://peaceful-otter-7a427f.netlify.app/ (not connected to back end)
|
||||||
|
|
||||||
|
much of this readme is just notes for myself during dev work
|
||||||
|
|
||||||
|
numpy rand: 0 to 4294967295
|
||||||
|
|
||||||
|
## Test and Build
|
||||||
|
|
||||||
|
from `frontend/`:
|
||||||
|
|
||||||
|
- `yarn dev` runs `tsc-watch`, which runs `vite build` on successful `tsc` transpilation
|
||||||
|
|
||||||
|
from `.`:
|
||||||
|
|
||||||
|
- `python backend/server.py` serves both frontend and backend at http://localhost:9090
|
||||||
|
|
||||||
|
## API
|
||||||
|
|
||||||
|
`backend/server.py` serves the UI and provides a [socket.io](https://github.com/socketio/socket.io) API via [flask-socketio](https://github.com/miguelgrinberg/flask-socketio).
|
||||||
|
|
||||||
|
### Server Listeners
|
||||||
|
|
||||||
|
The server listens for these socket.io events:
|
||||||
|
|
||||||
|
`cancel`
|
||||||
|
|
||||||
|
- Cancels in-progress image generation
|
||||||
|
- Returns ack only
|
||||||
|
|
||||||
|
`generateImage`
|
||||||
|
|
||||||
|
- Accepts object of image parameters
|
||||||
|
- Generates an image
|
||||||
|
- Returns ack only (image generation function sends progress and result via separate events)
|
||||||
|
|
||||||
|
`deleteImage`
|
||||||
|
|
||||||
|
- Accepts file path to image
|
||||||
|
- Deletes image
|
||||||
|
- Returns ack only
|
||||||
|
|
||||||
|
`deleteAllImages` WIP
|
||||||
|
|
||||||
|
- Deletes all images in `outputs/`
|
||||||
|
- Returns ack only
|
||||||
|
|
||||||
|
`requestAllImages`
|
||||||
|
|
||||||
|
- Returns array of all images in `outputs/`
|
||||||
|
|
||||||
|
`requestCapabilities` WIP
|
||||||
|
|
||||||
|
- Returns capabilities of server (torch device, GFPGAN and ESRGAN availability, ???)
|
||||||
|
|
||||||
|
`sendImage` WIP
|
||||||
|
|
||||||
|
- Accepts a File and attributes
|
||||||
|
- Saves image
|
||||||
|
- Used to save init images which are not generated images
|
||||||
|
|
||||||
|
### Server Emitters
|
||||||
|
|
||||||
|
`progress`
|
||||||
|
|
||||||
|
- Emitted during each step in generation
|
||||||
|
- Sends a number from 0 to 1 representing percentage of steps completed
|
||||||
|
|
||||||
|
`result` WIP
|
||||||
|
|
||||||
|
- Emitted when an image generation has completed
|
||||||
|
- Sends a object:
|
||||||
|
|
||||||
|
```
|
||||||
|
{
|
||||||
|
url: relative_file_path,
|
||||||
|
metadata: image_metadata_object
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## TODO
|
||||||
|
|
||||||
|
- Search repo for "TODO"
|
||||||
|
- My one gripe with Chakra: no way to disable all animations right now and drop the dependence on `framer-motion`. I would prefer to save the ~30kb on bundle and have zero animations. This is on the Chakra roadmap. See https://github.com/chakra-ui/chakra-ui/pull/6368 for last discussion on this. Need to check in on this issue periodically.
|
1
frontend/dist/assets/index.447eb2a9.css
vendored
Normal file
1
frontend/dist/assets/index.447eb2a9.css
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
.checkerboard{background-position:0px 0px,10px 10px;background-size:20px 20px;background-image:linear-gradient(45deg,#eee 25%,transparent 25%,transparent 75%,#eee 75%,#eee 100%),linear-gradient(45deg,#eee 25%,white 25%,white 75%,#eee 75%,#eee 100%)}
|
695
frontend/dist/assets/index.cc5cde43.js
vendored
Normal file
695
frontend/dist/assets/index.cc5cde43.js
vendored
Normal file
File diff suppressed because one or more lines are too long
14
frontend/dist/index.html
vendored
Normal file
14
frontend/dist/index.html
vendored
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8" />
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||||
|
<title>Stable Diffusion Dream Server</title>
|
||||||
|
<script type="module" crossorigin src="/assets/index.cc5cde43.js"></script>
|
||||||
|
<link rel="stylesheet" href="/assets/index.447eb2a9.css">
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div id="root"></div>
|
||||||
|
|
||||||
|
</body>
|
||||||
|
</html>
|
1
frontend/index.d.ts
vendored
Normal file
1
frontend/index.d.ts
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
declare module 'redux-socket.io-middleware';
|
12
frontend/index.html
Normal file
12
frontend/index.html
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8" />
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||||
|
<title>Stable Diffusion Dream Server</title>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div id="root"></div>
|
||||||
|
<script type="module" src="/src/main.tsx"></script>
|
||||||
|
</body>
|
||||||
|
</html>
|
46
frontend/package.json
Normal file
46
frontend/package.json
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
{
|
||||||
|
"name": "sdui",
|
||||||
|
"private": true,
|
||||||
|
"version": "0.0.0",
|
||||||
|
"type": "module",
|
||||||
|
"scripts": {
|
||||||
|
"dev": "tsc-watch --onSuccess 'yarn run vite build -m development'",
|
||||||
|
"hmr": "vite dev",
|
||||||
|
"build": "tsc && vite build",
|
||||||
|
"build-dev": "tsc && vite build -m development",
|
||||||
|
"preview": "vite preview"
|
||||||
|
},
|
||||||
|
"dependencies": {
|
||||||
|
"@chakra-ui/react": "^2.3.1",
|
||||||
|
"@emotion/react": "^11.10.4",
|
||||||
|
"@emotion/styled": "^11.10.4",
|
||||||
|
"@reduxjs/toolkit": "^1.8.5",
|
||||||
|
"@types/uuid": "^8.3.4",
|
||||||
|
"dateformat": "^5.0.3",
|
||||||
|
"framer-motion": "^7.2.1",
|
||||||
|
"lodash": "^4.17.21",
|
||||||
|
"react": "^18.2.0",
|
||||||
|
"react-dom": "^18.2.0",
|
||||||
|
"react-dropzone": "^14.2.2",
|
||||||
|
"react-icons": "^4.4.0",
|
||||||
|
"react-redux": "^8.0.2",
|
||||||
|
"redux-persist": "^6.0.0",
|
||||||
|
"socket.io-client": "^4.5.2",
|
||||||
|
"uuid": "^9.0.0"
|
||||||
|
},
|
||||||
|
"devDependencies": {
|
||||||
|
"@types/dateformat": "^5.0.0",
|
||||||
|
"@types/react": "^18.0.17",
|
||||||
|
"@types/react-dom": "^18.0.6",
|
||||||
|
"@typescript-eslint/eslint-plugin": "^5.36.2",
|
||||||
|
"@typescript-eslint/parser": "^5.36.2",
|
||||||
|
"@vitejs/plugin-react": "^2.0.1",
|
||||||
|
"eslint": "^8.23.0",
|
||||||
|
"eslint-plugin-prettier": "^4.2.1",
|
||||||
|
"eslint-plugin-react-hooks": "^4.6.0",
|
||||||
|
"tsc-watch": "^5.0.3",
|
||||||
|
"typescript": "^4.6.4",
|
||||||
|
"vite": "^3.0.7",
|
||||||
|
"vite-plugin-eslint": "^1.8.1"
|
||||||
|
}
|
||||||
|
}
|
60
frontend/src/App.tsx
Normal file
60
frontend/src/App.tsx
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
import { Grid, GridItem } from '@chakra-ui/react';
|
||||||
|
import CurrentImage from './features/gallery/CurrentImage';
|
||||||
|
import LogViewer from './features/system/LogViewer';
|
||||||
|
import PromptInput from './features/sd/PromptInput';
|
||||||
|
import ProgressBar from './features/header/ProgressBar';
|
||||||
|
import { useEffect } from 'react';
|
||||||
|
import { useAppDispatch } from './app/hooks';
|
||||||
|
import { requestAllImages } from './app/socketio';
|
||||||
|
import ProcessButtons from './features/sd/ProcessButtons';
|
||||||
|
import ImageRoll from './features/gallery/ImageRoll';
|
||||||
|
import SiteHeader from './features/header/SiteHeader';
|
||||||
|
import OptionsAccordion from './features/sd/OptionsAccordion';
|
||||||
|
|
||||||
|
const App = () => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
useEffect(() => {
|
||||||
|
dispatch(requestAllImages());
|
||||||
|
}, [dispatch]);
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<Grid
|
||||||
|
width='100vw'
|
||||||
|
height='100vh'
|
||||||
|
templateAreas={`
|
||||||
|
"header header header header"
|
||||||
|
"progressBar progressBar progressBar progressBar"
|
||||||
|
"menu prompt processButtons imageRoll"
|
||||||
|
"menu currentImage currentImage imageRoll"`}
|
||||||
|
gridTemplateRows={'36px 10px 100px auto'}
|
||||||
|
gridTemplateColumns={'350px auto 100px 388px'}
|
||||||
|
gap={2}
|
||||||
|
>
|
||||||
|
<GridItem area={'header'} pt={1}>
|
||||||
|
<SiteHeader />
|
||||||
|
</GridItem>
|
||||||
|
<GridItem area={'progressBar'}>
|
||||||
|
<ProgressBar />
|
||||||
|
</GridItem>
|
||||||
|
<GridItem pl='2' area={'menu'} overflowY='scroll'>
|
||||||
|
<OptionsAccordion />
|
||||||
|
</GridItem>
|
||||||
|
<GridItem area={'prompt'}>
|
||||||
|
<PromptInput />
|
||||||
|
</GridItem>
|
||||||
|
<GridItem area={'processButtons'}>
|
||||||
|
<ProcessButtons />
|
||||||
|
</GridItem>
|
||||||
|
<GridItem area={'currentImage'}>
|
||||||
|
<CurrentImage />
|
||||||
|
</GridItem>
|
||||||
|
<GridItem pr='2' area={'imageRoll'} overflowY='scroll'>
|
||||||
|
<ImageRoll />
|
||||||
|
</GridItem>
|
||||||
|
</Grid>
|
||||||
|
<LogViewer />
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default App;
|
22
frontend/src/Loading.tsx
Normal file
22
frontend/src/Loading.tsx
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
import { Flex, Spinner } from '@chakra-ui/react';
|
||||||
|
|
||||||
|
const Loading = () => {
|
||||||
|
return (
|
||||||
|
<Flex
|
||||||
|
width={'100vw'}
|
||||||
|
height={'100vh'}
|
||||||
|
alignItems='center'
|
||||||
|
justifyContent='center'
|
||||||
|
>
|
||||||
|
<Spinner
|
||||||
|
thickness='2px'
|
||||||
|
speed='1s'
|
||||||
|
emptyColor='gray.200'
|
||||||
|
color='gray.400'
|
||||||
|
size='xl'
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default Loading;
|
55
frontend/src/app/constants.ts
Normal file
55
frontend/src/app/constants.ts
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
// TODO: use Enums?
|
||||||
|
|
||||||
|
// Valid samplers
|
||||||
|
export const SAMPLERS: Array<string> = [
|
||||||
|
'ddim',
|
||||||
|
'plms',
|
||||||
|
'k_lms',
|
||||||
|
'k_dpm_2',
|
||||||
|
'k_dpm_2_a',
|
||||||
|
'k_euler',
|
||||||
|
'k_euler_a',
|
||||||
|
'k_heun',
|
||||||
|
];
|
||||||
|
|
||||||
|
// Valid image widths
|
||||||
|
export const WIDTHS: Array<number> = [
|
||||||
|
64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960,
|
||||||
|
1024,
|
||||||
|
];
|
||||||
|
|
||||||
|
// Valid image heights
|
||||||
|
export const HEIGHTS: Array<number> = [
|
||||||
|
64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960,
|
||||||
|
1024,
|
||||||
|
];
|
||||||
|
|
||||||
|
// Valid upscaling levels
|
||||||
|
export const UPSCALING_LEVELS: Array<{ key: string; value: number }> = [
|
||||||
|
{ key: '2x', value: 2 },
|
||||||
|
{ key: '4x', value: 4 },
|
||||||
|
];
|
||||||
|
|
||||||
|
// Internal to human-readable parameters
|
||||||
|
export const PARAMETERS: { [key: string]: string } = {
|
||||||
|
prompt: 'Prompt',
|
||||||
|
iterations: 'Iterations',
|
||||||
|
steps: 'Steps',
|
||||||
|
cfgScale: 'CFG Scale',
|
||||||
|
height: 'Height',
|
||||||
|
width: 'Width',
|
||||||
|
sampler: 'Sampler',
|
||||||
|
seed: 'Seed',
|
||||||
|
img2imgStrength: 'img2img Strength',
|
||||||
|
gfpganStrength: 'GFPGAN Strength',
|
||||||
|
upscalingLevel: 'Upscaling Level',
|
||||||
|
upscalingStrength: 'Upscaling Strength',
|
||||||
|
initialImagePath: 'Initial Image',
|
||||||
|
maskPath: 'Initial Image Mask',
|
||||||
|
shouldFitToWidthHeight: 'Fit Initial Image',
|
||||||
|
seamless: 'Seamless Tiling',
|
||||||
|
};
|
||||||
|
|
||||||
|
export const NUMPY_RAND_MIN = 0;
|
||||||
|
|
||||||
|
export const NUMPY_RAND_MAX = 4294967295;
|
7
frontend/src/app/hooks.ts
Normal file
7
frontend/src/app/hooks.ts
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
import { useDispatch, useSelector } from 'react-redux';
|
||||||
|
import type { TypedUseSelectorHook } from 'react-redux';
|
||||||
|
import type { RootState, AppDispatch } from './store';
|
||||||
|
|
||||||
|
// Use throughout your app instead of plain `useDispatch` and `useSelector`
|
||||||
|
export const useAppDispatch: () => AppDispatch = useDispatch;
|
||||||
|
export const useAppSelector: TypedUseSelectorHook<RootState> = useSelector;
|
182
frontend/src/app/parameterTranslation.ts
Normal file
182
frontend/src/app/parameterTranslation.ts
Normal file
@ -0,0 +1,182 @@
|
|||||||
|
import { SDState } from '../features/sd/sdSlice';
|
||||||
|
import randomInt from '../features/sd/util/randomInt';
|
||||||
|
import {
|
||||||
|
seedWeightsToString,
|
||||||
|
stringToSeedWeights,
|
||||||
|
} from '../features/sd/util/seedWeightPairs';
|
||||||
|
import { SystemState } from '../features/system/systemSlice';
|
||||||
|
import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from './constants';
|
||||||
|
|
||||||
|
/*
|
||||||
|
These functions translate frontend state into parameters
|
||||||
|
suitable for consumption by the backend, and vice-versa.
|
||||||
|
*/
|
||||||
|
|
||||||
|
export const frontendToBackendParameters = (
|
||||||
|
sdState: SDState,
|
||||||
|
systemState: SystemState
|
||||||
|
): { [key: string]: any } => {
|
||||||
|
const {
|
||||||
|
prompt,
|
||||||
|
iterations,
|
||||||
|
steps,
|
||||||
|
cfgScale,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
sampler,
|
||||||
|
seed,
|
||||||
|
seamless,
|
||||||
|
shouldUseInitImage,
|
||||||
|
img2imgStrength,
|
||||||
|
initialImagePath,
|
||||||
|
maskPath,
|
||||||
|
shouldFitToWidthHeight,
|
||||||
|
shouldGenerateVariations,
|
||||||
|
variantAmount,
|
||||||
|
seedWeights,
|
||||||
|
shouldRunESRGAN,
|
||||||
|
upscalingLevel,
|
||||||
|
upscalingStrength,
|
||||||
|
shouldRunGFPGAN,
|
||||||
|
gfpganStrength,
|
||||||
|
shouldRandomizeSeed,
|
||||||
|
} = sdState;
|
||||||
|
|
||||||
|
const { shouldDisplayInProgress } = systemState;
|
||||||
|
|
||||||
|
const generationParameters: { [k: string]: any } = {
|
||||||
|
prompt,
|
||||||
|
iterations,
|
||||||
|
steps,
|
||||||
|
cfg_scale: cfgScale,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
sampler_name: sampler,
|
||||||
|
seed,
|
||||||
|
seamless,
|
||||||
|
progress_images: shouldDisplayInProgress,
|
||||||
|
};
|
||||||
|
|
||||||
|
generationParameters.seed = shouldRandomizeSeed
|
||||||
|
? randomInt(NUMPY_RAND_MIN, NUMPY_RAND_MAX)
|
||||||
|
: seed;
|
||||||
|
|
||||||
|
if (shouldUseInitImage) {
|
||||||
|
generationParameters.init_img = initialImagePath;
|
||||||
|
generationParameters.strength = img2imgStrength;
|
||||||
|
generationParameters.fit = shouldFitToWidthHeight;
|
||||||
|
if (maskPath) {
|
||||||
|
generationParameters.init_mask = maskPath;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (shouldGenerateVariations) {
|
||||||
|
generationParameters.variation_amount = variantAmount;
|
||||||
|
if (seedWeights) {
|
||||||
|
generationParameters.with_variations =
|
||||||
|
stringToSeedWeights(seedWeights);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
generationParameters.variation_amount = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
let esrganParameters: false | { [k: string]: any } = false;
|
||||||
|
let gfpganParameters: false | { [k: string]: any } = false;
|
||||||
|
|
||||||
|
if (shouldRunESRGAN) {
|
||||||
|
esrganParameters = {
|
||||||
|
level: upscalingLevel,
|
||||||
|
strength: upscalingStrength,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (shouldRunGFPGAN) {
|
||||||
|
gfpganParameters = {
|
||||||
|
strength: gfpganStrength,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
generationParameters,
|
||||||
|
esrganParameters,
|
||||||
|
gfpganParameters,
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
export const backendToFrontendParameters = (parameters: {
|
||||||
|
[key: string]: any;
|
||||||
|
}) => {
|
||||||
|
const {
|
||||||
|
prompt,
|
||||||
|
iterations,
|
||||||
|
steps,
|
||||||
|
cfg_scale,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
sampler_name,
|
||||||
|
seed,
|
||||||
|
seamless,
|
||||||
|
progress_images,
|
||||||
|
variation_amount,
|
||||||
|
with_variations,
|
||||||
|
gfpgan_strength,
|
||||||
|
upscale,
|
||||||
|
init_img,
|
||||||
|
init_mask,
|
||||||
|
strength,
|
||||||
|
} = parameters;
|
||||||
|
|
||||||
|
const sd: { [key: string]: any } = {
|
||||||
|
shouldDisplayInProgress: progress_images,
|
||||||
|
// init
|
||||||
|
shouldGenerateVariations: false,
|
||||||
|
shouldRunESRGAN: false,
|
||||||
|
shouldRunGFPGAN: false,
|
||||||
|
initialImagePath: '',
|
||||||
|
maskPath: '',
|
||||||
|
};
|
||||||
|
|
||||||
|
if (variation_amount > 0) {
|
||||||
|
sd.shouldGenerateVariations = true;
|
||||||
|
sd.variantAmount = variation_amount;
|
||||||
|
if (with_variations) {
|
||||||
|
sd.seedWeights = seedWeightsToString(with_variations);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (gfpgan_strength > 0) {
|
||||||
|
sd.shouldRunGFPGAN = true;
|
||||||
|
sd.gfpganStrength = gfpgan_strength;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (upscale) {
|
||||||
|
sd.shouldRunESRGAN = true;
|
||||||
|
sd.upscalingLevel = upscale[0];
|
||||||
|
sd.upscalingStrength = upscale[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (init_img) {
|
||||||
|
sd.shouldUseInitImage = true
|
||||||
|
sd.initialImagePath = init_img;
|
||||||
|
sd.strength = strength;
|
||||||
|
if (init_mask) {
|
||||||
|
sd.maskPath = init_mask;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// if we had a prompt, add all the metadata, but if we don't have a prompt,
|
||||||
|
// we must have only done ESRGAN or GFPGAN so do not add that metadata
|
||||||
|
if (prompt) {
|
||||||
|
sd.prompt = prompt;
|
||||||
|
sd.iterations = iterations;
|
||||||
|
sd.steps = steps;
|
||||||
|
sd.cfgScale = cfg_scale;
|
||||||
|
sd.height = height;
|
||||||
|
sd.width = width;
|
||||||
|
sd.sampler = sampler_name;
|
||||||
|
sd.seed = seed;
|
||||||
|
sd.seamless = seamless;
|
||||||
|
}
|
||||||
|
|
||||||
|
return sd;
|
||||||
|
};
|
393
frontend/src/app/socketio.ts
Normal file
393
frontend/src/app/socketio.ts
Normal file
@ -0,0 +1,393 @@
|
|||||||
|
import { createAction, Middleware } from '@reduxjs/toolkit';
|
||||||
|
import { io } from 'socket.io-client';
|
||||||
|
import {
|
||||||
|
addImage,
|
||||||
|
clearIntermediateImage,
|
||||||
|
removeImage,
|
||||||
|
SDImage,
|
||||||
|
SDMetadata,
|
||||||
|
setGalleryImages,
|
||||||
|
setIntermediateImage,
|
||||||
|
} from '../features/gallery/gallerySlice';
|
||||||
|
import {
|
||||||
|
addLogEntry,
|
||||||
|
setCurrentStep,
|
||||||
|
setIsConnected,
|
||||||
|
setIsProcessing,
|
||||||
|
} from '../features/system/systemSlice';
|
||||||
|
import { v4 as uuidv4 } from 'uuid';
|
||||||
|
import { setInitialImagePath, setMaskPath } from '../features/sd/sdSlice';
|
||||||
|
import {
|
||||||
|
backendToFrontendParameters,
|
||||||
|
frontendToBackendParameters,
|
||||||
|
} from './parameterTranslation';
|
||||||
|
|
||||||
|
export interface SocketIOResponse {
|
||||||
|
status: 'OK' | 'ERROR';
|
||||||
|
message?: string;
|
||||||
|
data?: any;
|
||||||
|
}
|
||||||
|
|
||||||
|
export const socketioMiddleware = () => {
|
||||||
|
const { hostname, port } = new URL(window.location.href);
|
||||||
|
|
||||||
|
const socketio = io(`http://${hostname}:9090`);
|
||||||
|
|
||||||
|
let areListenersSet = false;
|
||||||
|
|
||||||
|
const middleware: Middleware = (store) => (next) => (action) => {
|
||||||
|
const { dispatch, getState } = store;
|
||||||
|
if (!areListenersSet) {
|
||||||
|
// CONNECT
|
||||||
|
socketio.on('connect', () => {
|
||||||
|
try {
|
||||||
|
dispatch(setIsConnected(true));
|
||||||
|
} catch (e) {
|
||||||
|
console.error(e);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// DISCONNECT
|
||||||
|
socketio.on('disconnect', () => {
|
||||||
|
try {
|
||||||
|
dispatch(setIsConnected(false));
|
||||||
|
dispatch(setIsProcessing(false));
|
||||||
|
dispatch(addLogEntry(`Disconnected from server`));
|
||||||
|
} catch (e) {
|
||||||
|
console.error(e);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// PROCESSING RESULT
|
||||||
|
socketio.on(
|
||||||
|
'result',
|
||||||
|
(data: {
|
||||||
|
url: string;
|
||||||
|
type: 'generation' | 'esrgan' | 'gfpgan';
|
||||||
|
uuid?: string;
|
||||||
|
metadata: { [key: string]: any };
|
||||||
|
}) => {
|
||||||
|
try {
|
||||||
|
const newUuid = uuidv4();
|
||||||
|
const { type, url, uuid, metadata } = data;
|
||||||
|
switch (type) {
|
||||||
|
case 'generation': {
|
||||||
|
const translatedMetadata =
|
||||||
|
backendToFrontendParameters(metadata);
|
||||||
|
dispatch(
|
||||||
|
addImage({
|
||||||
|
uuid: newUuid,
|
||||||
|
url,
|
||||||
|
metadata: translatedMetadata,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
dispatch(
|
||||||
|
addLogEntry(`Image generated: ${url}`)
|
||||||
|
);
|
||||||
|
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 'esrgan': {
|
||||||
|
const originalImage =
|
||||||
|
getState().gallery.images.find(
|
||||||
|
(i: SDImage) => i.uuid === uuid
|
||||||
|
);
|
||||||
|
const newMetadata = {
|
||||||
|
...originalImage.metadata,
|
||||||
|
};
|
||||||
|
newMetadata.shouldRunESRGAN = true;
|
||||||
|
newMetadata.upscalingLevel =
|
||||||
|
metadata.upscale[0];
|
||||||
|
newMetadata.upscalingStrength =
|
||||||
|
metadata.upscale[1];
|
||||||
|
dispatch(
|
||||||
|
addImage({
|
||||||
|
uuid: newUuid,
|
||||||
|
url,
|
||||||
|
metadata: newMetadata,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
dispatch(
|
||||||
|
addLogEntry(`ESRGAN upscaled: ${url}`)
|
||||||
|
);
|
||||||
|
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 'gfpgan': {
|
||||||
|
const originalImage =
|
||||||
|
getState().gallery.images.find(
|
||||||
|
(i: SDImage) => i.uuid === uuid
|
||||||
|
);
|
||||||
|
const newMetadata = {
|
||||||
|
...originalImage.metadata,
|
||||||
|
};
|
||||||
|
newMetadata.shouldRunGFPGAN = true;
|
||||||
|
newMetadata.gfpganStrength =
|
||||||
|
metadata.gfpgan_strength;
|
||||||
|
dispatch(
|
||||||
|
addImage({
|
||||||
|
uuid: newUuid,
|
||||||
|
url,
|
||||||
|
metadata: newMetadata,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
dispatch(
|
||||||
|
addLogEntry(`GFPGAN fixed faces: ${url}`)
|
||||||
|
);
|
||||||
|
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dispatch(setIsProcessing(false));
|
||||||
|
} catch (e) {
|
||||||
|
console.error(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
// PROGRESS UPDATE
|
||||||
|
socketio.on('progress', (data: { step: number }) => {
|
||||||
|
try {
|
||||||
|
dispatch(setIsProcessing(true));
|
||||||
|
dispatch(setCurrentStep(data.step));
|
||||||
|
} catch (e) {
|
||||||
|
console.error(e);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// INTERMEDIATE IMAGE
|
||||||
|
socketio.on(
|
||||||
|
'intermediateResult',
|
||||||
|
(data: { url: string; metadata: SDMetadata }) => {
|
||||||
|
try {
|
||||||
|
const uuid = uuidv4();
|
||||||
|
const { url, metadata } = data;
|
||||||
|
dispatch(
|
||||||
|
setIntermediateImage({
|
||||||
|
uuid,
|
||||||
|
url,
|
||||||
|
metadata,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
dispatch(
|
||||||
|
addLogEntry(`Intermediate image generated: ${url}`)
|
||||||
|
);
|
||||||
|
} catch (e) {
|
||||||
|
console.error(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
// ERROR FROM BACKEND
|
||||||
|
socketio.on('error', (message) => {
|
||||||
|
try {
|
||||||
|
dispatch(addLogEntry(`Server error: ${message}`));
|
||||||
|
dispatch(setIsProcessing(false));
|
||||||
|
dispatch(clearIntermediateImage());
|
||||||
|
} catch (e) {
|
||||||
|
console.error(e);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
areListenersSet = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// HANDLE ACTIONS
|
||||||
|
|
||||||
|
switch (action.type) {
|
||||||
|
// GENERATE IMAGE
|
||||||
|
case 'socketio/generateImage': {
|
||||||
|
dispatch(setIsProcessing(true));
|
||||||
|
dispatch(setCurrentStep(-1));
|
||||||
|
|
||||||
|
const {
|
||||||
|
generationParameters,
|
||||||
|
esrganParameters,
|
||||||
|
gfpganParameters,
|
||||||
|
} = frontendToBackendParameters(
|
||||||
|
getState().sd,
|
||||||
|
getState().system
|
||||||
|
);
|
||||||
|
|
||||||
|
socketio.emit(
|
||||||
|
'generateImage',
|
||||||
|
generationParameters,
|
||||||
|
esrganParameters,
|
||||||
|
gfpganParameters
|
||||||
|
);
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
addLogEntry(
|
||||||
|
`Image generation requested: ${JSON.stringify({
|
||||||
|
...generationParameters,
|
||||||
|
...esrganParameters,
|
||||||
|
...gfpganParameters,
|
||||||
|
})}`
|
||||||
|
)
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// RUN ESRGAN (UPSCALING)
|
||||||
|
case 'socketio/runESRGAN': {
|
||||||
|
const imageToProcess = action.payload;
|
||||||
|
dispatch(setIsProcessing(true));
|
||||||
|
dispatch(setCurrentStep(-1));
|
||||||
|
const { upscalingLevel, upscalingStrength } = getState().sd;
|
||||||
|
const esrganParameters = {
|
||||||
|
upscale: [upscalingLevel, upscalingStrength],
|
||||||
|
};
|
||||||
|
socketio.emit('runESRGAN', imageToProcess, esrganParameters);
|
||||||
|
dispatch(
|
||||||
|
addLogEntry(
|
||||||
|
`ESRGAN upscale requested: ${JSON.stringify({
|
||||||
|
file: imageToProcess.url,
|
||||||
|
...esrganParameters,
|
||||||
|
})}`
|
||||||
|
)
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// RUN GFPGAN (FIX FACES)
|
||||||
|
case 'socketio/runGFPGAN': {
|
||||||
|
const imageToProcess = action.payload;
|
||||||
|
dispatch(setIsProcessing(true));
|
||||||
|
dispatch(setCurrentStep(-1));
|
||||||
|
const { gfpganStrength } = getState().sd;
|
||||||
|
|
||||||
|
const gfpganParameters = {
|
||||||
|
gfpgan_strength: gfpganStrength,
|
||||||
|
};
|
||||||
|
socketio.emit('runGFPGAN', imageToProcess, gfpganParameters);
|
||||||
|
dispatch(
|
||||||
|
addLogEntry(
|
||||||
|
`GFPGAN fix faces requested: ${JSON.stringify({
|
||||||
|
file: imageToProcess.url,
|
||||||
|
...gfpganParameters,
|
||||||
|
})}`
|
||||||
|
)
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// DELETE IMAGE
|
||||||
|
case 'socketio/deleteImage': {
|
||||||
|
const imageToDelete = action.payload;
|
||||||
|
const { url } = imageToDelete;
|
||||||
|
socketio.emit(
|
||||||
|
'deleteImage',
|
||||||
|
url,
|
||||||
|
(response: SocketIOResponse) => {
|
||||||
|
if (response.status === 'OK') {
|
||||||
|
dispatch(removeImage(imageToDelete));
|
||||||
|
dispatch(addLogEntry(`Image deleted: ${url}`));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// GET ALL IMAGES FOR GALLERY
|
||||||
|
case 'socketio/requestAllImages': {
|
||||||
|
socketio.emit(
|
||||||
|
'requestAllImages',
|
||||||
|
(response: SocketIOResponse) => {
|
||||||
|
dispatch(setGalleryImages(response.data));
|
||||||
|
dispatch(
|
||||||
|
addLogEntry(`Loaded ${response.data.length} images`)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// CANCEL PROCESSING
|
||||||
|
case 'socketio/cancelProcessing': {
|
||||||
|
socketio.emit('cancel', (response: SocketIOResponse) => {
|
||||||
|
const { intermediateImage } = getState().gallery;
|
||||||
|
if (response.status === 'OK') {
|
||||||
|
dispatch(setIsProcessing(false));
|
||||||
|
if (intermediateImage) {
|
||||||
|
dispatch(addImage(intermediateImage));
|
||||||
|
dispatch(
|
||||||
|
addLogEntry(
|
||||||
|
`Intermediate image saved: ${intermediateImage.url}`
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
|
dispatch(clearIntermediateImage());
|
||||||
|
}
|
||||||
|
dispatch(addLogEntry(`Processing canceled`));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// UPLOAD INITIAL IMAGE
|
||||||
|
case 'socketio/uploadInitialImage': {
|
||||||
|
const file = action.payload;
|
||||||
|
|
||||||
|
socketio.emit(
|
||||||
|
'uploadInitialImage',
|
||||||
|
file,
|
||||||
|
file.name,
|
||||||
|
(response: SocketIOResponse) => {
|
||||||
|
if (response.status === 'OK') {
|
||||||
|
dispatch(setInitialImagePath(response.data));
|
||||||
|
dispatch(
|
||||||
|
addLogEntry(
|
||||||
|
`Initial image uploaded: ${response.data}`
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// UPLOAD MASK IMAGE
|
||||||
|
case 'socketio/uploadMaskImage': {
|
||||||
|
const file = action.payload;
|
||||||
|
|
||||||
|
socketio.emit(
|
||||||
|
'uploadMaskImage',
|
||||||
|
file,
|
||||||
|
file.name,
|
||||||
|
(response: SocketIOResponse) => {
|
||||||
|
if (response.status === 'OK') {
|
||||||
|
dispatch(setMaskPath(response.data));
|
||||||
|
dispatch(
|
||||||
|
addLogEntry(
|
||||||
|
`Mask image uploaded: ${response.data}`
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
next(action);
|
||||||
|
};
|
||||||
|
|
||||||
|
return middleware;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Actions to be used by app
|
||||||
|
|
||||||
|
export const generateImage = createAction<undefined>('socketio/generateImage');
|
||||||
|
export const runESRGAN = createAction<SDImage>('socketio/runESRGAN');
|
||||||
|
export const runGFPGAN = createAction<SDImage>('socketio/runGFPGAN');
|
||||||
|
export const deleteImage = createAction<SDImage>('socketio/deleteImage');
|
||||||
|
export const requestAllImages = createAction<undefined>(
|
||||||
|
'socketio/requestAllImages'
|
||||||
|
);
|
||||||
|
export const cancelProcessing = createAction<undefined>(
|
||||||
|
'socketio/cancelProcessing'
|
||||||
|
);
|
||||||
|
export const uploadInitialImage = createAction<File>(
|
||||||
|
'socketio/uploadInitialImage'
|
||||||
|
);
|
||||||
|
export const uploadMaskImage = createAction<File>('socketio/uploadMaskImage');
|
53
frontend/src/app/store.ts
Normal file
53
frontend/src/app/store.ts
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
import { combineReducers, configureStore } from '@reduxjs/toolkit';
|
||||||
|
import { persistReducer } from 'redux-persist';
|
||||||
|
import storage from 'redux-persist/lib/storage'; // defaults to localStorage for web
|
||||||
|
|
||||||
|
import sdReducer from '../features/sd/sdSlice';
|
||||||
|
import galleryReducer from '../features/gallery/gallerySlice';
|
||||||
|
import systemReducer from '../features/system/systemSlice';
|
||||||
|
import { socketioMiddleware } from './socketio';
|
||||||
|
|
||||||
|
const reducers = combineReducers({
|
||||||
|
sd: sdReducer,
|
||||||
|
gallery: galleryReducer,
|
||||||
|
system: systemReducer,
|
||||||
|
});
|
||||||
|
|
||||||
|
const persistConfig = {
|
||||||
|
key: 'root',
|
||||||
|
storage,
|
||||||
|
};
|
||||||
|
|
||||||
|
const persistedReducer = persistReducer(persistConfig, reducers);
|
||||||
|
|
||||||
|
/*
|
||||||
|
The frontend needs to be distributed as a production build, so
|
||||||
|
we cannot reasonably ask users to edit the JS and specify the
|
||||||
|
host and port on which the socket.io server will run.
|
||||||
|
|
||||||
|
The solution is to allow server script to be run with arguments
|
||||||
|
(or just edited) providing the host and port. Then, the server
|
||||||
|
serves a route `/socketio_config` which responds with the host
|
||||||
|
and port.
|
||||||
|
|
||||||
|
When the frontend loads, it synchronously requests that route
|
||||||
|
and thus gets the host and port. This requires a suspicious
|
||||||
|
fetch somewhere, and the store setup seems like as good a place
|
||||||
|
as any to make this fetch request.
|
||||||
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
// Continue with store setup
|
||||||
|
export const store = configureStore({
|
||||||
|
reducer: persistedReducer,
|
||||||
|
middleware: (getDefaultMiddleware) =>
|
||||||
|
getDefaultMiddleware({
|
||||||
|
// redux-persist sometimes needs to have a function in redux, need to disable this check
|
||||||
|
serializableCheck: false,
|
||||||
|
}).concat(socketioMiddleware()),
|
||||||
|
});
|
||||||
|
|
||||||
|
// Infer the `RootState` and `AppDispatch` types from the store itself
|
||||||
|
export type RootState = ReturnType<typeof store.getState>;
|
||||||
|
// Inferred type: {posts: PostsState, comments: CommentsState, users: UsersState}
|
||||||
|
export type AppDispatch = typeof store.dispatch;
|
37
frontend/src/app/theme.ts
Normal file
37
frontend/src/app/theme.ts
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
import { extendTheme } from '@chakra-ui/react';
|
||||||
|
import type { StyleFunctionProps } from '@chakra-ui/styled-system';
|
||||||
|
|
||||||
|
export const theme = extendTheme({
|
||||||
|
config: {
|
||||||
|
initialColorMode: 'dark',
|
||||||
|
useSystemColorMode: false,
|
||||||
|
},
|
||||||
|
components: {
|
||||||
|
Tooltip: {
|
||||||
|
baseStyle: (props: StyleFunctionProps) => ({
|
||||||
|
textColor: props.colorMode === 'dark' ? 'gray.800' : 'gray.100',
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
Accordion: {
|
||||||
|
baseStyle: (props: StyleFunctionProps) => ({
|
||||||
|
button: {
|
||||||
|
fontWeight: 'bold',
|
||||||
|
_hover: {
|
||||||
|
bgColor:
|
||||||
|
props.colorMode === 'dark'
|
||||||
|
? 'rgba(255,255,255,0.05)'
|
||||||
|
: 'rgba(0,0,0,0.05)',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
panel: {
|
||||||
|
paddingBottom: 2,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
FormLabel: {
|
||||||
|
baseStyle: {
|
||||||
|
fontWeight: 'light',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
16
frontend/src/components/SDButton.tsx
Normal file
16
frontend/src/components/SDButton.tsx
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
import { Button, ButtonProps } from '@chakra-ui/react';
|
||||||
|
|
||||||
|
interface Props extends ButtonProps {
|
||||||
|
label: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
const SDButton = (props: Props) => {
|
||||||
|
const { label, size = 'sm', ...rest } = props;
|
||||||
|
return (
|
||||||
|
<Button size={size} {...rest}>
|
||||||
|
{label}
|
||||||
|
</Button>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default SDButton;
|
56
frontend/src/components/SDNumberInput.tsx
Normal file
56
frontend/src/components/SDNumberInput.tsx
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
import {
|
||||||
|
FormControl,
|
||||||
|
NumberInput,
|
||||||
|
NumberInputField,
|
||||||
|
NumberInputStepper,
|
||||||
|
NumberIncrementStepper,
|
||||||
|
NumberDecrementStepper,
|
||||||
|
Text,
|
||||||
|
FormLabel,
|
||||||
|
NumberInputProps,
|
||||||
|
Flex,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
|
||||||
|
interface Props extends NumberInputProps {
|
||||||
|
label?: string;
|
||||||
|
width?: string | number;
|
||||||
|
}
|
||||||
|
|
||||||
|
const SDNumberInput = (props: Props) => {
|
||||||
|
const {
|
||||||
|
label,
|
||||||
|
isDisabled = false,
|
||||||
|
fontSize = 'md',
|
||||||
|
size = 'sm',
|
||||||
|
width,
|
||||||
|
isInvalid,
|
||||||
|
...rest
|
||||||
|
} = props;
|
||||||
|
return (
|
||||||
|
<FormControl isDisabled={isDisabled} width={width} isInvalid={isInvalid}>
|
||||||
|
<Flex gap={2} justifyContent={'space-between'} alignItems={'center'}>
|
||||||
|
{label && (
|
||||||
|
<FormLabel marginBottom={1}>
|
||||||
|
<Text fontSize={fontSize} whiteSpace='nowrap'>
|
||||||
|
{label}
|
||||||
|
</Text>
|
||||||
|
</FormLabel>
|
||||||
|
)}
|
||||||
|
<NumberInput
|
||||||
|
size={size}
|
||||||
|
{...rest}
|
||||||
|
keepWithinRange={false}
|
||||||
|
clampValueOnBlur={true}
|
||||||
|
>
|
||||||
|
<NumberInputField fontSize={'md'}/>
|
||||||
|
<NumberInputStepper>
|
||||||
|
<NumberIncrementStepper />
|
||||||
|
<NumberDecrementStepper />
|
||||||
|
</NumberInputStepper>
|
||||||
|
</NumberInput>
|
||||||
|
</Flex>
|
||||||
|
</FormControl>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default SDNumberInput;
|
57
frontend/src/components/SDSelect.tsx
Normal file
57
frontend/src/components/SDSelect.tsx
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
import {
|
||||||
|
Flex,
|
||||||
|
FormControl,
|
||||||
|
FormLabel,
|
||||||
|
Select,
|
||||||
|
SelectProps,
|
||||||
|
Text,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
|
||||||
|
interface Props extends SelectProps {
|
||||||
|
label: string;
|
||||||
|
validValues:
|
||||||
|
| Array<number | string>
|
||||||
|
| Array<{ key: string; value: string | number }>;
|
||||||
|
}
|
||||||
|
|
||||||
|
const SDSelect = (props: Props) => {
|
||||||
|
const {
|
||||||
|
label,
|
||||||
|
isDisabled,
|
||||||
|
validValues,
|
||||||
|
size = 'sm',
|
||||||
|
fontSize = 'md',
|
||||||
|
marginBottom = 1,
|
||||||
|
whiteSpace = 'nowrap',
|
||||||
|
...rest
|
||||||
|
} = props;
|
||||||
|
return (
|
||||||
|
<FormControl isDisabled={isDisabled}>
|
||||||
|
<Flex justifyContent={'space-between'} alignItems={'center'}>
|
||||||
|
<FormLabel
|
||||||
|
marginBottom={marginBottom}
|
||||||
|
>
|
||||||
|
<Text fontSize={fontSize} whiteSpace={whiteSpace}>
|
||||||
|
{label}
|
||||||
|
</Text>
|
||||||
|
</FormLabel>
|
||||||
|
<Select fontSize={fontSize} size={size} {...rest}>
|
||||||
|
{validValues.map((opt) => {
|
||||||
|
return typeof opt === 'string' ||
|
||||||
|
typeof opt === 'number' ? (
|
||||||
|
<option key={opt} value={opt}>
|
||||||
|
{opt}
|
||||||
|
</option>
|
||||||
|
) : (
|
||||||
|
<option key={opt.value} value={opt.value}>
|
||||||
|
{opt.key}
|
||||||
|
</option>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
</Select>
|
||||||
|
</Flex>
|
||||||
|
</FormControl>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default SDSelect;
|
42
frontend/src/components/SDSwitch.tsx
Normal file
42
frontend/src/components/SDSwitch.tsx
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
import {
|
||||||
|
Flex,
|
||||||
|
FormControl,
|
||||||
|
FormLabel,
|
||||||
|
Switch,
|
||||||
|
SwitchProps,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
|
||||||
|
interface Props extends SwitchProps {
|
||||||
|
label?: string;
|
||||||
|
width?: string | number;
|
||||||
|
}
|
||||||
|
|
||||||
|
const SDSwitch = (props: Props) => {
|
||||||
|
const {
|
||||||
|
label,
|
||||||
|
isDisabled = false,
|
||||||
|
fontSize = 'md',
|
||||||
|
size = 'md',
|
||||||
|
width,
|
||||||
|
...rest
|
||||||
|
} = props;
|
||||||
|
return (
|
||||||
|
<FormControl isDisabled={isDisabled} width={width}>
|
||||||
|
<Flex justifyContent={'space-between'} alignItems={'center'}>
|
||||||
|
{label && (
|
||||||
|
<FormLabel
|
||||||
|
fontSize={fontSize}
|
||||||
|
marginBottom={1}
|
||||||
|
flexGrow={2}
|
||||||
|
whiteSpace='nowrap'
|
||||||
|
>
|
||||||
|
{label}
|
||||||
|
</FormLabel>
|
||||||
|
)}
|
||||||
|
<Switch size={size} {...rest} />
|
||||||
|
</Flex>
|
||||||
|
</FormControl>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default SDSwitch;
|
161
frontend/src/features/gallery/CurrentImage.tsx
Normal file
161
frontend/src/features/gallery/CurrentImage.tsx
Normal file
@ -0,0 +1,161 @@
|
|||||||
|
import { Center, Flex, Image, useColorModeValue } from '@chakra-ui/react';
|
||||||
|
import { useAppDispatch, useAppSelector } from '../../app/hooks';
|
||||||
|
import { RootState } from '../../app/store';
|
||||||
|
import { setAllParameters, setInitialImagePath, setSeed } from '../sd/sdSlice';
|
||||||
|
import { useState } from 'react';
|
||||||
|
import ImageMetadataViewer from './ImageMetadataViewer';
|
||||||
|
import DeleteImageModalButton from './DeleteImageModalButton';
|
||||||
|
import SDButton from '../../components/SDButton';
|
||||||
|
import { runESRGAN, runGFPGAN } from '../../app/socketio';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { SystemState } from '../system/systemSlice';
|
||||||
|
import { isEqual } from 'lodash';
|
||||||
|
|
||||||
|
const height = 'calc(100vh - 238px)';
|
||||||
|
|
||||||
|
const systemSelector = createSelector(
|
||||||
|
(state: RootState) => state.system,
|
||||||
|
(system: SystemState) => {
|
||||||
|
return {
|
||||||
|
isProcessing: system.isProcessing,
|
||||||
|
isConnected: system.isConnected,
|
||||||
|
isGFPGANAvailable: system.isGFPGANAvailable,
|
||||||
|
isESRGANAvailable: system.isESRGANAvailable,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
{
|
||||||
|
memoizeOptions: {
|
||||||
|
resultEqualityCheck: isEqual,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
const CurrentImage = () => {
|
||||||
|
const { currentImage, intermediateImage } = useAppSelector(
|
||||||
|
(state: RootState) => state.gallery
|
||||||
|
);
|
||||||
|
const { isProcessing, isConnected, isGFPGANAvailable, isESRGANAvailable } =
|
||||||
|
useAppSelector(systemSelector);
|
||||||
|
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const bgColor = useColorModeValue(
|
||||||
|
'rgba(255, 255, 255, 0.85)',
|
||||||
|
'rgba(0, 0, 0, 0.8)'
|
||||||
|
);
|
||||||
|
|
||||||
|
const [shouldShowImageDetails, setShouldShowImageDetails] =
|
||||||
|
useState<boolean>(false);
|
||||||
|
|
||||||
|
const imageToDisplay = intermediateImage || currentImage;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex direction={'column'} rounded={'md'} borderWidth={1} p={2} gap={2}>
|
||||||
|
{imageToDisplay && (
|
||||||
|
<Flex gap={2}>
|
||||||
|
<SDButton
|
||||||
|
label='Use as initial image'
|
||||||
|
colorScheme={'gray'}
|
||||||
|
flexGrow={1}
|
||||||
|
variant={'outline'}
|
||||||
|
onClick={() =>
|
||||||
|
dispatch(setInitialImagePath(imageToDisplay.url))
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<SDButton
|
||||||
|
label='Use all'
|
||||||
|
colorScheme={'gray'}
|
||||||
|
flexGrow={1}
|
||||||
|
variant={'outline'}
|
||||||
|
onClick={() =>
|
||||||
|
dispatch(setAllParameters(imageToDisplay.metadata))
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<SDButton
|
||||||
|
label='Use seed'
|
||||||
|
colorScheme={'gray'}
|
||||||
|
flexGrow={1}
|
||||||
|
variant={'outline'}
|
||||||
|
isDisabled={!imageToDisplay.metadata.seed}
|
||||||
|
onClick={() =>
|
||||||
|
dispatch(setSeed(imageToDisplay.metadata.seed!))
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<SDButton
|
||||||
|
label='Upscale'
|
||||||
|
colorScheme={'gray'}
|
||||||
|
flexGrow={1}
|
||||||
|
variant={'outline'}
|
||||||
|
isDisabled={
|
||||||
|
!isESRGANAvailable ||
|
||||||
|
Boolean(intermediateImage) ||
|
||||||
|
!(isConnected && !isProcessing)
|
||||||
|
}
|
||||||
|
onClick={() => dispatch(runESRGAN(imageToDisplay))}
|
||||||
|
/>
|
||||||
|
<SDButton
|
||||||
|
label='Fix faces'
|
||||||
|
colorScheme={'gray'}
|
||||||
|
flexGrow={1}
|
||||||
|
variant={'outline'}
|
||||||
|
isDisabled={
|
||||||
|
!isGFPGANAvailable ||
|
||||||
|
Boolean(intermediateImage) ||
|
||||||
|
!(isConnected && !isProcessing)
|
||||||
|
}
|
||||||
|
onClick={() => dispatch(runGFPGAN(imageToDisplay))}
|
||||||
|
/>
|
||||||
|
<SDButton
|
||||||
|
label='Details'
|
||||||
|
colorScheme={'gray'}
|
||||||
|
variant={shouldShowImageDetails ? 'solid' : 'outline'}
|
||||||
|
borderWidth={1}
|
||||||
|
flexGrow={1}
|
||||||
|
onClick={() =>
|
||||||
|
setShouldShowImageDetails(!shouldShowImageDetails)
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
<DeleteImageModalButton image={imageToDisplay}>
|
||||||
|
<SDButton
|
||||||
|
label='Delete'
|
||||||
|
colorScheme={'red'}
|
||||||
|
flexGrow={1}
|
||||||
|
variant={'outline'}
|
||||||
|
isDisabled={Boolean(intermediateImage)}
|
||||||
|
/>
|
||||||
|
</DeleteImageModalButton>
|
||||||
|
</Flex>
|
||||||
|
)}
|
||||||
|
<Center height={height} position={'relative'}>
|
||||||
|
{imageToDisplay && (
|
||||||
|
<Image
|
||||||
|
src={imageToDisplay.url}
|
||||||
|
fit='contain'
|
||||||
|
maxWidth={'100%'}
|
||||||
|
maxHeight={'100%'}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{imageToDisplay && shouldShowImageDetails && (
|
||||||
|
<Flex
|
||||||
|
width={'100%'}
|
||||||
|
height={'100%'}
|
||||||
|
position={'absolute'}
|
||||||
|
top={0}
|
||||||
|
left={0}
|
||||||
|
p={3}
|
||||||
|
boxSizing='border-box'
|
||||||
|
backgroundColor={bgColor}
|
||||||
|
overflow='scroll'
|
||||||
|
>
|
||||||
|
<ImageMetadataViewer image={imageToDisplay} />
|
||||||
|
</Flex>
|
||||||
|
)}
|
||||||
|
</Center>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default CurrentImage;
|
94
frontend/src/features/gallery/DeleteImageModalButton.tsx
Normal file
94
frontend/src/features/gallery/DeleteImageModalButton.tsx
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
import {
|
||||||
|
IconButtonProps,
|
||||||
|
Modal,
|
||||||
|
ModalBody,
|
||||||
|
ModalCloseButton,
|
||||||
|
ModalContent,
|
||||||
|
ModalFooter,
|
||||||
|
ModalHeader,
|
||||||
|
ModalOverlay,
|
||||||
|
Text,
|
||||||
|
useDisclosure,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import {
|
||||||
|
cloneElement,
|
||||||
|
ReactElement,
|
||||||
|
SyntheticEvent,
|
||||||
|
} from 'react';
|
||||||
|
import { useAppDispatch, useAppSelector } from '../../app/hooks';
|
||||||
|
import { deleteImage } from '../../app/socketio';
|
||||||
|
import { RootState } from '../../app/store';
|
||||||
|
import SDButton from '../../components/SDButton';
|
||||||
|
import { setShouldConfirmOnDelete, SystemState } from '../system/systemSlice';
|
||||||
|
import { SDImage } from './gallerySlice';
|
||||||
|
|
||||||
|
interface Props extends IconButtonProps {
|
||||||
|
image: SDImage;
|
||||||
|
'aria-label': string;
|
||||||
|
children: ReactElement;
|
||||||
|
}
|
||||||
|
|
||||||
|
const systemSelector = createSelector(
|
||||||
|
(state: RootState) => state.system,
|
||||||
|
(system: SystemState) => system.shouldConfirmOnDelete
|
||||||
|
);
|
||||||
|
|
||||||
|
/*
|
||||||
|
TODO: The modal and button to open it should be two different components,
|
||||||
|
but their state is closely related and I'm not sure how best to accomplish it.
|
||||||
|
*/
|
||||||
|
const DeleteImageModalButton = (props: Omit<Props, 'aria-label'>) => {
|
||||||
|
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const shouldConfirmOnDelete = useAppSelector(systemSelector);
|
||||||
|
|
||||||
|
const handleClickDelete = (e: SyntheticEvent) => {
|
||||||
|
e.stopPropagation();
|
||||||
|
shouldConfirmOnDelete ? onOpen() : handleDelete();
|
||||||
|
};
|
||||||
|
|
||||||
|
const { image, children } = props;
|
||||||
|
|
||||||
|
const handleDelete = () => {
|
||||||
|
dispatch(deleteImage(image));
|
||||||
|
onClose();
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleDeleteAndDontAsk = () => {
|
||||||
|
dispatch(deleteImage(image));
|
||||||
|
dispatch(setShouldConfirmOnDelete(false));
|
||||||
|
onClose();
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
{cloneElement(children, {
|
||||||
|
onClick: handleClickDelete,
|
||||||
|
})}
|
||||||
|
|
||||||
|
<Modal isOpen={isOpen} onClose={onClose}>
|
||||||
|
<ModalOverlay />
|
||||||
|
<ModalContent>
|
||||||
|
<ModalHeader>Are you sure you want to delete this image?</ModalHeader>
|
||||||
|
<ModalCloseButton />
|
||||||
|
<ModalBody>
|
||||||
|
<Text>It will be deleted forever!</Text>
|
||||||
|
</ModalBody>
|
||||||
|
|
||||||
|
<ModalFooter justifyContent={'space-between'}>
|
||||||
|
<SDButton label={'Yes'} colorScheme='red' onClick={handleDelete} />
|
||||||
|
<SDButton
|
||||||
|
label={"Yes, and don't ask me again"}
|
||||||
|
colorScheme='red'
|
||||||
|
onClick={handleDeleteAndDontAsk}
|
||||||
|
/>
|
||||||
|
<SDButton label='Cancel' colorScheme='blue' onClick={onClose} />
|
||||||
|
</ModalFooter>
|
||||||
|
</ModalContent>
|
||||||
|
</Modal>
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default DeleteImageModalButton;
|
124
frontend/src/features/gallery/ImageMetadataViewer.tsx
Normal file
124
frontend/src/features/gallery/ImageMetadataViewer.tsx
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
import {
|
||||||
|
Center,
|
||||||
|
Flex,
|
||||||
|
IconButton,
|
||||||
|
Link,
|
||||||
|
List,
|
||||||
|
ListItem,
|
||||||
|
Text,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
import { FaPlus } from 'react-icons/fa';
|
||||||
|
import { PARAMETERS } from '../../app/constants';
|
||||||
|
import { useAppDispatch } from '../../app/hooks';
|
||||||
|
import SDButton from '../../components/SDButton';
|
||||||
|
import { setAllParameters, setParameter } from '../sd/sdSlice';
|
||||||
|
import { SDImage, SDMetadata } from './gallerySlice';
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
image: SDImage;
|
||||||
|
};
|
||||||
|
|
||||||
|
const ImageMetadataViewer = ({ image }: Props) => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const keys = Object.keys(PARAMETERS);
|
||||||
|
|
||||||
|
const metadata: Array<{
|
||||||
|
label: string;
|
||||||
|
key: string;
|
||||||
|
value: string | number | boolean;
|
||||||
|
}> = [];
|
||||||
|
|
||||||
|
keys.forEach((key) => {
|
||||||
|
const value = image.metadata[key as keyof SDMetadata];
|
||||||
|
if (value !== undefined) {
|
||||||
|
metadata.push({ label: PARAMETERS[key], key, value });
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex gap={2} direction={'column'} overflowY={'scroll'} width={'100%'}>
|
||||||
|
<SDButton
|
||||||
|
label='Use all parameters'
|
||||||
|
colorScheme={'gray'}
|
||||||
|
padding={2}
|
||||||
|
isDisabled={metadata.length === 0}
|
||||||
|
onClick={() => dispatch(setAllParameters(image.metadata))}
|
||||||
|
/>
|
||||||
|
<Flex gap={2}>
|
||||||
|
<Text fontWeight={'semibold'}>File:</Text>
|
||||||
|
<Link href={image.url} isExternal>
|
||||||
|
<Text>{image.url}</Text>
|
||||||
|
</Link>
|
||||||
|
</Flex>
|
||||||
|
{metadata.length ? (
|
||||||
|
<>
|
||||||
|
<List>
|
||||||
|
{metadata.map((parameter, i) => {
|
||||||
|
const { label, key, value } = parameter;
|
||||||
|
return (
|
||||||
|
<ListItem key={i} pb={1}>
|
||||||
|
<Flex gap={2}>
|
||||||
|
<IconButton
|
||||||
|
aria-label='Use this parameter'
|
||||||
|
icon={<FaPlus />}
|
||||||
|
size={'xs'}
|
||||||
|
onClick={() =>
|
||||||
|
dispatch(
|
||||||
|
setParameter({
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
})
|
||||||
|
)
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
<Text fontWeight={'semibold'}>
|
||||||
|
{label}:
|
||||||
|
</Text>
|
||||||
|
|
||||||
|
{value === undefined ||
|
||||||
|
value === null ||
|
||||||
|
value === '' ||
|
||||||
|
value === 0 ? (
|
||||||
|
<Text
|
||||||
|
maxHeight={100}
|
||||||
|
fontStyle={'italic'}
|
||||||
|
>
|
||||||
|
None
|
||||||
|
</Text>
|
||||||
|
) : (
|
||||||
|
<Text
|
||||||
|
maxHeight={100}
|
||||||
|
overflowY={'scroll'}
|
||||||
|
>
|
||||||
|
{value.toString()}
|
||||||
|
</Text>
|
||||||
|
)}
|
||||||
|
</Flex>
|
||||||
|
</ListItem>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
</List>
|
||||||
|
<Flex gap={2}>
|
||||||
|
<Text fontWeight={'semibold'}>Raw:</Text>
|
||||||
|
<Text
|
||||||
|
maxHeight={100}
|
||||||
|
overflowY={'scroll'}
|
||||||
|
wordBreak={'break-all'}
|
||||||
|
>
|
||||||
|
{JSON.stringify(image.metadata)}
|
||||||
|
</Text>
|
||||||
|
</Flex>
|
||||||
|
</>
|
||||||
|
) : (
|
||||||
|
<Center width={'100%'} pt={10}>
|
||||||
|
<Text fontSize={'lg'} fontWeight='semibold'>
|
||||||
|
No metadata available
|
||||||
|
</Text>
|
||||||
|
</Center>
|
||||||
|
)}
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default ImageMetadataViewer;
|
150
frontend/src/features/gallery/ImageRoll.tsx
Normal file
150
frontend/src/features/gallery/ImageRoll.tsx
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
import {
|
||||||
|
Box,
|
||||||
|
Flex,
|
||||||
|
Icon,
|
||||||
|
IconButton,
|
||||||
|
Image,
|
||||||
|
useColorModeValue,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
import { RootState } from '../../app/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from '../../app/hooks';
|
||||||
|
import { SDImage, setCurrentImage } from './gallerySlice';
|
||||||
|
import { FaCheck, FaCopy, FaSeedling, FaTrash } from 'react-icons/fa';
|
||||||
|
import DeleteImageModalButton from './DeleteImageModalButton';
|
||||||
|
import { memo, SyntheticEvent, useState } from 'react';
|
||||||
|
import { setAllParameters, setSeed } from '../sd/sdSlice';
|
||||||
|
|
||||||
|
interface HoverableImageProps {
|
||||||
|
image: SDImage;
|
||||||
|
isSelected: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
const HoverableImage = memo(
|
||||||
|
(props: HoverableImageProps) => {
|
||||||
|
const [isHovered, setIsHovered] = useState<boolean>(false);
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const checkColor = useColorModeValue('green.600', 'green.300');
|
||||||
|
const bgColor = useColorModeValue('gray.200', 'gray.700');
|
||||||
|
const bgGradient = useColorModeValue(
|
||||||
|
'radial-gradient(circle, rgba(255,255,255,0.7) 0%, rgba(255,255,255,0.7) 20%, rgba(0,0,0,0) 100%)',
|
||||||
|
'radial-gradient(circle, rgba(0,0,0,0.7) 0%, rgba(0,0,0,0.7) 20%, rgba(0,0,0,0) 100%)'
|
||||||
|
);
|
||||||
|
|
||||||
|
const { image, isSelected } = props;
|
||||||
|
const { url, uuid, metadata } = image;
|
||||||
|
|
||||||
|
const handleMouseOver = () => setIsHovered(true);
|
||||||
|
const handleMouseOut = () => setIsHovered(false);
|
||||||
|
const handleClickSetAllParameters = (e: SyntheticEvent) => {
|
||||||
|
e.stopPropagation();
|
||||||
|
dispatch(setAllParameters(metadata));
|
||||||
|
};
|
||||||
|
const handleClickSetSeed = (e: SyntheticEvent) => {
|
||||||
|
e.stopPropagation();
|
||||||
|
dispatch(setSeed(image.metadata.seed!)); // component not rendered unless this exists
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Box position={'relative'} key={uuid}>
|
||||||
|
<Image
|
||||||
|
width={120}
|
||||||
|
height={120}
|
||||||
|
objectFit='cover'
|
||||||
|
rounded={'md'}
|
||||||
|
src={url}
|
||||||
|
loading={'lazy'}
|
||||||
|
backgroundColor={bgColor}
|
||||||
|
/>
|
||||||
|
<Flex
|
||||||
|
cursor={'pointer'}
|
||||||
|
position={'absolute'}
|
||||||
|
top={0}
|
||||||
|
left={0}
|
||||||
|
rounded={'md'}
|
||||||
|
width='100%'
|
||||||
|
height='100%'
|
||||||
|
alignItems={'center'}
|
||||||
|
justifyContent={'center'}
|
||||||
|
background={isSelected ? bgGradient : undefined}
|
||||||
|
onClick={() => dispatch(setCurrentImage(image))}
|
||||||
|
onMouseOver={handleMouseOver}
|
||||||
|
onMouseOut={handleMouseOut}
|
||||||
|
>
|
||||||
|
{isSelected && (
|
||||||
|
<Icon
|
||||||
|
fill={checkColor}
|
||||||
|
width={'50%'}
|
||||||
|
height={'50%'}
|
||||||
|
as={FaCheck}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{isHovered && (
|
||||||
|
<Flex
|
||||||
|
direction={'column'}
|
||||||
|
gap={1}
|
||||||
|
position={'absolute'}
|
||||||
|
top={1}
|
||||||
|
right={1}
|
||||||
|
>
|
||||||
|
<DeleteImageModalButton image={image}>
|
||||||
|
<IconButton
|
||||||
|
colorScheme='red'
|
||||||
|
aria-label='Delete image'
|
||||||
|
icon={<FaTrash />}
|
||||||
|
size='xs'
|
||||||
|
fontSize={15}
|
||||||
|
/>
|
||||||
|
</DeleteImageModalButton>
|
||||||
|
<IconButton
|
||||||
|
aria-label='Use all parameters'
|
||||||
|
colorScheme={'blue'}
|
||||||
|
icon={<FaCopy />}
|
||||||
|
size='xs'
|
||||||
|
fontSize={15}
|
||||||
|
onClickCapture={handleClickSetAllParameters}
|
||||||
|
/>
|
||||||
|
{image.metadata.seed && (
|
||||||
|
<IconButton
|
||||||
|
aria-label='Use seed'
|
||||||
|
colorScheme={'blue'}
|
||||||
|
icon={<FaSeedling />}
|
||||||
|
size='xs'
|
||||||
|
fontSize={16}
|
||||||
|
onClickCapture={handleClickSetSeed}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</Flex>
|
||||||
|
)}
|
||||||
|
</Flex>
|
||||||
|
</Box>
|
||||||
|
);
|
||||||
|
},
|
||||||
|
(prev, next) =>
|
||||||
|
prev.image.uuid === next.image.uuid &&
|
||||||
|
prev.isSelected === next.isSelected
|
||||||
|
);
|
||||||
|
|
||||||
|
const ImageRoll = () => {
|
||||||
|
const { images, currentImageUuid } = useAppSelector(
|
||||||
|
(state: RootState) => state.gallery
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex gap={2} wrap='wrap' pb={2}>
|
||||||
|
{[...images].reverse().map((image) => {
|
||||||
|
const { uuid } = image;
|
||||||
|
const isSelected = currentImageUuid === uuid;
|
||||||
|
return (
|
||||||
|
<HoverableImage
|
||||||
|
key={uuid}
|
||||||
|
image={image}
|
||||||
|
isSelected={isSelected}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default ImageRoll;
|
144
frontend/src/features/gallery/gallerySlice.ts
Normal file
144
frontend/src/features/gallery/gallerySlice.ts
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
import { createSlice } from '@reduxjs/toolkit';
|
||||||
|
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||||
|
import { v4 as uuidv4 } from 'uuid';
|
||||||
|
import { UpscalingLevel } from '../sd/sdSlice';
|
||||||
|
import { backendToFrontendParameters } from '../../app/parameterTranslation';
|
||||||
|
|
||||||
|
// TODO: Revise pending metadata RFC: https://github.com/lstein/stable-diffusion/issues/266
|
||||||
|
export interface SDMetadata {
|
||||||
|
prompt?: string;
|
||||||
|
steps?: number;
|
||||||
|
cfgScale?: number;
|
||||||
|
height?: number;
|
||||||
|
width?: number;
|
||||||
|
sampler?: string;
|
||||||
|
seed?: number;
|
||||||
|
img2imgStrength?: number;
|
||||||
|
gfpganStrength?: number;
|
||||||
|
upscalingLevel?: UpscalingLevel;
|
||||||
|
upscalingStrength?: number;
|
||||||
|
initialImagePath?: string;
|
||||||
|
maskPath?: string;
|
||||||
|
seamless?: boolean;
|
||||||
|
shouldFitToWidthHeight?: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface SDImage {
|
||||||
|
// TODO: I have installed @types/uuid but cannot figure out how to use them here.
|
||||||
|
uuid: string;
|
||||||
|
url: string;
|
||||||
|
metadata: SDMetadata;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface GalleryState {
|
||||||
|
currentImageUuid: string;
|
||||||
|
images: Array<SDImage>;
|
||||||
|
intermediateImage?: SDImage;
|
||||||
|
currentImage?: SDImage;
|
||||||
|
}
|
||||||
|
|
||||||
|
const initialState: GalleryState = {
|
||||||
|
currentImageUuid: '',
|
||||||
|
images: [],
|
||||||
|
};
|
||||||
|
|
||||||
|
export const gallerySlice = createSlice({
|
||||||
|
name: 'gallery',
|
||||||
|
initialState,
|
||||||
|
reducers: {
|
||||||
|
setCurrentImage: (state, action: PayloadAction<SDImage>) => {
|
||||||
|
state.currentImage = action.payload;
|
||||||
|
state.currentImageUuid = action.payload.uuid;
|
||||||
|
},
|
||||||
|
removeImage: (state, action: PayloadAction<SDImage>) => {
|
||||||
|
const { uuid } = action.payload;
|
||||||
|
|
||||||
|
const newImages = state.images.filter((image) => image.uuid !== uuid);
|
||||||
|
|
||||||
|
const imageToDeleteIndex = state.images.findIndex(
|
||||||
|
(image) => image.uuid === uuid
|
||||||
|
);
|
||||||
|
|
||||||
|
const newCurrentImageIndex = Math.min(
|
||||||
|
Math.max(imageToDeleteIndex, 0),
|
||||||
|
newImages.length - 1
|
||||||
|
);
|
||||||
|
|
||||||
|
state.images = newImages;
|
||||||
|
|
||||||
|
state.currentImage = newImages.length
|
||||||
|
? newImages[newCurrentImageIndex]
|
||||||
|
: undefined;
|
||||||
|
|
||||||
|
state.currentImageUuid = newImages.length
|
||||||
|
? newImages[newCurrentImageIndex].uuid
|
||||||
|
: '';
|
||||||
|
},
|
||||||
|
addImage: (state, action: PayloadAction<SDImage>) => {
|
||||||
|
state.images.push(action.payload);
|
||||||
|
state.currentImageUuid = action.payload.uuid;
|
||||||
|
state.intermediateImage = undefined;
|
||||||
|
state.currentImage = action.payload;
|
||||||
|
},
|
||||||
|
setIntermediateImage: (state, action: PayloadAction<SDImage>) => {
|
||||||
|
state.intermediateImage = action.payload;
|
||||||
|
},
|
||||||
|
clearIntermediateImage: (state) => {
|
||||||
|
state.intermediateImage = undefined;
|
||||||
|
},
|
||||||
|
setGalleryImages: (
|
||||||
|
state,
|
||||||
|
action: PayloadAction<
|
||||||
|
Array<{
|
||||||
|
path: string;
|
||||||
|
metadata: { [key: string]: string | number | boolean };
|
||||||
|
}>
|
||||||
|
>
|
||||||
|
) => {
|
||||||
|
// TODO: Revise pending metadata RFC: https://github.com/lstein/stable-diffusion/issues/266
|
||||||
|
const images = action.payload;
|
||||||
|
|
||||||
|
if (images.length === 0) {
|
||||||
|
// there are no images on disk, clear the gallery
|
||||||
|
state.images = [];
|
||||||
|
state.currentImageUuid = '';
|
||||||
|
state.currentImage = undefined;
|
||||||
|
} else {
|
||||||
|
// Filter image urls that are already in the rehydrated state
|
||||||
|
const filteredImages = action.payload.filter(
|
||||||
|
(image) => !state.images.find((i) => i.url === image.path)
|
||||||
|
);
|
||||||
|
|
||||||
|
const preparedImages = filteredImages.map((image): SDImage => {
|
||||||
|
return {
|
||||||
|
uuid: uuidv4(),
|
||||||
|
url: image.path,
|
||||||
|
metadata: backendToFrontendParameters(image.metadata),
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
const newImages = [...state.images].concat(preparedImages);
|
||||||
|
|
||||||
|
// if previous currentimage no longer exists, set a new one
|
||||||
|
if (!newImages.find((image) => image.uuid === state.currentImageUuid)) {
|
||||||
|
const newCurrentImage = newImages[newImages.length - 1];
|
||||||
|
state.currentImage = newCurrentImage;
|
||||||
|
state.currentImageUuid = newCurrentImage.uuid;
|
||||||
|
}
|
||||||
|
|
||||||
|
state.images = newImages;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
export const {
|
||||||
|
setCurrentImage,
|
||||||
|
removeImage,
|
||||||
|
addImage,
|
||||||
|
setGalleryImages,
|
||||||
|
setIntermediateImage,
|
||||||
|
clearIntermediateImage,
|
||||||
|
} = gallerySlice.actions;
|
||||||
|
|
||||||
|
export default gallerySlice.reducer;
|
35
frontend/src/features/header/ProgressBar.tsx
Normal file
35
frontend/src/features/header/ProgressBar.tsx
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
import { Progress } from '@chakra-ui/react';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { isEqual } from 'lodash';
|
||||||
|
import { useAppSelector } from '../../app/hooks';
|
||||||
|
import { RootState } from '../../app/store';
|
||||||
|
import { SDState } from '../sd/sdSlice';
|
||||||
|
|
||||||
|
const sdSelector = createSelector(
|
||||||
|
(state: RootState) => state.sd,
|
||||||
|
(sd: SDState) => {
|
||||||
|
return {
|
||||||
|
realSteps: sd.realSteps,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
{
|
||||||
|
memoizeOptions: {
|
||||||
|
resultEqualityCheck: isEqual,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
const ProgressBar = () => {
|
||||||
|
const { realSteps } = useAppSelector(sdSelector);
|
||||||
|
const { currentStep } = useAppSelector((state: RootState) => state.system);
|
||||||
|
const progress = Math.round((currentStep * 100) / realSteps);
|
||||||
|
return (
|
||||||
|
<Progress
|
||||||
|
height='10px'
|
||||||
|
value={progress}
|
||||||
|
isIndeterminate={progress < 0 || currentStep === realSteps}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default ProgressBar;
|
93
frontend/src/features/header/SiteHeader.tsx
Normal file
93
frontend/src/features/header/SiteHeader.tsx
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
import {
|
||||||
|
Flex,
|
||||||
|
Heading,
|
||||||
|
IconButton,
|
||||||
|
Link,
|
||||||
|
Spacer,
|
||||||
|
Text,
|
||||||
|
useColorMode,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { isEqual } from 'lodash';
|
||||||
|
|
||||||
|
import { FaSun, FaMoon, FaGithub } from 'react-icons/fa';
|
||||||
|
import { MdHelp, MdSettings } from 'react-icons/md';
|
||||||
|
import { useAppSelector } from '../../app/hooks';
|
||||||
|
import { RootState } from '../../app/store';
|
||||||
|
import SettingsModal from '../system/SettingsModal';
|
||||||
|
import { SystemState } from '../system/systemSlice';
|
||||||
|
|
||||||
|
const systemSelector = createSelector(
|
||||||
|
(state: RootState) => state.system,
|
||||||
|
(system: SystemState) => {
|
||||||
|
return { isConnected: system.isConnected };
|
||||||
|
},
|
||||||
|
{
|
||||||
|
memoizeOptions: { resultEqualityCheck: isEqual },
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
const SiteHeader = () => {
|
||||||
|
const { colorMode, toggleColorMode } = useColorMode();
|
||||||
|
const { isConnected } = useAppSelector(systemSelector);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex minWidth='max-content' alignItems='center' gap='1' pl={2} pr={1}>
|
||||||
|
<Heading size={'lg'}>Stable Diffusion Dream Server</Heading>
|
||||||
|
|
||||||
|
<Spacer />
|
||||||
|
|
||||||
|
<Text textColor={isConnected ? 'green.500' : 'red.500'}>
|
||||||
|
{isConnected ? `Connected to server` : 'No connection to server'}
|
||||||
|
</Text>
|
||||||
|
|
||||||
|
<SettingsModal>
|
||||||
|
<IconButton
|
||||||
|
aria-label='Settings'
|
||||||
|
variant='link'
|
||||||
|
fontSize={24}
|
||||||
|
size={'sm'}
|
||||||
|
icon={<MdSettings />}
|
||||||
|
/>
|
||||||
|
</SettingsModal>
|
||||||
|
|
||||||
|
<IconButton
|
||||||
|
aria-label='Link to Github Issues'
|
||||||
|
variant='link'
|
||||||
|
fontSize={23}
|
||||||
|
size={'sm'}
|
||||||
|
icon={
|
||||||
|
<Link
|
||||||
|
isExternal
|
||||||
|
href='http://github.com/lstein/stable-diffusion/issues'
|
||||||
|
>
|
||||||
|
<MdHelp />
|
||||||
|
</Link>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<IconButton
|
||||||
|
aria-label='Link to Github Repo'
|
||||||
|
variant='link'
|
||||||
|
fontSize={20}
|
||||||
|
size={'sm'}
|
||||||
|
icon={
|
||||||
|
<Link isExternal href='http://github.com/lstein/stable-diffusion'>
|
||||||
|
<FaGithub />
|
||||||
|
</Link>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<IconButton
|
||||||
|
aria-label='Toggle Dark Mode'
|
||||||
|
onClick={toggleColorMode}
|
||||||
|
variant='link'
|
||||||
|
size={'sm'}
|
||||||
|
fontSize={colorMode == 'light' ? 18 : 20}
|
||||||
|
icon={colorMode == 'light' ? <FaMoon /> : <FaSun />}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default SiteHeader;
|
84
frontend/src/features/sd/ESRGANOptions.tsx
Normal file
84
frontend/src/features/sd/ESRGANOptions.tsx
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
import { Flex } from '@chakra-ui/react';
|
||||||
|
|
||||||
|
import { RootState } from '../../app/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from '../../app/hooks';
|
||||||
|
|
||||||
|
import {
|
||||||
|
setUpscalingLevel,
|
||||||
|
setUpscalingStrength,
|
||||||
|
UpscalingLevel,
|
||||||
|
SDState,
|
||||||
|
} from '../sd/sdSlice';
|
||||||
|
|
||||||
|
import SDNumberInput from '../../components/SDNumberInput';
|
||||||
|
import SDSelect from '../../components/SDSelect';
|
||||||
|
|
||||||
|
import { UPSCALING_LEVELS } from '../../app/constants';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { isEqual } from 'lodash';
|
||||||
|
import { SystemState } from '../system/systemSlice';
|
||||||
|
|
||||||
|
const sdSelector = createSelector(
|
||||||
|
(state: RootState) => state.sd,
|
||||||
|
(sd: SDState) => {
|
||||||
|
return {
|
||||||
|
upscalingLevel: sd.upscalingLevel,
|
||||||
|
upscalingStrength: sd.upscalingStrength,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
{
|
||||||
|
memoizeOptions: {
|
||||||
|
resultEqualityCheck: isEqual,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
const systemSelector = createSelector(
|
||||||
|
(state: RootState) => state.system,
|
||||||
|
(system: SystemState) => {
|
||||||
|
return {
|
||||||
|
isESRGANAvailable: system.isESRGANAvailable,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
{
|
||||||
|
memoizeOptions: {
|
||||||
|
resultEqualityCheck: isEqual,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
||||||
|
const ESRGANOptions = () => {
|
||||||
|
const { upscalingLevel, upscalingStrength } = useAppSelector(sdSelector);
|
||||||
|
|
||||||
|
const { isESRGANAvailable } = useAppSelector(systemSelector);
|
||||||
|
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex direction={'column'} gap={2}>
|
||||||
|
<SDSelect
|
||||||
|
isDisabled={!isESRGANAvailable}
|
||||||
|
label='Scale'
|
||||||
|
value={upscalingLevel}
|
||||||
|
onChange={(e) =>
|
||||||
|
dispatch(
|
||||||
|
setUpscalingLevel(
|
||||||
|
Number(e.target.value) as UpscalingLevel
|
||||||
|
)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
validValues={UPSCALING_LEVELS}
|
||||||
|
/>
|
||||||
|
<SDNumberInput
|
||||||
|
isDisabled={!isESRGANAvailable}
|
||||||
|
label='Strength'
|
||||||
|
step={0.05}
|
||||||
|
min={0}
|
||||||
|
max={1}
|
||||||
|
onChange={(v) => dispatch(setUpscalingStrength(Number(v)))}
|
||||||
|
value={upscalingStrength}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default ESRGANOptions;
|
63
frontend/src/features/sd/GFPGANOptions.tsx
Normal file
63
frontend/src/features/sd/GFPGANOptions.tsx
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
import { Flex } from '@chakra-ui/react';
|
||||||
|
|
||||||
|
import { RootState } from '../../app/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from '../../app/hooks';
|
||||||
|
|
||||||
|
import { SDState, setGfpganStrength } from '../sd/sdSlice';
|
||||||
|
|
||||||
|
import SDNumberInput from '../../components/SDNumberInput';
|
||||||
|
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { isEqual } from 'lodash';
|
||||||
|
import { SystemState } from '../system/systemSlice';
|
||||||
|
|
||||||
|
const sdSelector = createSelector(
|
||||||
|
(state: RootState) => state.sd,
|
||||||
|
(sd: SDState) => {
|
||||||
|
return {
|
||||||
|
gfpganStrength: sd.gfpganStrength,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
{
|
||||||
|
memoizeOptions: {
|
||||||
|
resultEqualityCheck: isEqual,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
const systemSelector = createSelector(
|
||||||
|
(state: RootState) => state.system,
|
||||||
|
(system: SystemState) => {
|
||||||
|
return {
|
||||||
|
isGFPGANAvailable: system.isGFPGANAvailable,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
{
|
||||||
|
memoizeOptions: {
|
||||||
|
resultEqualityCheck: isEqual,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
||||||
|
const GFPGANOptions = () => {
|
||||||
|
const { gfpganStrength } = useAppSelector(sdSelector);
|
||||||
|
|
||||||
|
const { isGFPGANAvailable } = useAppSelector(systemSelector);
|
||||||
|
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex direction={'column'} gap={2}>
|
||||||
|
<SDNumberInput
|
||||||
|
isDisabled={!isGFPGANAvailable}
|
||||||
|
label='Strength'
|
||||||
|
step={0.05}
|
||||||
|
min={0}
|
||||||
|
max={1}
|
||||||
|
onChange={(v) => dispatch(setGfpganStrength(Number(v)))}
|
||||||
|
value={gfpganStrength}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default GFPGANOptions;
|
54
frontend/src/features/sd/ImageToImageOptions.tsx
Normal file
54
frontend/src/features/sd/ImageToImageOptions.tsx
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
import { Flex } from '@chakra-ui/react';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { useAppDispatch, useAppSelector } from '../../app/hooks';
|
||||||
|
import { RootState } from '../../app/store';
|
||||||
|
import SDNumberInput from '../../components/SDNumberInput';
|
||||||
|
import SDSwitch from '../../components/SDSwitch';
|
||||||
|
import InitImage from './InitImage';
|
||||||
|
import {
|
||||||
|
SDState,
|
||||||
|
setImg2imgStrength,
|
||||||
|
setShouldFitToWidthHeight,
|
||||||
|
} from './sdSlice';
|
||||||
|
|
||||||
|
const sdSelector = createSelector(
|
||||||
|
(state: RootState) => state.sd,
|
||||||
|
(sd: SDState) => {
|
||||||
|
return {
|
||||||
|
initialImagePath: sd.initialImagePath,
|
||||||
|
img2imgStrength: sd.img2imgStrength,
|
||||||
|
shouldFitToWidthHeight: sd.shouldFitToWidthHeight,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
const ImageToImageOptions = () => {
|
||||||
|
const { initialImagePath, img2imgStrength, shouldFitToWidthHeight } =
|
||||||
|
useAppSelector(sdSelector);
|
||||||
|
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
return (
|
||||||
|
<Flex direction={'column'} gap={2}>
|
||||||
|
<SDNumberInput
|
||||||
|
isDisabled={!initialImagePath}
|
||||||
|
label='Strength'
|
||||||
|
step={0.01}
|
||||||
|
min={0}
|
||||||
|
max={1}
|
||||||
|
onChange={(v) => dispatch(setImg2imgStrength(Number(v)))}
|
||||||
|
value={img2imgStrength}
|
||||||
|
/>
|
||||||
|
<SDSwitch
|
||||||
|
isDisabled={!initialImagePath}
|
||||||
|
label='Fit initial image to output size'
|
||||||
|
isChecked={shouldFitToWidthHeight}
|
||||||
|
onChange={(e) =>
|
||||||
|
dispatch(setShouldFitToWidthHeight(e.target.checked))
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
<InitImage />
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default ImageToImageOptions;
|
20
frontend/src/features/sd/InitImage.css
Normal file
20
frontend/src/features/sd/InitImage.css
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
.checkerboard {
|
||||||
|
background-position: 0px 0px, 10px 10px;
|
||||||
|
background-size: 20px 20px;
|
||||||
|
background-image: linear-gradient(
|
||||||
|
45deg,
|
||||||
|
#eee 25%,
|
||||||
|
transparent 25%,
|
||||||
|
transparent 75%,
|
||||||
|
#eee 75%,
|
||||||
|
#eee 100%
|
||||||
|
),
|
||||||
|
linear-gradient(
|
||||||
|
45deg,
|
||||||
|
#eee 25%,
|
||||||
|
white 25%,
|
||||||
|
white 75%,
|
||||||
|
#eee 75%,
|
||||||
|
#eee 100%
|
||||||
|
);
|
||||||
|
}
|
155
frontend/src/features/sd/InitImage.tsx
Normal file
155
frontend/src/features/sd/InitImage.tsx
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
import {
|
||||||
|
Button,
|
||||||
|
Flex,
|
||||||
|
IconButton,
|
||||||
|
Image,
|
||||||
|
useToast,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
import { SyntheticEvent, useCallback, useState } from 'react';
|
||||||
|
import { FileRejection, useDropzone } from 'react-dropzone';
|
||||||
|
import { FaTrash } from 'react-icons/fa';
|
||||||
|
import { useAppDispatch, useAppSelector } from '../../app/hooks';
|
||||||
|
import { RootState } from '../../app/store';
|
||||||
|
import {
|
||||||
|
SDState,
|
||||||
|
setInitialImagePath,
|
||||||
|
setMaskPath,
|
||||||
|
} from '../../features/sd/sdSlice';
|
||||||
|
import MaskUploader from './MaskUploader';
|
||||||
|
import './InitImage.css';
|
||||||
|
import { uploadInitialImage } from '../../app/socketio';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { isEqual } from 'lodash';
|
||||||
|
|
||||||
|
const sdSelector = createSelector(
|
||||||
|
(state: RootState) => state.sd,
|
||||||
|
(sd: SDState) => {
|
||||||
|
return {
|
||||||
|
initialImagePath: sd.initialImagePath,
|
||||||
|
maskPath: sd.maskPath,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
{ memoizeOptions: { resultEqualityCheck: isEqual } }
|
||||||
|
);
|
||||||
|
|
||||||
|
const InitImage = () => {
|
||||||
|
const toast = useToast();
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { initialImagePath, maskPath } = useAppSelector(sdSelector);
|
||||||
|
|
||||||
|
const onDrop = useCallback(
|
||||||
|
(acceptedFiles: Array<File>, fileRejections: Array<FileRejection>) => {
|
||||||
|
fileRejections.forEach((rejection: FileRejection) => {
|
||||||
|
const msg = rejection.errors.reduce(
|
||||||
|
(acc: string, cur: { message: string }) => acc + '\n' + cur.message,
|
||||||
|
''
|
||||||
|
);
|
||||||
|
|
||||||
|
toast({
|
||||||
|
title: 'Upload failed',
|
||||||
|
description: msg,
|
||||||
|
status: 'error',
|
||||||
|
isClosable: true,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
acceptedFiles.forEach((file: File) => {
|
||||||
|
dispatch(uploadInitialImage(file));
|
||||||
|
});
|
||||||
|
},
|
||||||
|
[dispatch, toast]
|
||||||
|
);
|
||||||
|
|
||||||
|
const { getRootProps, getInputProps, open } = useDropzone({
|
||||||
|
onDrop,
|
||||||
|
accept: {
|
||||||
|
'image/jpeg': ['.jpg', '.jpeg', '.png'],
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const [shouldShowMask, setShouldShowMask] = useState<boolean>(false);
|
||||||
|
const handleClickUploadIcon = (e: SyntheticEvent) => {
|
||||||
|
e.stopPropagation();
|
||||||
|
open();
|
||||||
|
};
|
||||||
|
const handleClickResetInitialImageAndMask = (e: SyntheticEvent) => {
|
||||||
|
e.stopPropagation();
|
||||||
|
dispatch(setInitialImagePath(''));
|
||||||
|
dispatch(setMaskPath(''));
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleMouseOverInitialImageUploadButton = () =>
|
||||||
|
setShouldShowMask(false);
|
||||||
|
const handleMouseOutInitialImageUploadButton = () => setShouldShowMask(true);
|
||||||
|
|
||||||
|
const handleMouseOverMaskUploadButton = () => setShouldShowMask(true);
|
||||||
|
const handleMouseOutMaskUploadButton = () => setShouldShowMask(true);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex
|
||||||
|
{...getRootProps({
|
||||||
|
onClick: initialImagePath ? (e) => e.stopPropagation() : undefined,
|
||||||
|
})}
|
||||||
|
direction={'column'}
|
||||||
|
alignItems={'center'}
|
||||||
|
gap={2}
|
||||||
|
>
|
||||||
|
<input {...getInputProps({ multiple: false })} />
|
||||||
|
<Flex gap={2} justifyContent={'space-between'} width={'100%'}>
|
||||||
|
<Button
|
||||||
|
size={'sm'}
|
||||||
|
fontSize={'md'}
|
||||||
|
fontWeight={'normal'}
|
||||||
|
onClick={handleClickUploadIcon}
|
||||||
|
onMouseOver={handleMouseOverInitialImageUploadButton}
|
||||||
|
onMouseOut={handleMouseOutInitialImageUploadButton}
|
||||||
|
>
|
||||||
|
Upload Image
|
||||||
|
</Button>
|
||||||
|
|
||||||
|
<MaskUploader>
|
||||||
|
<Button
|
||||||
|
size={'sm'}
|
||||||
|
fontSize={'md'}
|
||||||
|
fontWeight={'normal'}
|
||||||
|
onClick={handleClickUploadIcon}
|
||||||
|
onMouseOver={handleMouseOverMaskUploadButton}
|
||||||
|
onMouseOut={handleMouseOutMaskUploadButton}
|
||||||
|
>
|
||||||
|
Upload Mask
|
||||||
|
</Button>
|
||||||
|
</MaskUploader>
|
||||||
|
<IconButton
|
||||||
|
size={'sm'}
|
||||||
|
aria-label={'Reset initial image and mask'}
|
||||||
|
onClick={handleClickResetInitialImageAndMask}
|
||||||
|
icon={<FaTrash />}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
{initialImagePath && (
|
||||||
|
<Flex position={'relative'} width={'100%'}>
|
||||||
|
<Image
|
||||||
|
fit={'contain'}
|
||||||
|
src={initialImagePath}
|
||||||
|
rounded={'md'}
|
||||||
|
className={'checkerboard'}
|
||||||
|
/>
|
||||||
|
{shouldShowMask && maskPath && (
|
||||||
|
<Image
|
||||||
|
position={'absolute'}
|
||||||
|
top={0}
|
||||||
|
left={0}
|
||||||
|
fit={'contain'}
|
||||||
|
src={maskPath}
|
||||||
|
rounded={'md'}
|
||||||
|
zIndex={1}
|
||||||
|
className={'checkerboard'}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</Flex>
|
||||||
|
)}
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default InitImage;
|
61
frontend/src/features/sd/MaskUploader.tsx
Normal file
61
frontend/src/features/sd/MaskUploader.tsx
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
import { useToast } from '@chakra-ui/react';
|
||||||
|
import { cloneElement, ReactElement, SyntheticEvent, useCallback } from 'react';
|
||||||
|
import { FileRejection, useDropzone } from 'react-dropzone';
|
||||||
|
import { useAppDispatch } from '../../app/hooks';
|
||||||
|
import { uploadMaskImage } from '../../app/socketio';
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
children: ReactElement;
|
||||||
|
};
|
||||||
|
|
||||||
|
const MaskUploader = ({ children }: Props) => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const toast = useToast();
|
||||||
|
|
||||||
|
const onDrop = useCallback(
|
||||||
|
(acceptedFiles: Array<File>, fileRejections: Array<FileRejection>) => {
|
||||||
|
fileRejections.forEach((rejection: FileRejection) => {
|
||||||
|
const msg = rejection.errors.reduce(
|
||||||
|
(acc: string, cur: { message: string }) =>
|
||||||
|
acc + '\n' + cur.message,
|
||||||
|
''
|
||||||
|
);
|
||||||
|
|
||||||
|
toast({
|
||||||
|
title: 'Upload failed',
|
||||||
|
description: msg,
|
||||||
|
status: 'error',
|
||||||
|
isClosable: true,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
acceptedFiles.forEach((file: File) => {
|
||||||
|
dispatch(uploadMaskImage(file));
|
||||||
|
});
|
||||||
|
},
|
||||||
|
[dispatch, toast]
|
||||||
|
);
|
||||||
|
|
||||||
|
const { getRootProps, getInputProps, open } = useDropzone({
|
||||||
|
onDrop,
|
||||||
|
accept: {
|
||||||
|
'image/jpeg': ['.jpg', '.jpeg', '.png'],
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const handleClickUploadIcon = (e: SyntheticEvent) => {
|
||||||
|
e.stopPropagation();
|
||||||
|
open();
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div {...getRootProps()}>
|
||||||
|
<input {...getInputProps({ multiple: false })} />
|
||||||
|
{cloneElement(children, {
|
||||||
|
onClick: handleClickUploadIcon,
|
||||||
|
})}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default MaskUploader;
|
211
frontend/src/features/sd/OptionsAccordion.tsx
Normal file
211
frontend/src/features/sd/OptionsAccordion.tsx
Normal file
@ -0,0 +1,211 @@
|
|||||||
|
import {
|
||||||
|
Flex,
|
||||||
|
Box,
|
||||||
|
Text,
|
||||||
|
Accordion,
|
||||||
|
AccordionItem,
|
||||||
|
AccordionButton,
|
||||||
|
AccordionIcon,
|
||||||
|
AccordionPanel,
|
||||||
|
Switch,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
|
||||||
|
import { RootState } from '../../app/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from '../../app/hooks';
|
||||||
|
|
||||||
|
import {
|
||||||
|
setShouldRunGFPGAN,
|
||||||
|
setShouldRunESRGAN,
|
||||||
|
SDState,
|
||||||
|
setShouldUseInitImage,
|
||||||
|
} from '../sd/sdSlice';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { isEqual } from 'lodash';
|
||||||
|
import { setOpenAccordions, SystemState } from '../system/systemSlice';
|
||||||
|
import SeedVariationOptions from './SeedVariationOptions';
|
||||||
|
import SamplerOptions from './SamplerOptions';
|
||||||
|
import ESRGANOptions from './ESRGANOptions';
|
||||||
|
import GFPGANOptions from './GFPGANOptions';
|
||||||
|
import OutputOptions from './OutputOptions';
|
||||||
|
import ImageToImageOptions from './ImageToImageOptions';
|
||||||
|
|
||||||
|
const sdSelector = createSelector(
|
||||||
|
(state: RootState) => state.sd,
|
||||||
|
(sd: SDState) => {
|
||||||
|
return {
|
||||||
|
initialImagePath: sd.initialImagePath,
|
||||||
|
shouldUseInitImage: sd.shouldUseInitImage,
|
||||||
|
shouldRunESRGAN: sd.shouldRunESRGAN,
|
||||||
|
shouldRunGFPGAN: sd.shouldRunGFPGAN,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
{
|
||||||
|
memoizeOptions: {
|
||||||
|
resultEqualityCheck: isEqual,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
const systemSelector = createSelector(
|
||||||
|
(state: RootState) => state.system,
|
||||||
|
(system: SystemState) => {
|
||||||
|
return {
|
||||||
|
isGFPGANAvailable: system.isGFPGANAvailable,
|
||||||
|
isESRGANAvailable: system.isESRGANAvailable,
|
||||||
|
openAccordions: system.openAccordions,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
{
|
||||||
|
memoizeOptions: {
|
||||||
|
resultEqualityCheck: isEqual,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
const OptionsAccordion = () => {
|
||||||
|
const {
|
||||||
|
shouldRunESRGAN,
|
||||||
|
shouldRunGFPGAN,
|
||||||
|
shouldUseInitImage,
|
||||||
|
initialImagePath,
|
||||||
|
} = useAppSelector(sdSelector);
|
||||||
|
|
||||||
|
const { isGFPGANAvailable, isESRGANAvailable, openAccordions } =
|
||||||
|
useAppSelector(systemSelector);
|
||||||
|
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Accordion
|
||||||
|
defaultIndex={openAccordions}
|
||||||
|
allowMultiple
|
||||||
|
reduceMotion
|
||||||
|
onChange={(openAccordions) =>
|
||||||
|
dispatch(setOpenAccordions(openAccordions))
|
||||||
|
}
|
||||||
|
>
|
||||||
|
<AccordionItem>
|
||||||
|
<h2>
|
||||||
|
<AccordionButton>
|
||||||
|
<Box flex='1' textAlign='left'>
|
||||||
|
Seed & Variation
|
||||||
|
</Box>
|
||||||
|
<AccordionIcon />
|
||||||
|
</AccordionButton>
|
||||||
|
</h2>
|
||||||
|
<AccordionPanel>
|
||||||
|
<SeedVariationOptions />
|
||||||
|
</AccordionPanel>
|
||||||
|
</AccordionItem>
|
||||||
|
<AccordionItem>
|
||||||
|
<h2>
|
||||||
|
<AccordionButton>
|
||||||
|
<Box flex='1' textAlign='left'>
|
||||||
|
Sampler
|
||||||
|
</Box>
|
||||||
|
<AccordionIcon />
|
||||||
|
</AccordionButton>
|
||||||
|
</h2>
|
||||||
|
<AccordionPanel>
|
||||||
|
<SamplerOptions />
|
||||||
|
</AccordionPanel>
|
||||||
|
</AccordionItem>
|
||||||
|
<AccordionItem>
|
||||||
|
<h2>
|
||||||
|
<AccordionButton>
|
||||||
|
<Flex
|
||||||
|
justifyContent={'space-between'}
|
||||||
|
alignItems={'center'}
|
||||||
|
width={'100%'}
|
||||||
|
mr={2}
|
||||||
|
>
|
||||||
|
<Text>Upscale (ESRGAN)</Text>
|
||||||
|
<Switch
|
||||||
|
isDisabled={!isESRGANAvailable}
|
||||||
|
isChecked={shouldRunESRGAN}
|
||||||
|
onChange={(e) =>
|
||||||
|
dispatch(
|
||||||
|
setShouldRunESRGAN(e.target.checked)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
<AccordionIcon />
|
||||||
|
</AccordionButton>
|
||||||
|
</h2>
|
||||||
|
<AccordionPanel>
|
||||||
|
<ESRGANOptions />
|
||||||
|
</AccordionPanel>
|
||||||
|
</AccordionItem>
|
||||||
|
<AccordionItem>
|
||||||
|
<h2>
|
||||||
|
<AccordionButton>
|
||||||
|
<Flex
|
||||||
|
justifyContent={'space-between'}
|
||||||
|
alignItems={'center'}
|
||||||
|
width={'100%'}
|
||||||
|
mr={2}
|
||||||
|
>
|
||||||
|
<Text>Fix Faces (GFPGAN)</Text>
|
||||||
|
<Switch
|
||||||
|
isDisabled={!isGFPGANAvailable}
|
||||||
|
isChecked={shouldRunGFPGAN}
|
||||||
|
onChange={(e) =>
|
||||||
|
dispatch(
|
||||||
|
setShouldRunGFPGAN(e.target.checked)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
<AccordionIcon />
|
||||||
|
</AccordionButton>
|
||||||
|
</h2>
|
||||||
|
<AccordionPanel>
|
||||||
|
<GFPGANOptions />
|
||||||
|
</AccordionPanel>
|
||||||
|
</AccordionItem>
|
||||||
|
<AccordionItem>
|
||||||
|
<h2>
|
||||||
|
<AccordionButton>
|
||||||
|
<Flex
|
||||||
|
justifyContent={'space-between'}
|
||||||
|
alignItems={'center'}
|
||||||
|
width={'100%'}
|
||||||
|
mr={2}
|
||||||
|
>
|
||||||
|
<Text>Image to Image</Text>
|
||||||
|
<Switch
|
||||||
|
isDisabled={!initialImagePath}
|
||||||
|
isChecked={shouldUseInitImage}
|
||||||
|
onChange={(e) =>
|
||||||
|
dispatch(
|
||||||
|
setShouldUseInitImage(e.target.checked)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
<AccordionIcon />
|
||||||
|
</AccordionButton>
|
||||||
|
</h2>
|
||||||
|
<AccordionPanel>
|
||||||
|
<ImageToImageOptions />
|
||||||
|
</AccordionPanel>
|
||||||
|
</AccordionItem>
|
||||||
|
<AccordionItem>
|
||||||
|
<h2>
|
||||||
|
<AccordionButton>
|
||||||
|
<Box flex='1' textAlign='left'>
|
||||||
|
Output
|
||||||
|
</Box>
|
||||||
|
<AccordionIcon />
|
||||||
|
</AccordionButton>
|
||||||
|
</h2>
|
||||||
|
<AccordionPanel>
|
||||||
|
<OutputOptions />
|
||||||
|
</AccordionPanel>
|
||||||
|
</AccordionItem>
|
||||||
|
</Accordion>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default OptionsAccordion;
|
66
frontend/src/features/sd/OutputOptions.tsx
Normal file
66
frontend/src/features/sd/OutputOptions.tsx
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
import { Flex } from '@chakra-ui/react';
|
||||||
|
|
||||||
|
import { RootState } from '../../app/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from '../../app/hooks';
|
||||||
|
|
||||||
|
import { setHeight, setWidth, setSeamless, SDState } from '../sd/sdSlice';
|
||||||
|
|
||||||
|
import SDSelect from '../../components/SDSelect';
|
||||||
|
|
||||||
|
import { HEIGHTS, WIDTHS } from '../../app/constants';
|
||||||
|
import SDSwitch from '../../components/SDSwitch';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { isEqual } from 'lodash';
|
||||||
|
|
||||||
|
const sdSelector = createSelector(
|
||||||
|
(state: RootState) => state.sd,
|
||||||
|
(sd: SDState) => {
|
||||||
|
return {
|
||||||
|
height: sd.height,
|
||||||
|
width: sd.width,
|
||||||
|
seamless: sd.seamless,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
{
|
||||||
|
memoizeOptions: {
|
||||||
|
resultEqualityCheck: isEqual,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
const OutputOptions = () => {
|
||||||
|
const { height, width, seamless } = useAppSelector(sdSelector);
|
||||||
|
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex gap={2} direction={'column'}>
|
||||||
|
<Flex gap={2}>
|
||||||
|
<SDSelect
|
||||||
|
label='Width'
|
||||||
|
value={width}
|
||||||
|
flexGrow={1}
|
||||||
|
onChange={(e) => dispatch(setWidth(Number(e.target.value)))}
|
||||||
|
validValues={WIDTHS}
|
||||||
|
/>
|
||||||
|
<SDSelect
|
||||||
|
label='Height'
|
||||||
|
value={height}
|
||||||
|
flexGrow={1}
|
||||||
|
onChange={(e) =>
|
||||||
|
dispatch(setHeight(Number(e.target.value)))
|
||||||
|
}
|
||||||
|
validValues={HEIGHTS}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
<SDSwitch
|
||||||
|
label='Seamless tiling'
|
||||||
|
fontSize={'md'}
|
||||||
|
isChecked={seamless}
|
||||||
|
onChange={(e) => dispatch(setSeamless(e.target.checked))}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default OutputOptions;
|
58
frontend/src/features/sd/ProcessButtons.tsx
Normal file
58
frontend/src/features/sd/ProcessButtons.tsx
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
import { Flex } from '@chakra-ui/react';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { isEqual } from 'lodash';
|
||||||
|
import { useAppDispatch, useAppSelector } from '../../app/hooks';
|
||||||
|
import { cancelProcessing, generateImage } from '../../app/socketio';
|
||||||
|
import { RootState } from '../../app/store';
|
||||||
|
import SDButton from '../../components/SDButton';
|
||||||
|
import { SystemState } from '../system/systemSlice';
|
||||||
|
import useCheckParameters from '../system/useCheckParameters';
|
||||||
|
|
||||||
|
const systemSelector = createSelector(
|
||||||
|
(state: RootState) => state.system,
|
||||||
|
(system: SystemState) => {
|
||||||
|
return {
|
||||||
|
isProcessing: system.isProcessing,
|
||||||
|
isConnected: system.isConnected,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
{
|
||||||
|
memoizeOptions: {
|
||||||
|
resultEqualityCheck: isEqual,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
const ProcessButtons = () => {
|
||||||
|
const { isProcessing, isConnected } = useAppSelector(systemSelector);
|
||||||
|
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const isReady = useCheckParameters();
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex gap={2} direction={'column'} alignItems={'space-between'} height={'100%'}>
|
||||||
|
<SDButton
|
||||||
|
label='Generate'
|
||||||
|
type='submit'
|
||||||
|
colorScheme='green'
|
||||||
|
flexGrow={1}
|
||||||
|
isDisabled={!isReady}
|
||||||
|
fontSize={'md'}
|
||||||
|
size={'md'}
|
||||||
|
onClick={() => dispatch(generateImage())}
|
||||||
|
/>
|
||||||
|
<SDButton
|
||||||
|
label='Cancel'
|
||||||
|
colorScheme='red'
|
||||||
|
flexGrow={1}
|
||||||
|
fontSize={'md'}
|
||||||
|
size={'md'}
|
||||||
|
isDisabled={!isConnected || !isProcessing}
|
||||||
|
onClick={() => dispatch(cancelProcessing())}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default ProcessButtons;
|
25
frontend/src/features/sd/PromptInput.tsx
Normal file
25
frontend/src/features/sd/PromptInput.tsx
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
import { Textarea } from '@chakra-ui/react';
|
||||||
|
import { useAppDispatch, useAppSelector } from '../../app/hooks';
|
||||||
|
import { RootState } from '../../app/store';
|
||||||
|
import { setPrompt } from '../sd/sdSlice';
|
||||||
|
|
||||||
|
const PromptInput = () => {
|
||||||
|
const { prompt } = useAppSelector((state: RootState) => state.sd);
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Textarea
|
||||||
|
id='prompt'
|
||||||
|
name='prompt'
|
||||||
|
resize='none'
|
||||||
|
size={'lg'}
|
||||||
|
height={'100%'}
|
||||||
|
isInvalid={!prompt.length}
|
||||||
|
onChange={(e) => dispatch(setPrompt(e.target.value))}
|
||||||
|
value={prompt}
|
||||||
|
placeholder="I'm dreaming of..."
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default PromptInput;
|
51
frontend/src/features/sd/SDSlider.tsx
Normal file
51
frontend/src/features/sd/SDSlider.tsx
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
import {
|
||||||
|
Slider,
|
||||||
|
SliderTrack,
|
||||||
|
SliderFilledTrack,
|
||||||
|
SliderThumb,
|
||||||
|
FormControl,
|
||||||
|
FormLabel,
|
||||||
|
Text,
|
||||||
|
Flex,
|
||||||
|
SliderProps,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
|
||||||
|
interface Props extends SliderProps {
|
||||||
|
label: string;
|
||||||
|
value: number;
|
||||||
|
fontSize?: number | string;
|
||||||
|
}
|
||||||
|
|
||||||
|
const SDSlider = ({
|
||||||
|
label,
|
||||||
|
value,
|
||||||
|
fontSize = 'sm',
|
||||||
|
onChange,
|
||||||
|
...rest
|
||||||
|
}: Props) => {
|
||||||
|
return (
|
||||||
|
<FormControl>
|
||||||
|
<Flex gap={2}>
|
||||||
|
<FormLabel marginInlineEnd={0} marginBottom={1}>
|
||||||
|
<Text fontSize={fontSize} whiteSpace='nowrap'>
|
||||||
|
{label}
|
||||||
|
</Text>
|
||||||
|
</FormLabel>
|
||||||
|
<Slider
|
||||||
|
aria-label={label}
|
||||||
|
focusThumbOnChange={true}
|
||||||
|
value={value}
|
||||||
|
onChange={onChange}
|
||||||
|
{...rest}
|
||||||
|
>
|
||||||
|
<SliderTrack>
|
||||||
|
<SliderFilledTrack />
|
||||||
|
</SliderTrack>
|
||||||
|
<SliderThumb />
|
||||||
|
</Slider>
|
||||||
|
</Flex>
|
||||||
|
</FormControl>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default SDSlider;
|
62
frontend/src/features/sd/SamplerOptions.tsx
Normal file
62
frontend/src/features/sd/SamplerOptions.tsx
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
import { Flex } from '@chakra-ui/react';
|
||||||
|
|
||||||
|
import { RootState } from '../../app/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from '../../app/hooks';
|
||||||
|
|
||||||
|
import { setCfgScale, setSampler, setSteps, SDState } from '../sd/sdSlice';
|
||||||
|
|
||||||
|
import SDNumberInput from '../../components/SDNumberInput';
|
||||||
|
import SDSelect from '../../components/SDSelect';
|
||||||
|
|
||||||
|
import { SAMPLERS } from '../../app/constants';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { isEqual } from 'lodash';
|
||||||
|
|
||||||
|
const sdSelector = createSelector(
|
||||||
|
(state: RootState) => state.sd,
|
||||||
|
(sd: SDState) => {
|
||||||
|
return {
|
||||||
|
steps: sd.steps,
|
||||||
|
cfgScale: sd.cfgScale,
|
||||||
|
sampler: sd.sampler,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
{
|
||||||
|
memoizeOptions: {
|
||||||
|
resultEqualityCheck: isEqual,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
const SamplerOptions = () => {
|
||||||
|
const { steps, cfgScale, sampler } = useAppSelector(sdSelector);
|
||||||
|
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex gap={2} direction={'column'}>
|
||||||
|
<SDNumberInput
|
||||||
|
label='Steps'
|
||||||
|
min={1}
|
||||||
|
step={1}
|
||||||
|
precision={0}
|
||||||
|
onChange={(v) => dispatch(setSteps(Number(v)))}
|
||||||
|
value={steps}
|
||||||
|
/>
|
||||||
|
<SDNumberInput
|
||||||
|
label='CFG scale'
|
||||||
|
step={0.5}
|
||||||
|
onChange={(v) => dispatch(setCfgScale(Number(v)))}
|
||||||
|
value={cfgScale}
|
||||||
|
/>
|
||||||
|
<SDSelect
|
||||||
|
label='Sampler'
|
||||||
|
value={sampler}
|
||||||
|
onChange={(e) => dispatch(setSampler(e.target.value))}
|
||||||
|
validValues={SAMPLERS}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default SamplerOptions;
|
144
frontend/src/features/sd/SeedVariationOptions.tsx
Normal file
144
frontend/src/features/sd/SeedVariationOptions.tsx
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
import {
|
||||||
|
Flex,
|
||||||
|
Input,
|
||||||
|
HStack,
|
||||||
|
FormControl,
|
||||||
|
FormLabel,
|
||||||
|
Text,
|
||||||
|
Button,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { isEqual } from 'lodash';
|
||||||
|
import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from '../../app/constants';
|
||||||
|
import { useAppDispatch, useAppSelector } from '../../app/hooks';
|
||||||
|
import { RootState } from '../../app/store';
|
||||||
|
import SDNumberInput from '../../components/SDNumberInput';
|
||||||
|
import SDSwitch from '../../components/SDSwitch';
|
||||||
|
import {
|
||||||
|
randomizeSeed,
|
||||||
|
SDState,
|
||||||
|
setIterations,
|
||||||
|
setSeed,
|
||||||
|
setSeedWeights,
|
||||||
|
setShouldGenerateVariations,
|
||||||
|
setShouldRandomizeSeed,
|
||||||
|
setVariantAmount,
|
||||||
|
} from './sdSlice';
|
||||||
|
import { validateSeedWeights } from './util/seedWeightPairs';
|
||||||
|
|
||||||
|
const sdSelector = createSelector(
|
||||||
|
(state: RootState) => state.sd,
|
||||||
|
(sd: SDState) => {
|
||||||
|
return {
|
||||||
|
variantAmount: sd.variantAmount,
|
||||||
|
seedWeights: sd.seedWeights,
|
||||||
|
shouldGenerateVariations: sd.shouldGenerateVariations,
|
||||||
|
shouldRandomizeSeed: sd.shouldRandomizeSeed,
|
||||||
|
seed: sd.seed,
|
||||||
|
iterations: sd.iterations,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
{
|
||||||
|
memoizeOptions: {
|
||||||
|
resultEqualityCheck: isEqual,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
const SeedVariationOptions = () => {
|
||||||
|
const {
|
||||||
|
shouldGenerateVariations,
|
||||||
|
variantAmount,
|
||||||
|
seedWeights,
|
||||||
|
shouldRandomizeSeed,
|
||||||
|
seed,
|
||||||
|
iterations,
|
||||||
|
} = useAppSelector(sdSelector);
|
||||||
|
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex gap={2} direction={'column'}>
|
||||||
|
<SDNumberInput
|
||||||
|
label='Images to generate'
|
||||||
|
step={1}
|
||||||
|
min={1}
|
||||||
|
precision={0}
|
||||||
|
onChange={(v) => dispatch(setIterations(Number(v)))}
|
||||||
|
value={iterations}
|
||||||
|
/>
|
||||||
|
<SDSwitch
|
||||||
|
label='Randomize seed on generation'
|
||||||
|
isChecked={shouldRandomizeSeed}
|
||||||
|
onChange={(e) =>
|
||||||
|
dispatch(setShouldRandomizeSeed(e.target.checked))
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
<Flex gap={2}>
|
||||||
|
<SDNumberInput
|
||||||
|
label='Seed'
|
||||||
|
step={1}
|
||||||
|
precision={0}
|
||||||
|
flexGrow={1}
|
||||||
|
min={NUMPY_RAND_MIN}
|
||||||
|
max={NUMPY_RAND_MAX}
|
||||||
|
isDisabled={shouldRandomizeSeed}
|
||||||
|
isInvalid={seed < 0 && shouldGenerateVariations}
|
||||||
|
onChange={(v) => dispatch(setSeed(Number(v)))}
|
||||||
|
value={seed}
|
||||||
|
/>
|
||||||
|
<Button
|
||||||
|
size={'sm'}
|
||||||
|
isDisabled={shouldRandomizeSeed}
|
||||||
|
onClick={() => dispatch(randomizeSeed())}
|
||||||
|
>
|
||||||
|
<Text pl={2} pr={2}>
|
||||||
|
Shuffle
|
||||||
|
</Text>
|
||||||
|
</Button>
|
||||||
|
</Flex>
|
||||||
|
<SDSwitch
|
||||||
|
label='Generate variations'
|
||||||
|
isChecked={shouldGenerateVariations}
|
||||||
|
width={'auto'}
|
||||||
|
onChange={(e) =>
|
||||||
|
dispatch(setShouldGenerateVariations(e.target.checked))
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
<SDNumberInput
|
||||||
|
label='Variation amount'
|
||||||
|
value={variantAmount}
|
||||||
|
step={0.01}
|
||||||
|
min={0}
|
||||||
|
max={1}
|
||||||
|
isDisabled={!shouldGenerateVariations}
|
||||||
|
onChange={(v) => dispatch(setVariantAmount(Number(v)))}
|
||||||
|
/>
|
||||||
|
<FormControl
|
||||||
|
isInvalid={
|
||||||
|
shouldGenerateVariations &&
|
||||||
|
!(validateSeedWeights(seedWeights) || seedWeights === '')
|
||||||
|
}
|
||||||
|
flexGrow={1}
|
||||||
|
isDisabled={!shouldGenerateVariations}
|
||||||
|
>
|
||||||
|
<HStack>
|
||||||
|
<FormLabel marginInlineEnd={0} marginBottom={1}>
|
||||||
|
<Text whiteSpace='nowrap'>
|
||||||
|
Seed Weights
|
||||||
|
</Text>
|
||||||
|
</FormLabel>
|
||||||
|
<Input
|
||||||
|
size={'sm'}
|
||||||
|
value={seedWeights}
|
||||||
|
onChange={(e) =>
|
||||||
|
dispatch(setSeedWeights(e.target.value))
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
</HStack>
|
||||||
|
</FormControl>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default SeedVariationOptions;
|
92
frontend/src/features/sd/Variant.tsx
Normal file
92
frontend/src/features/sd/Variant.tsx
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
import {
|
||||||
|
Flex,
|
||||||
|
FormControl,
|
||||||
|
FormLabel,
|
||||||
|
HStack,
|
||||||
|
Input,
|
||||||
|
Text,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { isEqual } from 'lodash';
|
||||||
|
import { useAppDispatch, useAppSelector } from '../../app/hooks';
|
||||||
|
import { RootState } from '../../app/store';
|
||||||
|
import SDNumberInput from '../../components/SDNumberInput';
|
||||||
|
import SDSwitch from '../../components/SDSwitch';
|
||||||
|
import {
|
||||||
|
SDState,
|
||||||
|
setSeedWeights,
|
||||||
|
setShouldGenerateVariations,
|
||||||
|
setVariantAmount,
|
||||||
|
} from './sdSlice';
|
||||||
|
import { validateSeedWeights } from './util/seedWeightPairs';
|
||||||
|
|
||||||
|
const sdSelector = createSelector(
|
||||||
|
(state: RootState) => state.sd,
|
||||||
|
(sd: SDState) => {
|
||||||
|
return {
|
||||||
|
variantAmount: sd.variantAmount,
|
||||||
|
seedWeights: sd.seedWeights,
|
||||||
|
shouldGenerateVariations: sd.shouldGenerateVariations,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
{
|
||||||
|
memoizeOptions: {
|
||||||
|
resultEqualityCheck: isEqual,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
const Variant = () => {
|
||||||
|
const { shouldGenerateVariations, variantAmount, seedWeights } =
|
||||||
|
useAppSelector(sdSelector);
|
||||||
|
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex gap={2} alignItems={'center'} pl={1}>
|
||||||
|
<SDSwitch
|
||||||
|
label='Generate variations'
|
||||||
|
isChecked={shouldGenerateVariations}
|
||||||
|
width={'auto'}
|
||||||
|
onChange={(e) =>
|
||||||
|
dispatch(setShouldGenerateVariations(e.target.checked))
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
<SDNumberInput
|
||||||
|
label='Amount'
|
||||||
|
value={variantAmount}
|
||||||
|
step={0.01}
|
||||||
|
min={0}
|
||||||
|
max={1}
|
||||||
|
width={240}
|
||||||
|
isDisabled={!shouldGenerateVariations}
|
||||||
|
onChange={(v) => dispatch(setVariantAmount(Number(v)))}
|
||||||
|
/>
|
||||||
|
<FormControl
|
||||||
|
isInvalid={
|
||||||
|
shouldGenerateVariations &&
|
||||||
|
!(validateSeedWeights(seedWeights) || seedWeights === '')
|
||||||
|
}
|
||||||
|
flexGrow={1}
|
||||||
|
isDisabled={!shouldGenerateVariations}
|
||||||
|
>
|
||||||
|
<HStack>
|
||||||
|
<FormLabel marginInlineEnd={0} marginBottom={1}>
|
||||||
|
<Text fontSize={'sm'} whiteSpace='nowrap'>
|
||||||
|
Seed Weights
|
||||||
|
</Text>
|
||||||
|
</FormLabel>
|
||||||
|
<Input
|
||||||
|
size={'sm'}
|
||||||
|
value={seedWeights}
|
||||||
|
onChange={(e) =>
|
||||||
|
dispatch(setSeedWeights(e.target.value))
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
</HStack>
|
||||||
|
</FormControl>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default Variant;
|
283
frontend/src/features/sd/sdSlice.ts
Normal file
283
frontend/src/features/sd/sdSlice.ts
Normal file
@ -0,0 +1,283 @@
|
|||||||
|
import { createSlice } from '@reduxjs/toolkit';
|
||||||
|
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||||
|
import { SDMetadata } from '../gallery/gallerySlice';
|
||||||
|
import randomInt from './util/randomInt';
|
||||||
|
import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from '../../app/constants';
|
||||||
|
|
||||||
|
const calculateRealSteps = (
|
||||||
|
steps: number,
|
||||||
|
strength: number,
|
||||||
|
hasInitImage: boolean
|
||||||
|
): number => {
|
||||||
|
return hasInitImage ? Math.floor(strength * steps) : steps;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type UpscalingLevel = 0 | 2 | 3 | 4;
|
||||||
|
|
||||||
|
export interface SDState {
|
||||||
|
prompt: string;
|
||||||
|
iterations: number;
|
||||||
|
steps: number;
|
||||||
|
realSteps: number;
|
||||||
|
cfgScale: number;
|
||||||
|
height: number;
|
||||||
|
width: number;
|
||||||
|
sampler: string;
|
||||||
|
seed: number;
|
||||||
|
img2imgStrength: number;
|
||||||
|
gfpganStrength: number;
|
||||||
|
upscalingLevel: UpscalingLevel;
|
||||||
|
upscalingStrength: number;
|
||||||
|
shouldUseInitImage: boolean;
|
||||||
|
initialImagePath: string;
|
||||||
|
maskPath: string;
|
||||||
|
seamless: boolean;
|
||||||
|
shouldFitToWidthHeight: boolean;
|
||||||
|
shouldGenerateVariations: boolean;
|
||||||
|
variantAmount: number;
|
||||||
|
seedWeights: string;
|
||||||
|
shouldRunESRGAN: boolean;
|
||||||
|
shouldRunGFPGAN: boolean;
|
||||||
|
shouldRandomizeSeed: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
const initialSDState: SDState = {
|
||||||
|
prompt: '',
|
||||||
|
iterations: 1,
|
||||||
|
steps: 50,
|
||||||
|
realSteps: 50,
|
||||||
|
cfgScale: 7.5,
|
||||||
|
height: 512,
|
||||||
|
width: 512,
|
||||||
|
sampler: 'k_lms',
|
||||||
|
seed: 0,
|
||||||
|
seamless: false,
|
||||||
|
shouldUseInitImage: false,
|
||||||
|
img2imgStrength: 0.75,
|
||||||
|
initialImagePath: '',
|
||||||
|
maskPath: '',
|
||||||
|
shouldFitToWidthHeight: true,
|
||||||
|
shouldGenerateVariations: false,
|
||||||
|
variantAmount: 0.1,
|
||||||
|
seedWeights: '',
|
||||||
|
shouldRunESRGAN: false,
|
||||||
|
upscalingLevel: 4,
|
||||||
|
upscalingStrength: 0.75,
|
||||||
|
shouldRunGFPGAN: false,
|
||||||
|
gfpganStrength: 0.8,
|
||||||
|
shouldRandomizeSeed: true,
|
||||||
|
};
|
||||||
|
|
||||||
|
const initialState: SDState = initialSDState;
|
||||||
|
|
||||||
|
export const sdSlice = createSlice({
|
||||||
|
name: 'sd',
|
||||||
|
initialState,
|
||||||
|
reducers: {
|
||||||
|
setPrompt: (state, action: PayloadAction<string>) => {
|
||||||
|
state.prompt = action.payload;
|
||||||
|
},
|
||||||
|
setIterations: (state, action: PayloadAction<number>) => {
|
||||||
|
state.iterations = action.payload;
|
||||||
|
},
|
||||||
|
setSteps: (state, action: PayloadAction<number>) => {
|
||||||
|
const { img2imgStrength, initialImagePath } = state;
|
||||||
|
const steps = action.payload;
|
||||||
|
state.steps = steps;
|
||||||
|
state.realSteps = calculateRealSteps(
|
||||||
|
steps,
|
||||||
|
img2imgStrength,
|
||||||
|
Boolean(initialImagePath)
|
||||||
|
);
|
||||||
|
},
|
||||||
|
setCfgScale: (state, action: PayloadAction<number>) => {
|
||||||
|
state.cfgScale = action.payload;
|
||||||
|
},
|
||||||
|
setHeight: (state, action: PayloadAction<number>) => {
|
||||||
|
state.height = action.payload;
|
||||||
|
},
|
||||||
|
setWidth: (state, action: PayloadAction<number>) => {
|
||||||
|
state.width = action.payload;
|
||||||
|
},
|
||||||
|
setSampler: (state, action: PayloadAction<string>) => {
|
||||||
|
state.sampler = action.payload;
|
||||||
|
},
|
||||||
|
setSeed: (state, action: PayloadAction<number>) => {
|
||||||
|
state.seed = action.payload;
|
||||||
|
state.shouldRandomizeSeed = false;
|
||||||
|
},
|
||||||
|
setImg2imgStrength: (state, action: PayloadAction<number>) => {
|
||||||
|
const img2imgStrength = action.payload;
|
||||||
|
const { steps, initialImagePath } = state;
|
||||||
|
state.img2imgStrength = img2imgStrength;
|
||||||
|
state.realSteps = calculateRealSteps(
|
||||||
|
steps,
|
||||||
|
img2imgStrength,
|
||||||
|
Boolean(initialImagePath)
|
||||||
|
);
|
||||||
|
},
|
||||||
|
setGfpganStrength: (state, action: PayloadAction<number>) => {
|
||||||
|
state.gfpganStrength = action.payload;
|
||||||
|
},
|
||||||
|
setUpscalingLevel: (state, action: PayloadAction<UpscalingLevel>) => {
|
||||||
|
state.upscalingLevel = action.payload;
|
||||||
|
},
|
||||||
|
setUpscalingStrength: (state, action: PayloadAction<number>) => {
|
||||||
|
state.upscalingStrength = action.payload;
|
||||||
|
},
|
||||||
|
setShouldUseInitImage: (state, action: PayloadAction<boolean>) => {
|
||||||
|
state.shouldUseInitImage = action.payload;
|
||||||
|
},
|
||||||
|
setInitialImagePath: (state, action: PayloadAction<string>) => {
|
||||||
|
const initialImagePath = action.payload;
|
||||||
|
const { steps, img2imgStrength } = state;
|
||||||
|
state.shouldUseInitImage = initialImagePath ? true : false;
|
||||||
|
state.initialImagePath = initialImagePath;
|
||||||
|
state.realSteps = calculateRealSteps(
|
||||||
|
steps,
|
||||||
|
img2imgStrength,
|
||||||
|
Boolean(initialImagePath)
|
||||||
|
);
|
||||||
|
},
|
||||||
|
setMaskPath: (state, action: PayloadAction<string>) => {
|
||||||
|
state.maskPath = action.payload;
|
||||||
|
},
|
||||||
|
setSeamless: (state, action: PayloadAction<boolean>) => {
|
||||||
|
state.seamless = action.payload;
|
||||||
|
},
|
||||||
|
setShouldFitToWidthHeight: (state, action: PayloadAction<boolean>) => {
|
||||||
|
state.shouldFitToWidthHeight = action.payload;
|
||||||
|
},
|
||||||
|
resetSeed: (state) => {
|
||||||
|
state.seed = -1;
|
||||||
|
},
|
||||||
|
randomizeSeed: (state) => {
|
||||||
|
state.seed = randomInt(NUMPY_RAND_MIN, NUMPY_RAND_MAX);
|
||||||
|
},
|
||||||
|
setParameter: (
|
||||||
|
state,
|
||||||
|
action: PayloadAction<{ key: string; value: string | number | boolean }>
|
||||||
|
) => {
|
||||||
|
const { key, value } = action.payload;
|
||||||
|
const temp = { ...state, [key]: value };
|
||||||
|
if (key === 'seed') {
|
||||||
|
temp.shouldRandomizeSeed = false;
|
||||||
|
}
|
||||||
|
if (key === 'initialImagePath' && value === '') {
|
||||||
|
temp.shouldUseInitImage = false;
|
||||||
|
}
|
||||||
|
return temp;
|
||||||
|
},
|
||||||
|
setShouldGenerateVariations: (state, action: PayloadAction<boolean>) => {
|
||||||
|
state.shouldGenerateVariations = action.payload;
|
||||||
|
},
|
||||||
|
setVariantAmount: (state, action: PayloadAction<number>) => {
|
||||||
|
state.variantAmount = action.payload;
|
||||||
|
},
|
||||||
|
setSeedWeights: (state, action: PayloadAction<string>) => {
|
||||||
|
state.seedWeights = action.payload;
|
||||||
|
},
|
||||||
|
setAllParameters: (state, action: PayloadAction<SDMetadata>) => {
|
||||||
|
const {
|
||||||
|
prompt,
|
||||||
|
steps,
|
||||||
|
cfgScale,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
sampler,
|
||||||
|
seed,
|
||||||
|
img2imgStrength,
|
||||||
|
gfpganStrength,
|
||||||
|
upscalingLevel,
|
||||||
|
upscalingStrength,
|
||||||
|
initialImagePath,
|
||||||
|
maskPath,
|
||||||
|
seamless,
|
||||||
|
shouldFitToWidthHeight,
|
||||||
|
} = action.payload;
|
||||||
|
|
||||||
|
// ?? = falsy values ('', 0, etc) are used
|
||||||
|
// || = falsy values not used
|
||||||
|
state.prompt = prompt ?? state.prompt;
|
||||||
|
state.steps = steps || state.steps;
|
||||||
|
state.cfgScale = cfgScale || state.cfgScale;
|
||||||
|
state.width = width || state.width;
|
||||||
|
state.height = height || state.height;
|
||||||
|
state.sampler = sampler || state.sampler;
|
||||||
|
state.seed = seed ?? state.seed;
|
||||||
|
state.seamless = seamless ?? state.seamless;
|
||||||
|
state.shouldFitToWidthHeight =
|
||||||
|
shouldFitToWidthHeight ?? state.shouldFitToWidthHeight;
|
||||||
|
state.img2imgStrength = img2imgStrength ?? state.img2imgStrength;
|
||||||
|
state.gfpganStrength = gfpganStrength ?? state.gfpganStrength;
|
||||||
|
state.upscalingLevel = upscalingLevel ?? state.upscalingLevel;
|
||||||
|
state.upscalingStrength = upscalingStrength ?? state.upscalingStrength;
|
||||||
|
state.initialImagePath = initialImagePath ?? state.initialImagePath;
|
||||||
|
state.maskPath = maskPath ?? state.maskPath;
|
||||||
|
|
||||||
|
// If the image whose parameters we are using has a seed, disable randomizing the seed
|
||||||
|
if (seed) {
|
||||||
|
state.shouldRandomizeSeed = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// if we have a gfpgan strength, enable it
|
||||||
|
state.shouldRunGFPGAN = gfpganStrength ? true : false;
|
||||||
|
|
||||||
|
// if we have a esrgan strength, enable it
|
||||||
|
state.shouldRunESRGAN = upscalingLevel ? true : false;
|
||||||
|
|
||||||
|
// if we want to recreate an image exactly, we disable variations
|
||||||
|
state.shouldGenerateVariations = false;
|
||||||
|
|
||||||
|
state.shouldUseInitImage = initialImagePath ? true : false;
|
||||||
|
},
|
||||||
|
resetSDState: (state) => {
|
||||||
|
return {
|
||||||
|
...state,
|
||||||
|
...initialSDState,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
setShouldRunGFPGAN: (state, action: PayloadAction<boolean>) => {
|
||||||
|
state.shouldRunGFPGAN = action.payload;
|
||||||
|
},
|
||||||
|
setShouldRunESRGAN: (state, action: PayloadAction<boolean>) => {
|
||||||
|
state.shouldRunESRGAN = action.payload;
|
||||||
|
},
|
||||||
|
setShouldRandomizeSeed: (state, action: PayloadAction<boolean>) => {
|
||||||
|
state.shouldRandomizeSeed = action.payload;
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
export const {
|
||||||
|
setPrompt,
|
||||||
|
setIterations,
|
||||||
|
setSteps,
|
||||||
|
setCfgScale,
|
||||||
|
setHeight,
|
||||||
|
setWidth,
|
||||||
|
setSampler,
|
||||||
|
setSeed,
|
||||||
|
setSeamless,
|
||||||
|
setImg2imgStrength,
|
||||||
|
setGfpganStrength,
|
||||||
|
setUpscalingLevel,
|
||||||
|
setUpscalingStrength,
|
||||||
|
setShouldUseInitImage,
|
||||||
|
setInitialImagePath,
|
||||||
|
setMaskPath,
|
||||||
|
resetSeed,
|
||||||
|
randomizeSeed,
|
||||||
|
resetSDState,
|
||||||
|
setShouldFitToWidthHeight,
|
||||||
|
setParameter,
|
||||||
|
setShouldGenerateVariations,
|
||||||
|
setSeedWeights,
|
||||||
|
setVariantAmount,
|
||||||
|
setAllParameters,
|
||||||
|
setShouldRunGFPGAN,
|
||||||
|
setShouldRunESRGAN,
|
||||||
|
setShouldRandomizeSeed,
|
||||||
|
} = sdSlice.actions;
|
||||||
|
|
||||||
|
export default sdSlice.reducer;
|
5
frontend/src/features/sd/util/randomInt.ts
Normal file
5
frontend/src/features/sd/util/randomInt.ts
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
const randomInt = (min: number, max: number): number => {
|
||||||
|
return Math.floor(Math.random() * (max - min + 1) + min);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default randomInt;
|
56
frontend/src/features/sd/util/seedWeightPairs.ts
Normal file
56
frontend/src/features/sd/util/seedWeightPairs.ts
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
export interface SeedWeightPair {
|
||||||
|
seed: number;
|
||||||
|
weight: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export type SeedWeights = Array<Array<number>>;
|
||||||
|
|
||||||
|
export const stringToSeedWeights = (string: string): SeedWeights | boolean => {
|
||||||
|
const stringPairs = string.split(',');
|
||||||
|
const arrPairs = stringPairs.map((p) => p.split(':'));
|
||||||
|
const pairs = arrPairs.map((p) => {
|
||||||
|
return [parseInt(p[0]), parseFloat(p[1])];
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!validateSeedWeights(pairs)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return pairs;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const validateSeedWeights = (
|
||||||
|
seedWeights: SeedWeights | string
|
||||||
|
): boolean => {
|
||||||
|
return typeof seedWeights === 'string'
|
||||||
|
? Boolean(stringToSeedWeights(seedWeights))
|
||||||
|
: Boolean(
|
||||||
|
seedWeights.length &&
|
||||||
|
!seedWeights.some((pair) => {
|
||||||
|
const [seed, weight] = pair;
|
||||||
|
const isSeedValid = !isNaN(parseInt(seed.toString(), 10));
|
||||||
|
const isWeightValid =
|
||||||
|
!isNaN(parseInt(weight.toString(), 10)) &&
|
||||||
|
weight >= 0 &&
|
||||||
|
weight <= 1;
|
||||||
|
return !(isSeedValid && isWeightValid);
|
||||||
|
})
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export const seedWeightsToString = (
|
||||||
|
seedWeights: SeedWeights
|
||||||
|
): string | boolean => {
|
||||||
|
if (!validateSeedWeights(seedWeights)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return seedWeights.reduce((acc, pair, i, arr) => {
|
||||||
|
const [seed, weight] = pair;
|
||||||
|
acc += `${seed}:${weight}`;
|
||||||
|
if (i !== arr.length - 1) {
|
||||||
|
acc += ',';
|
||||||
|
}
|
||||||
|
return acc;
|
||||||
|
}, '');
|
||||||
|
};
|
125
frontend/src/features/system/LogViewer.tsx
Normal file
125
frontend/src/features/system/LogViewer.tsx
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
import {
|
||||||
|
IconButton,
|
||||||
|
useColorModeValue,
|
||||||
|
Flex,
|
||||||
|
Text,
|
||||||
|
Tooltip,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
import { useAppDispatch, useAppSelector } from '../../app/hooks';
|
||||||
|
import { RootState } from '../../app/store';
|
||||||
|
import { setShouldShowLogViewer, SystemState } from './systemSlice';
|
||||||
|
import { useLayoutEffect, useRef, useState } from 'react';
|
||||||
|
import { FaAngleDoubleDown, FaCode, FaMinus } from 'react-icons/fa';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { isEqual } from 'lodash';
|
||||||
|
|
||||||
|
const logSelector = createSelector(
|
||||||
|
(state: RootState) => state.system,
|
||||||
|
(system: SystemState) => system.log,
|
||||||
|
{
|
||||||
|
memoizeOptions: {
|
||||||
|
resultEqualityCheck: (a, b) => a.length === b.length,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
const systemSelector = createSelector(
|
||||||
|
(state: RootState) => state.system,
|
||||||
|
(system: SystemState) => {
|
||||||
|
return { shouldShowLogViewer: system.shouldShowLogViewer };
|
||||||
|
},
|
||||||
|
{
|
||||||
|
memoizeOptions: {
|
||||||
|
resultEqualityCheck: isEqual,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
const LogViewer = () => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const bg = useColorModeValue('gray.50', 'gray.900');
|
||||||
|
const borderColor = useColorModeValue('gray.500', 'gray.500');
|
||||||
|
const [shouldAutoscroll, setShouldAutoscroll] = useState<boolean>(true);
|
||||||
|
|
||||||
|
const log = useAppSelector(logSelector);
|
||||||
|
const { shouldShowLogViewer } = useAppSelector(systemSelector);
|
||||||
|
|
||||||
|
const viewerRef = useRef<HTMLDivElement>(null);
|
||||||
|
|
||||||
|
useLayoutEffect(() => {
|
||||||
|
if (viewerRef.current !== null && shouldAutoscroll) {
|
||||||
|
viewerRef.current.scrollTop = viewerRef.current.scrollHeight;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
{shouldShowLogViewer && (
|
||||||
|
<Flex
|
||||||
|
position={'fixed'}
|
||||||
|
left={0}
|
||||||
|
bottom={0}
|
||||||
|
height='200px'
|
||||||
|
width='100vw'
|
||||||
|
overflow='auto'
|
||||||
|
direction='column'
|
||||||
|
fontFamily='monospace'
|
||||||
|
fontSize='sm'
|
||||||
|
pl={12}
|
||||||
|
pr={2}
|
||||||
|
pb={2}
|
||||||
|
borderTopWidth='4px'
|
||||||
|
borderColor={borderColor}
|
||||||
|
background={bg}
|
||||||
|
ref={viewerRef}
|
||||||
|
>
|
||||||
|
{log.map((entry, i) => (
|
||||||
|
<Flex gap={2} key={i}>
|
||||||
|
<Text fontSize='sm' fontWeight={'semibold'}>
|
||||||
|
{entry.timestamp}:
|
||||||
|
</Text>
|
||||||
|
<Text fontSize='sm' wordBreak={'break-all'}>
|
||||||
|
{entry.message}
|
||||||
|
</Text>
|
||||||
|
</Flex>
|
||||||
|
))}
|
||||||
|
</Flex>
|
||||||
|
)}
|
||||||
|
{shouldShowLogViewer && (
|
||||||
|
<Tooltip
|
||||||
|
label={
|
||||||
|
shouldAutoscroll ? 'Autoscroll on' : 'Autoscroll off'
|
||||||
|
}
|
||||||
|
>
|
||||||
|
<IconButton
|
||||||
|
size='sm'
|
||||||
|
position={'fixed'}
|
||||||
|
left={2}
|
||||||
|
bottom={12}
|
||||||
|
aria-label='Toggle autoscroll'
|
||||||
|
variant={'solid'}
|
||||||
|
colorScheme={shouldAutoscroll ? 'blue' : 'gray'}
|
||||||
|
icon={<FaAngleDoubleDown />}
|
||||||
|
onClick={() => setShouldAutoscroll(!shouldAutoscroll)}
|
||||||
|
/>
|
||||||
|
</Tooltip>
|
||||||
|
)}
|
||||||
|
<Tooltip label={shouldShowLogViewer ? 'Hide logs' : 'Show logs'}>
|
||||||
|
<IconButton
|
||||||
|
size='sm'
|
||||||
|
position={'fixed'}
|
||||||
|
left={2}
|
||||||
|
bottom={2}
|
||||||
|
variant={'solid'}
|
||||||
|
aria-label='Toggle Log Viewer'
|
||||||
|
icon={shouldShowLogViewer ? <FaMinus /> : <FaCode />}
|
||||||
|
onClick={() =>
|
||||||
|
dispatch(setShouldShowLogViewer(!shouldShowLogViewer))
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
</Tooltip>
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default LogViewer;
|
170
frontend/src/features/system/SettingsModal.tsx
Normal file
170
frontend/src/features/system/SettingsModal.tsx
Normal file
@ -0,0 +1,170 @@
|
|||||||
|
import {
|
||||||
|
Flex,
|
||||||
|
FormControl,
|
||||||
|
FormLabel,
|
||||||
|
Heading,
|
||||||
|
HStack,
|
||||||
|
Modal,
|
||||||
|
ModalBody,
|
||||||
|
ModalCloseButton,
|
||||||
|
ModalContent,
|
||||||
|
ModalFooter,
|
||||||
|
ModalHeader,
|
||||||
|
ModalOverlay,
|
||||||
|
Switch,
|
||||||
|
Text,
|
||||||
|
useDisclosure,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
import { useAppDispatch, useAppSelector } from '../../app/hooks';
|
||||||
|
import {
|
||||||
|
setShouldConfirmOnDelete,
|
||||||
|
setShouldDisplayInProgress,
|
||||||
|
SystemState,
|
||||||
|
} from './systemSlice';
|
||||||
|
import { RootState } from '../../app/store';
|
||||||
|
import SDButton from '../../components/SDButton';
|
||||||
|
import { persistor } from '../../main';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { isEqual } from 'lodash';
|
||||||
|
import { cloneElement, ReactElement } from 'react';
|
||||||
|
|
||||||
|
const systemSelector = createSelector(
|
||||||
|
(state: RootState) => state.system,
|
||||||
|
(system: SystemState) => {
|
||||||
|
const { shouldDisplayInProgress, shouldConfirmOnDelete } = system;
|
||||||
|
return { shouldDisplayInProgress, shouldConfirmOnDelete };
|
||||||
|
},
|
||||||
|
{
|
||||||
|
memoizeOptions: { resultEqualityCheck: isEqual },
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
children: ReactElement;
|
||||||
|
};
|
||||||
|
|
||||||
|
const SettingsModal = ({ children }: Props) => {
|
||||||
|
const {
|
||||||
|
isOpen: isSettingsModalOpen,
|
||||||
|
onOpen: onSettingsModalOpen,
|
||||||
|
onClose: onSettingsModalClose,
|
||||||
|
} = useDisclosure();
|
||||||
|
|
||||||
|
const {
|
||||||
|
isOpen: isRefreshModalOpen,
|
||||||
|
onOpen: onRefreshModalOpen,
|
||||||
|
onClose: onRefreshModalClose,
|
||||||
|
} = useDisclosure();
|
||||||
|
|
||||||
|
const { shouldDisplayInProgress, shouldConfirmOnDelete } =
|
||||||
|
useAppSelector(systemSelector);
|
||||||
|
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const handleClickResetWebUI = () => {
|
||||||
|
persistor.purge().then(() => {
|
||||||
|
onSettingsModalClose();
|
||||||
|
onRefreshModalOpen();
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
{cloneElement(children, {
|
||||||
|
onClick: onSettingsModalOpen,
|
||||||
|
})}
|
||||||
|
|
||||||
|
<Modal isOpen={isSettingsModalOpen} onClose={onSettingsModalClose}>
|
||||||
|
<ModalOverlay />
|
||||||
|
<ModalContent>
|
||||||
|
<ModalHeader>Settings</ModalHeader>
|
||||||
|
<ModalCloseButton />
|
||||||
|
<ModalBody>
|
||||||
|
<Flex gap={5} direction='column'>
|
||||||
|
<FormControl>
|
||||||
|
<HStack>
|
||||||
|
<FormLabel marginBottom={1}>
|
||||||
|
Display in-progress images (slower)
|
||||||
|
</FormLabel>
|
||||||
|
<Switch
|
||||||
|
isChecked={shouldDisplayInProgress}
|
||||||
|
onChange={(e) =>
|
||||||
|
dispatch(
|
||||||
|
setShouldDisplayInProgress(
|
||||||
|
e.target.checked
|
||||||
|
)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
</HStack>
|
||||||
|
</FormControl>
|
||||||
|
<FormControl>
|
||||||
|
<HStack>
|
||||||
|
<FormLabel marginBottom={1}>
|
||||||
|
Confirm on delete
|
||||||
|
</FormLabel>
|
||||||
|
<Switch
|
||||||
|
isChecked={shouldConfirmOnDelete}
|
||||||
|
onChange={(e) =>
|
||||||
|
dispatch(
|
||||||
|
setShouldConfirmOnDelete(
|
||||||
|
e.target.checked
|
||||||
|
)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
</HStack>
|
||||||
|
</FormControl>
|
||||||
|
|
||||||
|
<Heading size={'md'}>Reset Web UI</Heading>
|
||||||
|
<Text>
|
||||||
|
Resetting the web UI only resets the browser's
|
||||||
|
local cache of your images and remembered
|
||||||
|
settings. It does not delete any images from
|
||||||
|
disk.
|
||||||
|
</Text>
|
||||||
|
<Text>
|
||||||
|
If images aren't showing up in the gallery or
|
||||||
|
something else isn't working, please try
|
||||||
|
resetting before submitting an issue on GitHub.
|
||||||
|
</Text>
|
||||||
|
<SDButton
|
||||||
|
label='Reset Web UI'
|
||||||
|
colorScheme='red'
|
||||||
|
onClick={handleClickResetWebUI}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
</ModalBody>
|
||||||
|
|
||||||
|
<ModalFooter>
|
||||||
|
<SDButton
|
||||||
|
label='Close'
|
||||||
|
onClick={onSettingsModalClose}
|
||||||
|
/>
|
||||||
|
</ModalFooter>
|
||||||
|
</ModalContent>
|
||||||
|
</Modal>
|
||||||
|
|
||||||
|
<Modal
|
||||||
|
closeOnOverlayClick={false}
|
||||||
|
isOpen={isRefreshModalOpen}
|
||||||
|
onClose={onRefreshModalClose}
|
||||||
|
isCentered
|
||||||
|
>
|
||||||
|
<ModalOverlay bg='blackAlpha.300' backdropFilter='blur(40px)' />
|
||||||
|
<ModalContent>
|
||||||
|
<ModalBody pb={6} pt={6}>
|
||||||
|
<Flex justifyContent={'center'}>
|
||||||
|
<Text fontSize={'lg'}>
|
||||||
|
Web UI has been reset. Refresh the page to
|
||||||
|
reload.
|
||||||
|
</Text>
|
||||||
|
</Flex>
|
||||||
|
</ModalBody>
|
||||||
|
</ModalContent>
|
||||||
|
</Modal>
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default SettingsModal;
|
98
frontend/src/features/system/systemSlice.ts
Normal file
98
frontend/src/features/system/systemSlice.ts
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
import { createSlice } from '@reduxjs/toolkit';
|
||||||
|
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||||
|
import dateFormat from 'dateformat';
|
||||||
|
import { ExpandedIndex } from '@chakra-ui/react';
|
||||||
|
|
||||||
|
export interface LogEntry {
|
||||||
|
timestamp: string;
|
||||||
|
message: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface Log {
|
||||||
|
[index: number]: LogEntry;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface SystemState {
|
||||||
|
shouldDisplayInProgress: boolean;
|
||||||
|
isProcessing: boolean;
|
||||||
|
currentStep: number;
|
||||||
|
log: Array<LogEntry>;
|
||||||
|
shouldShowLogViewer: boolean;
|
||||||
|
isGFPGANAvailable: boolean;
|
||||||
|
isESRGANAvailable: boolean;
|
||||||
|
isConnected: boolean;
|
||||||
|
socketId: string;
|
||||||
|
shouldConfirmOnDelete: boolean;
|
||||||
|
openAccordions: ExpandedIndex;
|
||||||
|
}
|
||||||
|
|
||||||
|
const initialSystemState = {
|
||||||
|
isConnected: false,
|
||||||
|
isProcessing: false,
|
||||||
|
currentStep: 0,
|
||||||
|
log: [],
|
||||||
|
shouldShowLogViewer: false,
|
||||||
|
shouldDisplayInProgress: false,
|
||||||
|
isGFPGANAvailable: true,
|
||||||
|
isESRGANAvailable: true,
|
||||||
|
socketId: '',
|
||||||
|
shouldConfirmOnDelete: true,
|
||||||
|
openAccordions: [0],
|
||||||
|
};
|
||||||
|
|
||||||
|
const initialState: SystemState = initialSystemState;
|
||||||
|
|
||||||
|
export const systemSlice = createSlice({
|
||||||
|
name: 'system',
|
||||||
|
initialState,
|
||||||
|
reducers: {
|
||||||
|
setShouldDisplayInProgress: (state, action: PayloadAction<boolean>) => {
|
||||||
|
state.shouldDisplayInProgress = action.payload;
|
||||||
|
},
|
||||||
|
setIsProcessing: (state, action: PayloadAction<boolean>) => {
|
||||||
|
state.isProcessing = action.payload;
|
||||||
|
if (action.payload === false) {
|
||||||
|
state.currentStep = 0;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
setCurrentStep: (state, action: PayloadAction<number>) => {
|
||||||
|
state.currentStep = action.payload;
|
||||||
|
},
|
||||||
|
addLogEntry: (state, action: PayloadAction<string>) => {
|
||||||
|
const entry: LogEntry = {
|
||||||
|
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||||
|
message: action.payload,
|
||||||
|
};
|
||||||
|
state.log.push(entry);
|
||||||
|
},
|
||||||
|
setShouldShowLogViewer: (state, action: PayloadAction<boolean>) => {
|
||||||
|
state.shouldShowLogViewer = action.payload;
|
||||||
|
},
|
||||||
|
setIsConnected: (state, action: PayloadAction<boolean>) => {
|
||||||
|
state.isConnected = action.payload;
|
||||||
|
},
|
||||||
|
setSocketId: (state, action: PayloadAction<string>) => {
|
||||||
|
state.socketId = action.payload;
|
||||||
|
},
|
||||||
|
setShouldConfirmOnDelete: (state, action: PayloadAction<boolean>) => {
|
||||||
|
state.shouldConfirmOnDelete = action.payload;
|
||||||
|
},
|
||||||
|
setOpenAccordions: (state, action: PayloadAction<ExpandedIndex>) => {
|
||||||
|
state.openAccordions = action.payload;
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
export const {
|
||||||
|
setShouldDisplayInProgress,
|
||||||
|
setIsProcessing,
|
||||||
|
setCurrentStep,
|
||||||
|
addLogEntry,
|
||||||
|
setShouldShowLogViewer,
|
||||||
|
setIsConnected,
|
||||||
|
setSocketId,
|
||||||
|
setShouldConfirmOnDelete,
|
||||||
|
setOpenAccordions,
|
||||||
|
} = systemSlice.actions;
|
||||||
|
|
||||||
|
export default systemSlice.reducer;
|
108
frontend/src/features/system/useCheckParameters.ts
Normal file
108
frontend/src/features/system/useCheckParameters.ts
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { isEqual } from 'lodash';
|
||||||
|
import { useMemo } from 'react';
|
||||||
|
import { useAppSelector } from '../../app/hooks';
|
||||||
|
import { RootState } from '../../app/store';
|
||||||
|
import { SDState } from '../sd/sdSlice';
|
||||||
|
import { validateSeedWeights } from '../sd/util/seedWeightPairs';
|
||||||
|
import { SystemState } from './systemSlice';
|
||||||
|
|
||||||
|
const sdSelector = createSelector(
|
||||||
|
(state: RootState) => state.sd,
|
||||||
|
(sd: SDState) => {
|
||||||
|
return {
|
||||||
|
prompt: sd.prompt,
|
||||||
|
shouldGenerateVariations: sd.shouldGenerateVariations,
|
||||||
|
seedWeights: sd.seedWeights,
|
||||||
|
maskPath: sd.maskPath,
|
||||||
|
initialImagePath: sd.initialImagePath,
|
||||||
|
seed: sd.seed,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
{
|
||||||
|
memoizeOptions: {
|
||||||
|
resultEqualityCheck: isEqual,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
const systemSelector = createSelector(
|
||||||
|
(state: RootState) => state.system,
|
||||||
|
(system: SystemState) => {
|
||||||
|
return {
|
||||||
|
isProcessing: system.isProcessing,
|
||||||
|
isConnected: system.isConnected,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
{
|
||||||
|
memoizeOptions: {
|
||||||
|
resultEqualityCheck: isEqual,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
/*
|
||||||
|
Checks relevant pieces of state to confirm generation will not deterministically fail.
|
||||||
|
|
||||||
|
This is used to prevent the 'Generate' button from being clicked.
|
||||||
|
|
||||||
|
Other parameter values may cause failure but we rely on input validation for those.
|
||||||
|
*/
|
||||||
|
const useCheckParameters = () => {
|
||||||
|
const {
|
||||||
|
prompt,
|
||||||
|
shouldGenerateVariations,
|
||||||
|
seedWeights,
|
||||||
|
maskPath,
|
||||||
|
initialImagePath,
|
||||||
|
seed,
|
||||||
|
} = useAppSelector(sdSelector);
|
||||||
|
|
||||||
|
const { isProcessing, isConnected } = useAppSelector(systemSelector);
|
||||||
|
|
||||||
|
return useMemo(() => {
|
||||||
|
// Cannot generate without a prompt
|
||||||
|
if (!prompt) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cannot generate with a mask without img2img
|
||||||
|
if (maskPath && !initialImagePath) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: job queue
|
||||||
|
// Cannot generate if already processing an image
|
||||||
|
if (isProcessing) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cannot generate if not connected
|
||||||
|
if (!isConnected) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cannot generate variations without valid seed weights
|
||||||
|
if (
|
||||||
|
shouldGenerateVariations &&
|
||||||
|
(!(validateSeedWeights(seedWeights) || seedWeights === '') ||
|
||||||
|
seed === -1)
|
||||||
|
) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// All good
|
||||||
|
return true;
|
||||||
|
}, [
|
||||||
|
prompt,
|
||||||
|
maskPath,
|
||||||
|
initialImagePath,
|
||||||
|
isProcessing,
|
||||||
|
isConnected,
|
||||||
|
shouldGenerateVariations,
|
||||||
|
seedWeights,
|
||||||
|
seed,
|
||||||
|
]);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default useCheckParameters;
|
26
frontend/src/main.tsx
Normal file
26
frontend/src/main.tsx
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
import React from 'react';
|
||||||
|
import ReactDOM from 'react-dom/client';
|
||||||
|
import { ChakraProvider, ColorModeScript } from '@chakra-ui/react';
|
||||||
|
import { store } from './app/store';
|
||||||
|
import { Provider } from 'react-redux';
|
||||||
|
import { PersistGate } from 'redux-persist/integration/react';
|
||||||
|
import { persistStore } from 'redux-persist';
|
||||||
|
|
||||||
|
export const persistor = persistStore(store);
|
||||||
|
|
||||||
|
import App from './App';
|
||||||
|
import { theme } from './app/theme';
|
||||||
|
import Loading from './Loading';
|
||||||
|
|
||||||
|
ReactDOM.createRoot(document.getElementById('root') as HTMLElement).render(
|
||||||
|
<React.StrictMode>
|
||||||
|
<Provider store={store}>
|
||||||
|
<PersistGate loading={<Loading />} persistor={persistor}>
|
||||||
|
<ChakraProvider theme={theme}>
|
||||||
|
<ColorModeScript initialColorMode={theme.config.initialColorMode} />
|
||||||
|
<App />
|
||||||
|
</ChakraProvider>
|
||||||
|
</PersistGate>
|
||||||
|
</Provider>
|
||||||
|
</React.StrictMode>
|
||||||
|
);
|
1
frontend/src/vite-env.d.ts
vendored
Normal file
1
frontend/src/vite-env.d.ts
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
/// <reference types="vite/client" />
|
21
frontend/tsconfig.json
Normal file
21
frontend/tsconfig.json
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
{
|
||||||
|
"compilerOptions": {
|
||||||
|
"target": "ESNext",
|
||||||
|
"useDefineForClassFields": true,
|
||||||
|
"lib": ["DOM", "DOM.Iterable", "ESNext"],
|
||||||
|
"allowJs": false,
|
||||||
|
"skipLibCheck": true,
|
||||||
|
"esModuleInterop": false,
|
||||||
|
"allowSyntheticDefaultImports": true,
|
||||||
|
"strict": true,
|
||||||
|
"forceConsistentCasingInFileNames": true,
|
||||||
|
"module": "ESNext",
|
||||||
|
"moduleResolution": "Node",
|
||||||
|
"resolveJsonModule": true,
|
||||||
|
"isolatedModules": true,
|
||||||
|
"noEmit": true,
|
||||||
|
"jsx": "react-jsx"
|
||||||
|
},
|
||||||
|
"include": ["src", "index.d.ts"],
|
||||||
|
"references": [{ "path": "./tsconfig.node.json" }]
|
||||||
|
}
|
9
frontend/tsconfig.node.json
Normal file
9
frontend/tsconfig.node.json
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
{
|
||||||
|
"compilerOptions": {
|
||||||
|
"composite": true,
|
||||||
|
"module": "ESNext",
|
||||||
|
"moduleResolution": "Node",
|
||||||
|
"allowSyntheticDefaultImports": true
|
||||||
|
},
|
||||||
|
"include": ["vite.config.ts"]
|
||||||
|
}
|
36
frontend/vite.config.ts
Normal file
36
frontend/vite.config.ts
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
import { defineConfig } from 'vite';
|
||||||
|
import react from '@vitejs/plugin-react';
|
||||||
|
import eslint from 'vite-plugin-eslint';
|
||||||
|
|
||||||
|
// https://vitejs.dev/config/
|
||||||
|
export default defineConfig(({ mode }) => {
|
||||||
|
const common = {
|
||||||
|
plugins: [react(), eslint()],
|
||||||
|
server: {
|
||||||
|
proxy: {
|
||||||
|
'/outputs': {
|
||||||
|
target: 'http://localhost:9090/outputs',
|
||||||
|
changeOrigin: true,
|
||||||
|
rewrite: (path) => path.replace(/^\/outputs/, ''),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
build: {
|
||||||
|
target: 'esnext',
|
||||||
|
chunkSizeWarningLimit: 1500, // we don't really care about chunk size
|
||||||
|
},
|
||||||
|
};
|
||||||
|
if (mode == 'development') {
|
||||||
|
return {
|
||||||
|
...common,
|
||||||
|
build: {
|
||||||
|
...common.build,
|
||||||
|
// sourcemap: true, // this can be enabled if needed, it adds ovwer 15MB to the commit
|
||||||
|
},
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
return {
|
||||||
|
...common,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
});
|
3149
frontend/yarn.lock
Normal file
3149
frontend/yarn.lock
Normal file
File diff suppressed because it is too large
Load Diff
619
ldm/dream/args.py
Normal file
619
ldm/dream/args.py
Normal file
@ -0,0 +1,619 @@
|
|||||||
|
"""Helper class for dealing with image generation arguments.
|
||||||
|
|
||||||
|
The Args class parses both the command line (shell) arguments, as well as the
|
||||||
|
command string passed at the dream> prompt. It serves as the definitive repository
|
||||||
|
of all the arguments used by Generate and their default values.
|
||||||
|
|
||||||
|
To use:
|
||||||
|
opt = Args()
|
||||||
|
|
||||||
|
# Read in the command line options:
|
||||||
|
# this returns a namespace object like the underlying argparse library)
|
||||||
|
# You do not have to use the return value, but you can check it against None
|
||||||
|
# to detect illegal arguments on the command line.
|
||||||
|
args = opt.parse_args()
|
||||||
|
if not args:
|
||||||
|
print('oops')
|
||||||
|
sys.exit(-1)
|
||||||
|
|
||||||
|
# read in a command passed to the dream> prompt:
|
||||||
|
opts = opt.parse_cmd('do androids dream of electric sheep? -H256 -W1024 -n4')
|
||||||
|
|
||||||
|
# The Args object acts like a namespace object
|
||||||
|
print(opt.model)
|
||||||
|
|
||||||
|
You can set attributes in the usual way, use vars(), etc.:
|
||||||
|
|
||||||
|
opt.model = 'something-else'
|
||||||
|
do_something(**vars(a))
|
||||||
|
|
||||||
|
It is helpful in saving metadata:
|
||||||
|
|
||||||
|
# To get a json representation of all the values, allowing
|
||||||
|
# you to override any values dynamically
|
||||||
|
j = opt.json(seed=42)
|
||||||
|
|
||||||
|
# To get the prompt string with the switches, allowing you
|
||||||
|
# to override any values dynamically
|
||||||
|
j = opt.dream_prompt_str(seed=42)
|
||||||
|
|
||||||
|
If you want to access the namespace objects from the shell args or the
|
||||||
|
parsed command directly, you may use the values returned from the
|
||||||
|
original calls to parse_args() and parse_cmd(), or get them later
|
||||||
|
using the _arg_switches and _cmd_switches attributes. This can be
|
||||||
|
useful if both the args and the command contain the same attribute and
|
||||||
|
you wish to apply logic as to which one to use. For example:
|
||||||
|
|
||||||
|
a = Args()
|
||||||
|
args = a.parse_args()
|
||||||
|
opts = a.parse_cmd(string)
|
||||||
|
do_grid = args.grid or opts.grid
|
||||||
|
|
||||||
|
To add new attributes, edit the _create_arg_parser() and
|
||||||
|
_create_dream_cmd_parser() methods.
|
||||||
|
|
||||||
|
We also export the function build_metadata
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import shlex
|
||||||
|
import json
|
||||||
|
import hashlib
|
||||||
|
import os
|
||||||
|
import copy
|
||||||
|
from ldm.dream.conditioning import split_weighted_subprompts
|
||||||
|
|
||||||
|
SAMPLER_CHOICES = [
|
||||||
|
'ddim',
|
||||||
|
'k_dpm_2_a',
|
||||||
|
'k_dpm_2',
|
||||||
|
'k_euler_a',
|
||||||
|
'k_euler',
|
||||||
|
'k_heun',
|
||||||
|
'k_lms',
|
||||||
|
'plms',
|
||||||
|
]
|
||||||
|
|
||||||
|
# is there a way to pick this up during git commits?
|
||||||
|
APP_ID = 'lstein/stable-diffusion'
|
||||||
|
APP_VERSION = 'v1.15'
|
||||||
|
|
||||||
|
class Args(object):
|
||||||
|
def __init__(self,arg_parser=None,cmd_parser=None):
|
||||||
|
'''
|
||||||
|
Initialize new Args class. It takes two optional arguments, an argparse
|
||||||
|
parser for switches given on the shell command line, and an argparse
|
||||||
|
parser for switches given on the dream> CLI line. If one or both are
|
||||||
|
missing, it creates appropriate parsers internally.
|
||||||
|
'''
|
||||||
|
self._arg_parser = arg_parser or self._create_arg_parser()
|
||||||
|
self._cmd_parser = cmd_parser or self._create_dream_cmd_parser()
|
||||||
|
self._arg_switches = self.parse_cmd('') # fill in defaults
|
||||||
|
self._cmd_switches = self.parse_cmd('') # fill in defaults
|
||||||
|
|
||||||
|
def parse_args(self):
|
||||||
|
'''Parse the shell switches and store.'''
|
||||||
|
try:
|
||||||
|
self._arg_switches = self._arg_parser.parse_args()
|
||||||
|
return self._arg_switches
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def parse_cmd(self,cmd_string):
|
||||||
|
'''Parse a dream>-style command string '''
|
||||||
|
command = cmd_string.replace("'", "\\'")
|
||||||
|
try:
|
||||||
|
elements = shlex.split(command)
|
||||||
|
except ValueError:
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
return
|
||||||
|
switches = ['']
|
||||||
|
switches_started = False
|
||||||
|
|
||||||
|
for element in elements:
|
||||||
|
if element[0] == '-' and not switches_started:
|
||||||
|
switches_started = True
|
||||||
|
if switches_started:
|
||||||
|
switches.append(element)
|
||||||
|
else:
|
||||||
|
switches[0] += element
|
||||||
|
switches[0] += ' '
|
||||||
|
switches[0] = switches[0][: len(switches[0]) - 1]
|
||||||
|
try:
|
||||||
|
self._cmd_switches = self._cmd_parser.parse_args(switches)
|
||||||
|
return self._cmd_switches
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def json(self,**kwargs):
|
||||||
|
return json.dumps(self.to_dict(**kwargs))
|
||||||
|
|
||||||
|
def to_dict(self,**kwargs):
|
||||||
|
a = vars(self)
|
||||||
|
a.update(kwargs)
|
||||||
|
return a
|
||||||
|
|
||||||
|
# Isn't there a more automated way of doing this?
|
||||||
|
# Ideally we get the switch strings out of the argparse objects,
|
||||||
|
# but I don't see a documented API for this.
|
||||||
|
def dream_prompt_str(self,**kwargs):
|
||||||
|
"""Normalized dream_prompt."""
|
||||||
|
a = vars(self)
|
||||||
|
a.update(kwargs)
|
||||||
|
switches = list()
|
||||||
|
switches.append(f'"{a["prompt"]}')
|
||||||
|
switches.append(f'-s {a["steps"]}')
|
||||||
|
switches.append(f'-W {a["width"]}')
|
||||||
|
switches.append(f'-H {a["height"]}')
|
||||||
|
switches.append(f'-C {a["cfg_scale"]}')
|
||||||
|
switches.append(f'-A {a["sampler_name"]}')
|
||||||
|
switches.append(f'-S {a["seed"]}')
|
||||||
|
if a['grid']:
|
||||||
|
switches.append('--grid')
|
||||||
|
if a['iterations'] and a['iterations']>0:
|
||||||
|
switches.append(f'-n {a["iterations"]}')
|
||||||
|
if a['seamless']:
|
||||||
|
switches.append('--seamless')
|
||||||
|
if a['init_img'] and len(a['init_img'])>0:
|
||||||
|
switches.append(f'-I {a["init_img"]}')
|
||||||
|
if a['fit']:
|
||||||
|
switches.append(f'--fit')
|
||||||
|
if a['strength'] and a['strength']>0:
|
||||||
|
switches.append(f'-f {a["strength"]}')
|
||||||
|
if a['gfpgan_strength']:
|
||||||
|
switches.append(f'-G {a["gfpgan_strength"]}')
|
||||||
|
if a['upscale']:
|
||||||
|
switches.append(f'-U {" ".join([str(u) for u in a["upscale"]])}')
|
||||||
|
if a['embiggen']:
|
||||||
|
switches.append(f'--embiggen {" ".join([str(u) for u in a["embiggen"]])}')
|
||||||
|
if a['embiggen_tiles']:
|
||||||
|
switches.append(f'--embiggen_tiles {" ".join([str(u) for u in a["embiggen_tiles"]])}')
|
||||||
|
if a['variation_amount'] > 0:
|
||||||
|
switches.append(f'-v {a["variation_amount"]}')
|
||||||
|
if a['with_variations']:
|
||||||
|
formatted_variations = ','.join(f'{seed}:{weight}' for seed, weight in (a["with_variations"]))
|
||||||
|
switches.append(f'-V {formatted_variations}')
|
||||||
|
return ' '.join(switches)
|
||||||
|
|
||||||
|
def __getattribute__(self,name):
|
||||||
|
'''
|
||||||
|
Returns union of command-line arguments and dream_prompt arguments,
|
||||||
|
with the latter superseding the former.
|
||||||
|
'''
|
||||||
|
cmd_switches = None
|
||||||
|
arg_switches = None
|
||||||
|
try:
|
||||||
|
cmd_switches = object.__getattribute__(self,'_cmd_switches')
|
||||||
|
arg_switches = object.__getattribute__(self,'_arg_switches')
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if cmd_switches and arg_switches and name=='__dict__':
|
||||||
|
a = arg_switches.__dict__
|
||||||
|
a.update(cmd_switches.__dict__)
|
||||||
|
return a
|
||||||
|
|
||||||
|
try:
|
||||||
|
return object.__getattribute__(self,name)
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if not hasattr(cmd_switches,name) and not hasattr(arg_switches,name):
|
||||||
|
raise AttributeError
|
||||||
|
|
||||||
|
value_arg,value_cmd = (None,None)
|
||||||
|
try:
|
||||||
|
value_cmd = getattr(cmd_switches,name)
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
value_arg = getattr(arg_switches,name)
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# here is where we can pick and choose which to use
|
||||||
|
# default behavior is to choose the dream_command value over
|
||||||
|
# the arg value. For example, the --grid and --individual options are a little
|
||||||
|
# funny because of their push/pull relationship. This is how to handle it.
|
||||||
|
if name=='grid':
|
||||||
|
return value_arg or value_cmd # arg supersedes cmd
|
||||||
|
if name=='individual':
|
||||||
|
return value_cmd or value_arg # cmd supersedes arg
|
||||||
|
if value_cmd is not None:
|
||||||
|
return value_cmd
|
||||||
|
else:
|
||||||
|
return value_arg
|
||||||
|
|
||||||
|
def __setattr__(self,name,value):
|
||||||
|
if name.startswith('_'):
|
||||||
|
object.__setattr__(self,name,value)
|
||||||
|
else:
|
||||||
|
self._cmd_switches.__dict__[name] = value
|
||||||
|
|
||||||
|
def _create_arg_parser(self):
|
||||||
|
'''
|
||||||
|
This defines all the arguments used on the command line when you launch
|
||||||
|
the CLI or web backend.
|
||||||
|
'''
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description=
|
||||||
|
"""
|
||||||
|
Generate images using Stable Diffusion.
|
||||||
|
Use --web to launch the web interface.
|
||||||
|
Use --from_file to load prompts from a file path or standard input ("-").
|
||||||
|
Otherwise you will be dropped into an interactive command prompt (type -h for help.)
|
||||||
|
Other command-line arguments are defaults that can usually be overridden
|
||||||
|
prompt the command prompt.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
model_group = parser.add_argument_group('Model selection')
|
||||||
|
file_group = parser.add_argument_group('Input/output')
|
||||||
|
web_server_group = parser.add_argument_group('Web server')
|
||||||
|
render_group = parser.add_argument_group('Rendering')
|
||||||
|
postprocessing_group = parser.add_argument_group('Postprocessing')
|
||||||
|
deprecated_group = parser.add_argument_group('Deprecated options')
|
||||||
|
|
||||||
|
deprecated_group.add_argument('--laion400m')
|
||||||
|
deprecated_group.add_argument('--weights') # deprecated
|
||||||
|
model_group.add_argument(
|
||||||
|
'--conf',
|
||||||
|
'-c',
|
||||||
|
'-conf',
|
||||||
|
dest='conf',
|
||||||
|
default='./configs/models.yaml',
|
||||||
|
help='Path to configuration file for alternate models.',
|
||||||
|
)
|
||||||
|
model_group.add_argument(
|
||||||
|
'--model',
|
||||||
|
default='stable-diffusion-1.4',
|
||||||
|
help='Indicates which diffusion model to load. (currently "stable-diffusion-1.4" (default) or "laion400m")',
|
||||||
|
)
|
||||||
|
model_group.add_argument(
|
||||||
|
'-F',
|
||||||
|
'--full_precision',
|
||||||
|
dest='full_precision',
|
||||||
|
action='store_true',
|
||||||
|
help='Use more memory-intensive full precision math for calculations',
|
||||||
|
)
|
||||||
|
file_group.add_argument(
|
||||||
|
'--from_file',
|
||||||
|
dest='infile',
|
||||||
|
type=str,
|
||||||
|
help='If specified, load prompts from this file',
|
||||||
|
)
|
||||||
|
file_group.add_argument(
|
||||||
|
'--outdir',
|
||||||
|
'-o',
|
||||||
|
type=str,
|
||||||
|
help='Directory to save generated images and a log of prompts and seeds. Default: outputs/img-samples',
|
||||||
|
default='outputs/img-samples',
|
||||||
|
)
|
||||||
|
file_group.add_argument(
|
||||||
|
'--prompt_as_dir',
|
||||||
|
'-p',
|
||||||
|
action='store_true',
|
||||||
|
help='Place images in subdirectories named after the prompt.',
|
||||||
|
)
|
||||||
|
render_group.add_argument(
|
||||||
|
'--seamless',
|
||||||
|
action='store_true',
|
||||||
|
help='Change the model to seamless tiling (circular) mode',
|
||||||
|
)
|
||||||
|
render_group.add_argument(
|
||||||
|
'--grid',
|
||||||
|
'-g',
|
||||||
|
action='store_true',
|
||||||
|
help='generate a grid'
|
||||||
|
)
|
||||||
|
render_group.add_argument(
|
||||||
|
'--embedding_path',
|
||||||
|
type=str,
|
||||||
|
help='Path to a pre-trained embedding manager checkpoint - can only be set on command line',
|
||||||
|
)
|
||||||
|
# GFPGAN related args
|
||||||
|
postprocessing_group.add_argument(
|
||||||
|
'--gfpgan_bg_upsampler',
|
||||||
|
type=str,
|
||||||
|
default='realesrgan',
|
||||||
|
help='Background upsampler. Default: realesrgan. Options: realesrgan, none.',
|
||||||
|
|
||||||
|
)
|
||||||
|
postprocessing_group.add_argument(
|
||||||
|
'--gfpgan_bg_tile',
|
||||||
|
type=int,
|
||||||
|
default=400,
|
||||||
|
help='Tile size for background sampler, 0 for no tile during testing. Default: 400.',
|
||||||
|
)
|
||||||
|
postprocessing_group.add_argument(
|
||||||
|
'--gfpgan_model_path',
|
||||||
|
type=str,
|
||||||
|
default='experiments/pretrained_models/GFPGANv1.3.pth',
|
||||||
|
help='Indicates the path to the GFPGAN model, relative to --gfpgan_dir.',
|
||||||
|
)
|
||||||
|
postprocessing_group.add_argument(
|
||||||
|
'--gfpgan_dir',
|
||||||
|
type=str,
|
||||||
|
default='./src/gfpgan',
|
||||||
|
help='Indicates the directory containing the GFPGAN code.',
|
||||||
|
)
|
||||||
|
web_server_group.add_argument(
|
||||||
|
'--web',
|
||||||
|
dest='web',
|
||||||
|
action='store_true',
|
||||||
|
help='Start in web server mode.',
|
||||||
|
)
|
||||||
|
web_server_group.add_argument(
|
||||||
|
'--host',
|
||||||
|
type=str,
|
||||||
|
default='127.0.0.1',
|
||||||
|
help='Web server: Host or IP to listen on. Set to 0.0.0.0 to accept traffic from other devices on your network.'
|
||||||
|
)
|
||||||
|
web_server_group.add_argument(
|
||||||
|
'--port',
|
||||||
|
type=int,
|
||||||
|
default='9090',
|
||||||
|
help='Web server: Port to listen on'
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
# This creates the parser that processes commands on the dream> command line
|
||||||
|
def _create_dream_cmd_parser(self):
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description='Example: dream> a fantastic alien landscape -W1024 -H960 -s100 -n12'
|
||||||
|
)
|
||||||
|
render_group = parser.add_argument_group('General rendering')
|
||||||
|
img2img_group = parser.add_argument_group('Image-to-image and inpainting')
|
||||||
|
variation_group = parser.add_argument_group('Creating and combining variations')
|
||||||
|
postprocessing_group = parser.add_argument_group('Post-processing')
|
||||||
|
special_effects_group = parser.add_argument_group('Special effects')
|
||||||
|
render_group.add_argument('prompt')
|
||||||
|
render_group.add_argument(
|
||||||
|
'-s',
|
||||||
|
'--steps',
|
||||||
|
type=int,
|
||||||
|
default=50,
|
||||||
|
help='Number of steps'
|
||||||
|
)
|
||||||
|
render_group.add_argument(
|
||||||
|
'-S',
|
||||||
|
'--seed',
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help='Image seed; a +ve integer, or use -1 for the previous seed, -2 for the one before that, etc',
|
||||||
|
)
|
||||||
|
render_group.add_argument(
|
||||||
|
'-n',
|
||||||
|
'--iterations',
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help='Number of samplings to perform (slower, but will provide seeds for individual images)',
|
||||||
|
)
|
||||||
|
render_group.add_argument(
|
||||||
|
'-W',
|
||||||
|
'--width',
|
||||||
|
type=int,
|
||||||
|
help='Image width, multiple of 64',
|
||||||
|
default=512
|
||||||
|
)
|
||||||
|
render_group.add_argument(
|
||||||
|
'-H',
|
||||||
|
'--height',
|
||||||
|
type=int,
|
||||||
|
help='Image height, multiple of 64',
|
||||||
|
default=512,
|
||||||
|
)
|
||||||
|
render_group.add_argument(
|
||||||
|
'-C',
|
||||||
|
'--cfg_scale',
|
||||||
|
default=7.5,
|
||||||
|
type=float,
|
||||||
|
help='Classifier free guidance (CFG) scale - higher numbers cause generator to "try" harder.',
|
||||||
|
)
|
||||||
|
render_group.add_argument(
|
||||||
|
'--grid',
|
||||||
|
'-g',
|
||||||
|
action='store_true',
|
||||||
|
help='generate a grid'
|
||||||
|
)
|
||||||
|
render_group.add_argument(
|
||||||
|
'--individual',
|
||||||
|
'-i',
|
||||||
|
action='store_true',
|
||||||
|
help='override command-line --grid setting and generate individual images'
|
||||||
|
)
|
||||||
|
render_group.add_argument(
|
||||||
|
'-x',
|
||||||
|
'--skip_normalize',
|
||||||
|
action='store_true',
|
||||||
|
help='Skip subprompt weight normalization',
|
||||||
|
)
|
||||||
|
render_group.add_argument(
|
||||||
|
'-A',
|
||||||
|
'-m',
|
||||||
|
'--sampler',
|
||||||
|
dest='sampler_name',
|
||||||
|
type=str,
|
||||||
|
choices=SAMPLER_CHOICES,
|
||||||
|
metavar='SAMPLER_NAME',
|
||||||
|
help=f'Switch to a different sampler. Supported samplers: {", ".join(SAMPLER_CHOICES)}',
|
||||||
|
default='k_lms',
|
||||||
|
)
|
||||||
|
render_group.add_argument(
|
||||||
|
'-t',
|
||||||
|
'--log_tokenization',
|
||||||
|
action='store_true',
|
||||||
|
help='shows how the prompt is split into tokens'
|
||||||
|
)
|
||||||
|
render_group.add_argument(
|
||||||
|
'--outdir',
|
||||||
|
'-o',
|
||||||
|
type=str,
|
||||||
|
default='outputs/img-samples',
|
||||||
|
help='Directory to save generated images and a log of prompts and seeds',
|
||||||
|
)
|
||||||
|
img2img_group.add_argument(
|
||||||
|
'-I',
|
||||||
|
'--init_img',
|
||||||
|
type=str,
|
||||||
|
help='Path to input image for img2img mode (supersedes width and height)',
|
||||||
|
)
|
||||||
|
img2img_group.add_argument(
|
||||||
|
'-M',
|
||||||
|
'--init_mask',
|
||||||
|
type=str,
|
||||||
|
help='Path to input mask for inpainting mode (supersedes width and height)',
|
||||||
|
)
|
||||||
|
img2img_group.add_argument(
|
||||||
|
'-T',
|
||||||
|
'-fit',
|
||||||
|
'--fit',
|
||||||
|
action='store_true',
|
||||||
|
help='If specified, will resize the input image to fit within the dimensions of width x height (512x512 default)',
|
||||||
|
)
|
||||||
|
img2img_group.add_argument(
|
||||||
|
'-f',
|
||||||
|
'--strength',
|
||||||
|
type=float,
|
||||||
|
help='Strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely',
|
||||||
|
default=0.75,
|
||||||
|
)
|
||||||
|
postprocessing_group.add_argument(
|
||||||
|
'-G',
|
||||||
|
'--gfpgan_strength',
|
||||||
|
type=float,
|
||||||
|
help='The strength at which to apply the GFPGAN model to the result, in order to improve faces.',
|
||||||
|
default=0,
|
||||||
|
)
|
||||||
|
postprocessing_group.add_argument(
|
||||||
|
'-U',
|
||||||
|
'--upscale',
|
||||||
|
nargs='+',
|
||||||
|
type=float,
|
||||||
|
help='Scale factor (2, 4) for upscaling final output followed by upscaling strength (0-1.0). If strength not specified, defaults to 0.75',
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
postprocessing_group.add_argument(
|
||||||
|
'--save_original',
|
||||||
|
'-save_orig',
|
||||||
|
action='store_true',
|
||||||
|
help='Save original. Use it when upscaling to save both versions.',
|
||||||
|
)
|
||||||
|
postprocessing_group.add_argument(
|
||||||
|
'--embiggen',
|
||||||
|
'-embiggen',
|
||||||
|
nargs='+',
|
||||||
|
type=float,
|
||||||
|
help='Embiggen tiled img2img for higher resolution and detail without extra VRAM usage. Takes scale factor relative to the size of the --init_img (-I), followed by ESRGAN upscaling strength (0-1.0), followed by minimum amount of overlap between tiles as a decimal ratio (0 - 1.0) or number of pixels. ESRGAN strength defaults to 0.75, and overlap defaults to 0.25 . ESRGAN is used to upscale the init prior to cutting it into tiles/pieces to run through img2img and then stitch back togeather.',
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
postprocessing_group.add_argument(
|
||||||
|
'--embiggen_tiles',
|
||||||
|
'-embiggen_tiles',
|
||||||
|
nargs='+',
|
||||||
|
type=int,
|
||||||
|
help='If while doing Embiggen we are altering only parts of the image, takes a list of tiles by number to process and replace onto the image e.g. `1 3 5`, useful for redoing problematic spots from a prior Embiggen run',
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
special_effects_group.add_argument(
|
||||||
|
'--seamless',
|
||||||
|
action='store_true',
|
||||||
|
help='Change the model to seamless tiling (circular) mode',
|
||||||
|
)
|
||||||
|
variation_group.add_argument(
|
||||||
|
'-v',
|
||||||
|
'--variation_amount',
|
||||||
|
default=0.0,
|
||||||
|
type=float,
|
||||||
|
help='If > 0, generates variations on the initial seed instead of random seeds per iteration. Must be between 0 and 1. Higher values will be more different.'
|
||||||
|
)
|
||||||
|
variation_group.add_argument(
|
||||||
|
'-V',
|
||||||
|
'--with_variations',
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
help='list of variations to apply, in the format `seed:weight,seed:weight,...'
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
# very partial implementation of https://github.com/lstein/stable-diffusion/issues/266
|
||||||
|
# it does not write all the required top-level metadata, writes too much image
|
||||||
|
# data, and doesn't support grids yet. But you gotta start somewhere, no?
|
||||||
|
def format_metadata(opt,
|
||||||
|
seeds=[],
|
||||||
|
weights=None,
|
||||||
|
model_hash=None,
|
||||||
|
postprocessing=None):
|
||||||
|
'''
|
||||||
|
Given an Args object, returns a partial implementation of
|
||||||
|
the stable diffusion metadata standard
|
||||||
|
'''
|
||||||
|
# add some RFC266 fields that are generated internally, and not as
|
||||||
|
# user args
|
||||||
|
image_dict = opt.to_dict(
|
||||||
|
postprocessing=postprocessing
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: This is just a hack until postprocessing pipeline work completed
|
||||||
|
image_dict['postprocessing'] = []
|
||||||
|
if image_dict['gfpgan_strength'] and image_dict['gfpgan_strength'] > 0:
|
||||||
|
image_dict['postprocessing'].append('GFPGAN (not RFC compliant)')
|
||||||
|
if image_dict['upscale'] and image_dict['upscale'][0] > 0:
|
||||||
|
image_dict['postprocessing'].append('ESRGAN (not RFC compliant)')
|
||||||
|
|
||||||
|
# remove any image keys not mentioned in RFC #266
|
||||||
|
rfc266_img_fields = ['type','postprocessing','sampler','prompt','seed','variations','steps',
|
||||||
|
'cfg_scale','step_number','width','height','extra','strength']
|
||||||
|
|
||||||
|
rfc_dict ={}
|
||||||
|
for item in image_dict.items():
|
||||||
|
key,value = item
|
||||||
|
if key in rfc266_img_fields:
|
||||||
|
rfc_dict[key] = value
|
||||||
|
|
||||||
|
# semantic drift
|
||||||
|
rfc_dict['sampler'] = image_dict.get('sampler_name',None)
|
||||||
|
|
||||||
|
# display weighted subprompts (liable to change)
|
||||||
|
if opt.prompt:
|
||||||
|
subprompts = split_weighted_subprompts(opt.prompt)
|
||||||
|
subprompts = [{'prompt':x[0],'weight':x[1]} for x in subprompts]
|
||||||
|
rfc_dict['prompt'] = subprompts
|
||||||
|
|
||||||
|
# variations
|
||||||
|
if opt.with_variations:
|
||||||
|
variations = [{'seed':x[0],'weight':x[1]} for x in opt.with_variations]
|
||||||
|
rfc_dict['variations'] = variations
|
||||||
|
|
||||||
|
if opt.init_img:
|
||||||
|
rfc_dict['type'] = 'img2img'
|
||||||
|
rfc_dict['strength_steps'] = rfc_dict.pop('strength')
|
||||||
|
rfc_dict['orig_hash'] = sha256(image_dict['init_img'])
|
||||||
|
rfc_dict['sampler'] = 'ddim' # FIX ME WHEN IMG2IMG SUPPORTS ALL SAMPLERS
|
||||||
|
else:
|
||||||
|
rfc_dict['type'] = 'txt2img'
|
||||||
|
|
||||||
|
images = []
|
||||||
|
for seed in seeds:
|
||||||
|
rfc_dict['seed'] = seed
|
||||||
|
images.append(copy.copy(rfc_dict))
|
||||||
|
|
||||||
|
return {
|
||||||
|
'model' : 'stable diffusion',
|
||||||
|
'model_id' : opt.model,
|
||||||
|
'model_hash' : model_hash,
|
||||||
|
'app_id' : APP_ID,
|
||||||
|
'app_version' : APP_VERSION,
|
||||||
|
'images' : images,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Bah. This should be moved somewhere else...
|
||||||
|
def sha256(path):
|
||||||
|
sha = hashlib.sha256()
|
||||||
|
with open(path,'rb') as f:
|
||||||
|
while True:
|
||||||
|
data = f.read(65536)
|
||||||
|
if not data:
|
||||||
|
break
|
||||||
|
sha.update(data)
|
||||||
|
return sha.hexdigest()
|
||||||
|
|
@ -3,12 +3,13 @@ Two helper classes for dealing with PNG images and their path names.
|
|||||||
PngWriter -- Converts Images generated by T2I into PNGs, finds
|
PngWriter -- Converts Images generated by T2I into PNGs, finds
|
||||||
appropriate names for them, and writes prompt metadata
|
appropriate names for them, and writes prompt metadata
|
||||||
into the PNG.
|
into the PNG.
|
||||||
PromptFormatter -- Utility for converting a Namespace of prompt parameters
|
|
||||||
back into a formatted prompt string with command-line switches.
|
Exports function retrieve_metadata(path)
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from PIL import PngImagePlugin
|
import json
|
||||||
|
from PIL import PngImagePlugin, Image
|
||||||
|
|
||||||
# -------------------image generation utils-----
|
# -------------------image generation utils-----
|
||||||
|
|
||||||
@ -32,54 +33,32 @@ class PngWriter:
|
|||||||
|
|
||||||
# saves image named _image_ to outdir/name, writing metadata from prompt
|
# saves image named _image_ to outdir/name, writing metadata from prompt
|
||||||
# returns full path of output
|
# returns full path of output
|
||||||
def save_image_and_prompt_to_png(self, image, prompt, name):
|
def save_image_and_prompt_to_png(self, image, dream_prompt, name, metadata=None):
|
||||||
|
print(f'self.outdir={self.outdir}, name={name}')
|
||||||
path = os.path.join(self.outdir, name)
|
path = os.path.join(self.outdir, name)
|
||||||
info = PngImagePlugin.PngInfo()
|
info = PngImagePlugin.PngInfo()
|
||||||
info.add_text('Dream', prompt)
|
info.add_text('Dream', dream_prompt)
|
||||||
|
if metadata: # TODO: merge command line app's method of writing metadata and always just write metadata
|
||||||
|
info.add_text('sd-metadata', json.dumps(metadata))
|
||||||
image.save(path, 'PNG', pnginfo=info)
|
image.save(path, 'PNG', pnginfo=info)
|
||||||
return path
|
return path
|
||||||
|
|
||||||
|
def retrieve_metadata(self,img_basename):
|
||||||
|
'''
|
||||||
|
Given a PNG filename stored in outdir, returns the "sd-metadata"
|
||||||
|
metadata stored there, as a dict
|
||||||
|
'''
|
||||||
|
path = os.path.join(self.outdir,img_basename)
|
||||||
|
all_metadata = retrieve_metadata(path)
|
||||||
|
return all_metadata['sd-metadata']
|
||||||
|
|
||||||
class PromptFormatter:
|
def retrieve_metadata(img_path):
|
||||||
def __init__(self, t2i, opt):
|
'''
|
||||||
self.t2i = t2i
|
Given a path to a PNG image, returns the "sd-metadata"
|
||||||
self.opt = opt
|
metadata stored there, as a dict
|
||||||
|
'''
|
||||||
|
im = Image.open(img_path)
|
||||||
|
md = im.text.get('sd-metadata', '{}')
|
||||||
|
dream_prompt = im.text.get('Dream', '')
|
||||||
|
return {'sd-metadata': json.loads(md), 'Dream': dream_prompt}
|
||||||
|
|
||||||
# note: the t2i object should provide all these values.
|
|
||||||
# there should be no need to or against opt values
|
|
||||||
def normalize_prompt(self):
|
|
||||||
"""Normalize the prompt and switches"""
|
|
||||||
t2i = self.t2i
|
|
||||||
opt = self.opt
|
|
||||||
|
|
||||||
switches = list()
|
|
||||||
switches.append(f'"{opt.prompt}"')
|
|
||||||
switches.append(f'-s{opt.steps or t2i.steps}')
|
|
||||||
switches.append(f'-W{opt.width or t2i.width}')
|
|
||||||
switches.append(f'-H{opt.height or t2i.height}')
|
|
||||||
switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}')
|
|
||||||
switches.append(f'-A{opt.sampler_name or t2i.sampler_name}')
|
|
||||||
# to do: put model name into the t2i object
|
|
||||||
# switches.append(f'--model{t2i.model_name}')
|
|
||||||
if opt.seamless or t2i.seamless:
|
|
||||||
switches.append(f'--seamless')
|
|
||||||
if opt.init_img:
|
|
||||||
switches.append(f'-I{opt.init_img}')
|
|
||||||
if opt.fit:
|
|
||||||
switches.append(f'--fit')
|
|
||||||
if opt.strength and opt.init_img is not None:
|
|
||||||
switches.append(f'-f{opt.strength or t2i.strength}')
|
|
||||||
if opt.gfpgan_strength:
|
|
||||||
switches.append(f'-G{opt.gfpgan_strength}')
|
|
||||||
if opt.upscale:
|
|
||||||
switches.append(f'-U {" ".join([str(u) for u in opt.upscale])}')
|
|
||||||
if hasattr(opt, 'embiggen') and opt.embiggen:
|
|
||||||
switches.append(f'-embiggen {" ".join([str(u) for u in opt.embiggen])}')
|
|
||||||
if hasattr(opt, 'embiggen_tiles') and opt.embiggen_tiles:
|
|
||||||
switches.append(f'-embiggen_tiles {" ".join([str(u) for u in opt.embiggen_tiles])}')
|
|
||||||
if opt.variation_amount > 0:
|
|
||||||
switches.append(f'-v{opt.variation_amount}')
|
|
||||||
if opt.with_variations:
|
|
||||||
formatted_variations = ','.join(f'{seed}:{weight}' for seed, weight in opt.with_variations)
|
|
||||||
switches.append(f'-V{formatted_variations}')
|
|
||||||
return ' '.join(switches)
|
|
||||||
|
@ -1,14 +1,17 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
|
import copy
|
||||||
import base64
|
import base64
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import os
|
import os
|
||||||
|
from ldm.dream.args import Args, format_metadata
|
||||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||||
from ldm.dream.pngwriter import PngWriter, PromptFormatter
|
from ldm.dream.pngwriter import PngWriter
|
||||||
from threading import Event
|
from threading import Event
|
||||||
|
|
||||||
def build_opt(post_data, seed, gfpgan_model_exists):
|
def build_opt(post_data, seed, gfpgan_model_exists):
|
||||||
opt = argparse.Namespace()
|
opt = Args()
|
||||||
|
opt.parse_args() # initialize defaults
|
||||||
setattr(opt, 'prompt', post_data['prompt'])
|
setattr(opt, 'prompt', post_data['prompt'])
|
||||||
setattr(opt, 'init_img', post_data['initimg'])
|
setattr(opt, 'init_img', post_data['initimg'])
|
||||||
setattr(opt, 'strength', float(post_data['strength']))
|
setattr(opt, 'strength', float(post_data['strength']))
|
||||||
@ -42,7 +45,7 @@ def build_opt(post_data, seed, gfpgan_model_exists):
|
|||||||
for part in post_data['with_variations'].split(','):
|
for part in post_data['with_variations'].split(','):
|
||||||
seed_and_weight = part.split(':')
|
seed_and_weight = part.split(':')
|
||||||
if len(seed_and_weight) != 2:
|
if len(seed_and_weight) != 2:
|
||||||
print(f'could not parse with_variation part "{part}"')
|
print(f'could not parse WITH_variation part "{part}"')
|
||||||
broken = True
|
broken = True
|
||||||
break
|
break
|
||||||
try:
|
try:
|
||||||
@ -160,10 +163,10 @@ class DreamServer(BaseHTTPRequestHandler):
|
|||||||
# the images are first generated, and then again when after upscaling
|
# the images are first generated, and then again when after upscaling
|
||||||
# is complete. The upscaling replaces the original file, so the second
|
# is complete. The upscaling replaces the original file, so the second
|
||||||
# entry should not be inserted into the image list.
|
# entry should not be inserted into the image list.
|
||||||
|
# LS: This repeats code in dream.py
|
||||||
def image_done(image, seed, upscaled=False):
|
def image_done(image, seed, upscaled=False):
|
||||||
name = f'{prefix}.{seed}.png'
|
name = f'{prefix}.{seed}.png'
|
||||||
iter_opt = argparse.Namespace(**vars(opt)) # copy
|
iter_opt = copy.copy(opt)
|
||||||
print(f'iter_opt = {iter_opt}')
|
|
||||||
if opt.variation_amount > 0:
|
if opt.variation_amount > 0:
|
||||||
this_variation = [[seed, opt.variation_amount]]
|
this_variation = [[seed, opt.variation_amount]]
|
||||||
if opt.with_variations is None:
|
if opt.with_variations is None:
|
||||||
@ -171,10 +174,17 @@ class DreamServer(BaseHTTPRequestHandler):
|
|||||||
else:
|
else:
|
||||||
iter_opt.with_variations = opt.with_variations + this_variation
|
iter_opt.with_variations = opt.with_variations + this_variation
|
||||||
iter_opt.variation_amount = 0
|
iter_opt.variation_amount = 0
|
||||||
elif opt.with_variations is None:
|
formatted_prompt = opt.dream_prompt_str(seed=seed)
|
||||||
iter_opt.seed = seed
|
path = pngwriter.save_image_and_prompt_to_png(
|
||||||
normalized_prompt = PromptFormatter(self.model, iter_opt).normalize_prompt()
|
image,
|
||||||
path = pngwriter.save_image_and_prompt_to_png(image, f'{normalized_prompt} -S{iter_opt.seed}', name)
|
dream_prompt = formatted_prompt,
|
||||||
|
metadata = format_metadata(iter_opt,
|
||||||
|
seeds = [seed],
|
||||||
|
weights = self.model.weights,
|
||||||
|
model_hash = self.model.model_hash
|
||||||
|
),
|
||||||
|
name = name,
|
||||||
|
)
|
||||||
|
|
||||||
if int(config['seed']) == -1:
|
if int(config['seed']) == -1:
|
||||||
config['seed'] = seed
|
config['seed'] = seed
|
||||||
@ -220,9 +230,10 @@ class DreamServer(BaseHTTPRequestHandler):
|
|||||||
nonlocal step_index
|
nonlocal step_index
|
||||||
if opt.progress_images and step % 5 == 0 and step < opt.steps - 1:
|
if opt.progress_images and step % 5 == 0 and step < opt.steps - 1:
|
||||||
image = self.model.sample_to_image(sample)
|
image = self.model.sample_to_image(sample)
|
||||||
name = f'{prefix}.{opt.seed}.{step_index}.png'
|
step_index_padded = str(step_index).rjust(len(str(opt.steps)), '0')
|
||||||
|
name = f'{prefix}.{opt.seed}.{step_index_padded}.png'
|
||||||
metadata = f'{opt.prompt} -S{opt.seed} [intermediate]'
|
metadata = f'{opt.prompt} -S{opt.seed} [intermediate]'
|
||||||
path = step_writer.save_image_and_prompt_to_png(image, metadata, name)
|
path = step_writer.save_image_and_prompt_to_png(image, dream_prompt=metadata, name=name)
|
||||||
step_index += 1
|
step_index += 1
|
||||||
self.wfile.write(bytes(json.dumps(
|
self.wfile.write(bytes(json.dumps(
|
||||||
{'event': 'step', 'step': step + 1, 'url': path}
|
{'event': 'step', 'step': step + 1, 'url': path}
|
||||||
|
@ -13,6 +13,8 @@ import re
|
|||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import transformers
|
import transformers
|
||||||
|
import io
|
||||||
|
import hashlib
|
||||||
|
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
@ -179,7 +181,7 @@ class Generate:
|
|||||||
for image, seed in results:
|
for image, seed in results:
|
||||||
name = f'{prefix}.{seed}.png'
|
name = f'{prefix}.{seed}.png'
|
||||||
path = pngwriter.save_image_and_prompt_to_png(
|
path = pngwriter.save_image_and_prompt_to_png(
|
||||||
image, f'{prompt} -S{seed}', name)
|
image, dream_prompt=f'{prompt} -S{seed}', name=name)
|
||||||
outputs.append([path, seed])
|
outputs.append([path, seed])
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@ -577,7 +579,11 @@ class Generate:
|
|||||||
|
|
||||||
# this does the work
|
# this does the work
|
||||||
c = OmegaConf.load(config)
|
c = OmegaConf.load(config)
|
||||||
pl_sd = torch.load(weights, map_location='cpu')
|
with open(weights,'rb') as f:
|
||||||
|
weight_bytes = f.read()
|
||||||
|
self.model_hash = self._cached_sha256(weights,weight_bytes)
|
||||||
|
pl_sd = torch.load(io.BytesIO(weight_bytes), map_location='cpu')
|
||||||
|
del weight_bytes
|
||||||
sd = pl_sd['state_dict']
|
sd = pl_sd['state_dict']
|
||||||
model = instantiate_from_config(c.model)
|
model = instantiate_from_config(c.model)
|
||||||
m, u = model.load_state_dict(sd, strict=False)
|
m, u = model.load_state_dict(sd, strict=False)
|
||||||
@ -738,3 +744,24 @@ class Generate:
|
|||||||
|
|
||||||
def _has_cuda(self):
|
def _has_cuda(self):
|
||||||
return self.device.type == 'cuda'
|
return self.device.type == 'cuda'
|
||||||
|
|
||||||
|
def _cached_sha256(self,path,data):
|
||||||
|
dirname = os.path.dirname(path)
|
||||||
|
basename = os.path.basename(path)
|
||||||
|
base, _ = os.path.splitext(basename)
|
||||||
|
hashpath = os.path.join(dirname,base+'.sha256')
|
||||||
|
if os.path.exists(hashpath) and os.path.getmtime(path) <= os.path.getmtime(hashpath):
|
||||||
|
with open(hashpath) as f:
|
||||||
|
hash = f.read()
|
||||||
|
return hash
|
||||||
|
print(f'>> Calculating sha256 hash of weights file')
|
||||||
|
tic = time.time()
|
||||||
|
sha = hashlib.sha256()
|
||||||
|
sha.update(data)
|
||||||
|
hash = sha.hexdigest()
|
||||||
|
toc = time.time()
|
||||||
|
print(f'>> sha256 = {hash}','(%4.2fs)' % (toc - tic))
|
||||||
|
with open(hashpath,'w') as f:
|
||||||
|
f.write(hash)
|
||||||
|
return hash
|
||||||
|
|
||||||
|
@ -5,10 +5,11 @@ import sys
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from scripts.dream import create_argv_parser
|
#from scripts.dream import create_argv_parser
|
||||||
|
from ldm.dream.args import Args
|
||||||
|
|
||||||
arg_parser = create_argv_parser()
|
opt = Args()
|
||||||
opt = arg_parser.parse_args()
|
opt.parse_args()
|
||||||
model_path = os.path.join(opt.gfpgan_dir, opt.gfpgan_model_path)
|
model_path = os.path.join(opt.gfpgan_dir, opt.gfpgan_model_path)
|
||||||
gfpgan_model_exists = os.path.isfile(model_path)
|
gfpgan_model_exists = os.path.isfile(model_path)
|
||||||
|
|
||||||
|
@ -82,7 +82,9 @@ class EmbeddingManager(nn.Module):
|
|||||||
get_embedding_for_clip_token,
|
get_embedding_for_clip_token,
|
||||||
embedder.transformer.text_model.embeddings,
|
embedder.transformer.text_model.embeddings,
|
||||||
)
|
)
|
||||||
token_dim = 1280
|
# per bug report #572
|
||||||
|
#token_dim = 1280
|
||||||
|
token_dim = 768
|
||||||
else: # using LDM's BERT encoder
|
else: # using LDM's BERT encoder
|
||||||
self.is_clip = False
|
self.is_clip = False
|
||||||
get_token_for_string = partial(
|
get_token_for_string = partial(
|
||||||
|
@ -22,6 +22,11 @@ test-tube
|
|||||||
torch-fidelity
|
torch-fidelity
|
||||||
torchmetrics
|
torchmetrics
|
||||||
transformers
|
transformers
|
||||||
|
flask==2.1.3
|
||||||
|
flask_socketio==5.3.0
|
||||||
|
flask_cors==3.0.10
|
||||||
|
dependency_injector==4.40.0
|
||||||
|
eventlet
|
||||||
git+https://github.com/openai/CLIP.git@main#egg=clip
|
git+https://github.com/openai/CLIP.git@main#egg=clip
|
||||||
git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k-diffusion
|
git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k-diffusion
|
||||||
git+https://github.com/lstein/GFPGAN@fix-dark-cast-images#egg=gfpgan
|
git+https://github.com/lstein/GFPGAN@fix-dark-cast-images#egg=gfpgan
|
||||||
|
438
scripts/dream.py
Executable file → Normal file
438
scripts/dream.py
Executable file → Normal file
@ -1,8 +1,6 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
|
# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
|
||||||
|
|
||||||
import argparse
|
|
||||||
import shlex
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
@ -10,7 +8,8 @@ import copy
|
|||||||
import warnings
|
import warnings
|
||||||
import time
|
import time
|
||||||
import ldm.dream.readline
|
import ldm.dream.readline
|
||||||
from ldm.dream.pngwriter import PngWriter, PromptFormatter
|
from ldm.dream.args import Args, format_metadata
|
||||||
|
from ldm.dream.pngwriter import PngWriter
|
||||||
from ldm.dream.server import DreamServer, ThreadingDreamServer
|
from ldm.dream.server import DreamServer, ThreadingDreamServer
|
||||||
from ldm.dream.image_util import make_grid
|
from ldm.dream.image_util import make_grid
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
@ -22,14 +21,16 @@ output_cntr = 0
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""Initialize command-line parsers and the diffusion model"""
|
"""Initialize command-line parsers and the diffusion model"""
|
||||||
arg_parser = create_argv_parser()
|
opt = Args()
|
||||||
opt = arg_parser.parse_args()
|
args = opt.parse_args()
|
||||||
|
if not args:
|
||||||
|
sys.exit(-1)
|
||||||
|
|
||||||
if opt.laion400m:
|
if args.laion400m:
|
||||||
print('--laion400m flag has been deprecated. Please use --model laion400m instead.')
|
print('--laion400m flag has been deprecated. Please use --model laion400m instead.')
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
if opt.weights != 'model':
|
if args.weights:
|
||||||
print('--weights argument has been deprecated. Please configure ./configs/models.yaml, and call it using --model instead.')
|
print('--weights argument has been deprecated. Please edit ./configs/models.yaml, and select the weights using --model instead.')
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
|
|
||||||
print('* Initializing, be patient...\n')
|
print('* Initializing, be patient...\n')
|
||||||
@ -47,7 +48,7 @@ def main():
|
|||||||
# the user input loop
|
# the user input loop
|
||||||
try:
|
try:
|
||||||
gen = Generate(
|
gen = Generate(
|
||||||
conf = opt.config,
|
conf = opt.conf,
|
||||||
model = opt.model,
|
model = opt.model,
|
||||||
sampler_name = opt.sampler_name,
|
sampler_name = opt.sampler_name,
|
||||||
embedding_path = opt.embedding_path,
|
embedding_path = opt.embedding_path,
|
||||||
@ -91,11 +92,10 @@ def main():
|
|||||||
dream_server_loop(gen, opt.host, opt.port, opt.outdir)
|
dream_server_loop(gen, opt.host, opt.port, opt.outdir)
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
cmd_parser = create_cmd_parser()
|
main_loop(gen, opt, infile)
|
||||||
main_loop(gen, opt.outdir, opt.prompt_as_dir, cmd_parser, infile)
|
|
||||||
|
|
||||||
# TODO: main_loop() has gotten busy. Needs to be refactored.
|
# TODO: main_loop() has gotten busy. Needs to be refactored.
|
||||||
def main_loop(gen, outdir, prompt_as_dir, parser, infile):
|
def main_loop(gen, opt, infile):
|
||||||
"""prompt/read/execute loop"""
|
"""prompt/read/execute loop"""
|
||||||
done = False
|
done = False
|
||||||
path_filter = re.compile(r'[<>:"/\\|?*]')
|
path_filter = re.compile(r'[<>:"/\\|?*]')
|
||||||
@ -103,8 +103,8 @@ def main_loop(gen, outdir, prompt_as_dir, parser, infile):
|
|||||||
|
|
||||||
# os.pathconf is not available on Windows
|
# os.pathconf is not available on Windows
|
||||||
if hasattr(os, 'pathconf'):
|
if hasattr(os, 'pathconf'):
|
||||||
path_max = os.pathconf(outdir, 'PC_PATH_MAX')
|
path_max = os.pathconf(opt.outdir, 'PC_PATH_MAX')
|
||||||
name_max = os.pathconf(outdir, 'PC_NAME_MAX')
|
name_max = os.pathconf(opt.outdir, 'PC_NAME_MAX')
|
||||||
else:
|
else:
|
||||||
path_max = 260
|
path_max = 260
|
||||||
name_max = 255
|
name_max = 255
|
||||||
@ -123,41 +123,17 @@ def main_loop(gen, outdir, prompt_as_dir, parser, infile):
|
|||||||
if command.startswith(('#', '//')):
|
if command.startswith(('#', '//')):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# before splitting, escape single quotes so as not to mess
|
if command.startswith('q '):
|
||||||
# up the parser
|
|
||||||
command = command.replace("'", "\\'")
|
|
||||||
|
|
||||||
try:
|
|
||||||
elements = shlex.split(command)
|
|
||||||
except ValueError as e:
|
|
||||||
print(str(e))
|
|
||||||
continue
|
|
||||||
|
|
||||||
if elements[0] == 'q':
|
|
||||||
done = True
|
done = True
|
||||||
break
|
break
|
||||||
|
|
||||||
if elements[0].startswith(
|
if command.startswith(
|
||||||
'!dream'
|
'!dream'
|
||||||
): # in case a stored prompt still contains the !dream command
|
): # in case a stored prompt still contains the !dream command
|
||||||
elements.pop(0)
|
command.replace('!dream','',1)
|
||||||
|
|
||||||
# rearrange the arguments to mimic how it works in the Dream bot.
|
|
||||||
switches = ['']
|
|
||||||
switches_started = False
|
|
||||||
|
|
||||||
for el in elements:
|
|
||||||
if el[0] == '-' and not switches_started:
|
|
||||||
switches_started = True
|
|
||||||
if switches_started:
|
|
||||||
switches.append(el)
|
|
||||||
else:
|
|
||||||
switches[0] += el
|
|
||||||
switches[0] += ' '
|
|
||||||
switches[0] = switches[0][: len(switches[0]) - 1]
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
opt = parser.parse_args(switches)
|
parser = opt.parse_cmd(command)
|
||||||
except SystemExit:
|
except SystemExit:
|
||||||
parser.print_help()
|
parser.print_help()
|
||||||
continue
|
continue
|
||||||
@ -185,6 +161,7 @@ def main_loop(gen, outdir, prompt_as_dir, parser, infile):
|
|||||||
opt.seed = None
|
opt.seed = None
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# TODO - move this into a module
|
||||||
if opt.with_variations is not None:
|
if opt.with_variations is not None:
|
||||||
# shotgun parsing, woo
|
# shotgun parsing, woo
|
||||||
parts = []
|
parts = []
|
||||||
@ -220,7 +197,7 @@ def main_loop(gen, outdir, prompt_as_dir, parser, infile):
|
|||||||
|
|
||||||
# truncate path to maximum allowed length
|
# truncate path to maximum allowed length
|
||||||
# 27 is the length of '######.##########.##.png', plus two separators and a NUL
|
# 27 is the length of '######.##########.##.png', plus two separators and a NUL
|
||||||
subdir = subdir[:(path_max - 27 - len(os.path.abspath(outdir)))]
|
subdir = subdir[:(path_max - 27 - len(os.path.abspath(opt.outdir)))]
|
||||||
current_outdir = os.path.join(outdir, subdir)
|
current_outdir = os.path.join(outdir, subdir)
|
||||||
|
|
||||||
print('Writing files to directory: "' + current_outdir + '"')
|
print('Writing files to directory: "' + current_outdir + '"')
|
||||||
@ -248,31 +225,36 @@ def main_loop(gen, outdir, prompt_as_dir, parser, infile):
|
|||||||
filename = f'{prefix}.{seed}.postprocessed.png'
|
filename = f'{prefix}.{seed}.postprocessed.png'
|
||||||
else:
|
else:
|
||||||
filename = f'{prefix}.{seed}.png'
|
filename = f'{prefix}.{seed}.png'
|
||||||
|
# the handling of variations is probably broken
|
||||||
|
# Also, given the ability to add stuff to the dream_prompt_str, it isn't
|
||||||
|
# necessary to make a copy of the opt option just to change its attributes
|
||||||
if opt.variation_amount > 0:
|
if opt.variation_amount > 0:
|
||||||
iter_opt = argparse.Namespace(**vars(opt)) # copy
|
iter_opt = copy.copy(opt)
|
||||||
this_variation = [[seed, opt.variation_amount]]
|
this_variation = [[seed, opt.variation_amount]]
|
||||||
if opt.with_variations is None:
|
if opt.with_variations is None:
|
||||||
iter_opt.with_variations = this_variation
|
iter_opt.with_variations = this_variation
|
||||||
else:
|
else:
|
||||||
iter_opt.with_variations = opt.with_variations + this_variation
|
iter_opt.with_variations = opt.with_variations + this_variation
|
||||||
iter_opt.variation_amount = 0
|
iter_opt.variation_amount = 0
|
||||||
normalized_prompt = PromptFormatter(
|
formatted_dream_prompt = iter_opt.dream_prompt_str(seed=seed)
|
||||||
gen, iter_opt).normalize_prompt()
|
|
||||||
metadata_prompt = f'{normalized_prompt} -S{iter_opt.seed}'
|
|
||||||
elif opt.with_variations is not None:
|
elif opt.with_variations is not None:
|
||||||
normalized_prompt = PromptFormatter(
|
formatted_dream_prompt = opt.dream_prompt_str(seed=seed)
|
||||||
gen, opt).normalize_prompt()
|
|
||||||
# use the original seed - the per-iteration value is the last variation-seed
|
|
||||||
metadata_prompt = f'{normalized_prompt} -S{opt.seed}'
|
|
||||||
else:
|
else:
|
||||||
normalized_prompt = PromptFormatter(
|
formatted_dream_prompt = opt.dream_prompt_str(seed=seed)
|
||||||
gen, opt).normalize_prompt()
|
|
||||||
metadata_prompt = f'{normalized_prompt} -S{seed}'
|
|
||||||
path = file_writer.save_image_and_prompt_to_png(
|
path = file_writer.save_image_and_prompt_to_png(
|
||||||
image, metadata_prompt, filename)
|
image = image,
|
||||||
|
dream_prompt = formatted_dream_prompt,
|
||||||
|
metadata = format_metadata(
|
||||||
|
opt,
|
||||||
|
seeds = [seed],
|
||||||
|
weights = gen.weights,
|
||||||
|
model_hash = gen.model_hash,
|
||||||
|
),
|
||||||
|
name = filename,
|
||||||
|
)
|
||||||
if (not upscaled) or opt.save_original:
|
if (not upscaled) or opt.save_original:
|
||||||
# only append to results if we didn't overwrite an earlier output
|
# only append to results if we didn't overwrite an earlier output
|
||||||
results.append([path, metadata_prompt])
|
results.append([path, formatted_dream_prompt])
|
||||||
last_results.append([path, seed])
|
last_results.append([path, seed])
|
||||||
|
|
||||||
catch_ctrl_c = infile is None # if running interactively, we catch keyboard interrupts
|
catch_ctrl_c = infile is None # if running interactively, we catch keyboard interrupts
|
||||||
@ -286,15 +268,22 @@ def main_loop(gen, outdir, prompt_as_dir, parser, infile):
|
|||||||
grid_img = make_grid(list(grid_images.values()))
|
grid_img = make_grid(list(grid_images.values()))
|
||||||
grid_seeds = list(grid_images.keys())
|
grid_seeds = list(grid_images.keys())
|
||||||
first_seed = last_results[0][1]
|
first_seed = last_results[0][1]
|
||||||
filename = f'{prefix}.{first_seed}.png'
|
filename = f'{prefix}.{first_seed}.png'
|
||||||
# TODO better metadata for grid images
|
formatted_dream_prompt = opt.dream_prompt_str(seed=first_seed,grid=True,iterations=len(grid_images))
|
||||||
normalized_prompt = PromptFormatter(
|
formatted_dream_prompt += f' # {grid_seeds}'
|
||||||
gen, opt).normalize_prompt()
|
metadata = format_metadata(
|
||||||
metadata_prompt = f'{normalized_prompt} -S{first_seed} --grid -n{len(grid_images)} # {grid_seeds}'
|
opt,
|
||||||
|
seeds = grid_seeds,
|
||||||
|
weights = gen.weights,
|
||||||
|
model_hash = gen.model_hash
|
||||||
|
)
|
||||||
path = file_writer.save_image_and_prompt_to_png(
|
path = file_writer.save_image_and_prompt_to_png(
|
||||||
grid_img, metadata_prompt, filename
|
image = grid_img,
|
||||||
|
dream_prompt = formatted_dream_prompt,
|
||||||
|
metadata = metadata,
|
||||||
|
name = filename
|
||||||
)
|
)
|
||||||
results = [[path, metadata_prompt]]
|
results = [[path, formatted_dream_prompt]]
|
||||||
|
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
print(e)
|
print(e)
|
||||||
@ -325,7 +314,6 @@ def get_next_command(infile=None) -> str: # command string
|
|||||||
print(f'#{command}')
|
print(f'#{command}')
|
||||||
return command
|
return command
|
||||||
|
|
||||||
|
|
||||||
def dream_server_loop(gen, host, port, outdir):
|
def dream_server_loop(gen, host, port, outdir):
|
||||||
print('\n* --web was specified, starting web server...')
|
print('\n* --web was specified, starting web server...')
|
||||||
# Change working directory to the stable-diffusion directory
|
# Change working directory to the stable-diffusion directory
|
||||||
@ -365,327 +353,5 @@ def write_log_message(results, log_path):
|
|||||||
with open(log_path, 'a', encoding='utf-8') as file:
|
with open(log_path, 'a', encoding='utf-8') as file:
|
||||||
file.writelines(log_lines)
|
file.writelines(log_lines)
|
||||||
|
|
||||||
|
|
||||||
SAMPLER_CHOICES = [
|
|
||||||
'ddim',
|
|
||||||
'k_dpm_2_a',
|
|
||||||
'k_dpm_2',
|
|
||||||
'k_euler_a',
|
|
||||||
'k_euler',
|
|
||||||
'k_heun',
|
|
||||||
'k_lms',
|
|
||||||
'plms',
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def create_argv_parser():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="""Generate images using Stable Diffusion.
|
|
||||||
Use --web to launch the web interface.
|
|
||||||
Use --from_file to load prompts from a file path or standard input ("-").
|
|
||||||
Otherwise you will be dropped into an interactive command prompt (type -h for help.)
|
|
||||||
Other command-line arguments are defaults that can usually be overridden
|
|
||||||
prompt the command prompt.
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--laion400m',
|
|
||||||
'--latent_diffusion',
|
|
||||||
'-l',
|
|
||||||
dest='laion400m',
|
|
||||||
action='store_true',
|
|
||||||
help='Fallback to the latent diffusion (laion400m) weights and config',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--from_file',
|
|
||||||
dest='infile',
|
|
||||||
type=str,
|
|
||||||
help='If specified, load prompts from this file',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'-n',
|
|
||||||
'--iterations',
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help='Number of images to generate',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'-F',
|
|
||||||
'--full_precision',
|
|
||||||
dest='full_precision',
|
|
||||||
action='store_true',
|
|
||||||
help='Use more memory-intensive full precision math for calculations',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'-g',
|
|
||||||
'--grid',
|
|
||||||
action='store_true',
|
|
||||||
help='Generate a grid instead of individual images',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'-A',
|
|
||||||
'-m',
|
|
||||||
'--sampler',
|
|
||||||
dest='sampler_name',
|
|
||||||
choices=SAMPLER_CHOICES,
|
|
||||||
metavar='SAMPLER_NAME',
|
|
||||||
default='k_lms',
|
|
||||||
help=f'Set the initial sampler. Default: k_lms. Supported samplers: {", ".join(SAMPLER_CHOICES)}',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--outdir',
|
|
||||||
'-o',
|
|
||||||
type=str,
|
|
||||||
default='outputs/img-samples',
|
|
||||||
help='Directory to save generated images and a log of prompts and seeds. Default: outputs/img-samples',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--seamless',
|
|
||||||
action='store_true',
|
|
||||||
help='Change the model to seamless tiling (circular) mode',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--embedding_path',
|
|
||||||
type=str,
|
|
||||||
help='Path to a pre-trained embedding manager checkpoint - can only be set on command line',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--prompt_as_dir',
|
|
||||||
'-p',
|
|
||||||
action='store_true',
|
|
||||||
help='Place images in subdirectories named after the prompt.',
|
|
||||||
)
|
|
||||||
# GFPGAN related args
|
|
||||||
parser.add_argument(
|
|
||||||
'--gfpgan_bg_upsampler',
|
|
||||||
type=str,
|
|
||||||
default='realesrgan',
|
|
||||||
help='Background upsampler. Default: realesrgan. Options: realesrgan, none.',
|
|
||||||
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--gfpgan_bg_tile',
|
|
||||||
type=int,
|
|
||||||
default=400,
|
|
||||||
help='Tile size for background sampler, 0 for no tile during testing. Default: 400.',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--gfpgan_model_path',
|
|
||||||
type=str,
|
|
||||||
default='experiments/pretrained_models/GFPGANv1.3.pth',
|
|
||||||
help='Indicates the path to the GFPGAN model, relative to --gfpgan_dir.',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--gfpgan_dir',
|
|
||||||
type=str,
|
|
||||||
default='./src/gfpgan',
|
|
||||||
help='Indicates the directory containing the GFPGAN code.',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--web',
|
|
||||||
dest='web',
|
|
||||||
action='store_true',
|
|
||||||
help='Start in web server mode.',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--host',
|
|
||||||
type=str,
|
|
||||||
default='127.0.0.1',
|
|
||||||
help='Web server: Host or IP to listen on. Set to 0.0.0.0 to accept traffic from other devices on your network.'
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--port',
|
|
||||||
type=int,
|
|
||||||
default='9090',
|
|
||||||
help='Web server: Port to listen on'
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--weights',
|
|
||||||
default='model',
|
|
||||||
help='Indicates the Stable Diffusion model to use.',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--model',
|
|
||||||
default='stable-diffusion-1.4',
|
|
||||||
help='Indicates which diffusion model to load. (currently "stable-diffusion-1.4" (default) or "laion400m")',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--config',
|
|
||||||
default='configs/models.yaml',
|
|
||||||
help='Path to configuration file for alternate models.',
|
|
||||||
)
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
def create_cmd_parser():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description='Example: dream> a fantastic alien landscape -W1024 -H960 -s100 -n12'
|
|
||||||
)
|
|
||||||
parser.add_argument('prompt')
|
|
||||||
parser.add_argument('-s', '--steps', type=int, help='Number of steps')
|
|
||||||
parser.add_argument(
|
|
||||||
'-S',
|
|
||||||
'--seed',
|
|
||||||
type=int,
|
|
||||||
help='Image seed; a +ve integer, or use -1 for the previous seed, -2 for the one before that, etc',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'-n',
|
|
||||||
'--iterations',
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help='Number of samplings to perform (slower, but will provide seeds for individual images)',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'-W', '--width', type=int, help='Image width, multiple of 64'
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'-H', '--height', type=int, help='Image height, multiple of 64'
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'-C',
|
|
||||||
'--cfg_scale',
|
|
||||||
default=7.5,
|
|
||||||
type=float,
|
|
||||||
help='Classifier free guidance (CFG) scale - higher numbers cause generator to "try" harder.',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'-g', '--grid', action='store_true', help='generate a grid'
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--outdir',
|
|
||||||
'-o',
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help='Directory to save generated images and a log of prompts and seeds',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--seamless',
|
|
||||||
action='store_true',
|
|
||||||
help='Change the model to seamless tiling (circular) mode',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'-i',
|
|
||||||
'--individual',
|
|
||||||
action='store_true',
|
|
||||||
help='Generate individual files (default)',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'-I',
|
|
||||||
'--init_img',
|
|
||||||
type=str,
|
|
||||||
help='Path to input image for img2img mode (supersedes width and height)',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'-M',
|
|
||||||
'--init_mask',
|
|
||||||
type=str,
|
|
||||||
help='Path to input mask for inpainting mode (supersedes width and height)',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'-T',
|
|
||||||
'-fit',
|
|
||||||
'--fit',
|
|
||||||
action='store_true',
|
|
||||||
help='If specified, will resize the input image to fit within the dimensions of width x height (512x512 default)',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'-f',
|
|
||||||
'--strength',
|
|
||||||
default=0.75,
|
|
||||||
type=float,
|
|
||||||
help='Strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'-G',
|
|
||||||
'--gfpgan_strength',
|
|
||||||
default=0,
|
|
||||||
type=float,
|
|
||||||
help='The strength at which to apply the GFPGAN model to the result, in order to improve faces.',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'-U',
|
|
||||||
'--upscale',
|
|
||||||
nargs='+',
|
|
||||||
default=None,
|
|
||||||
type=float,
|
|
||||||
help='Scale factor (2, 4) for upscaling final output followed by upscaling strength (0-1.0). If strength not specified, defaults to 0.75'
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'-save_orig',
|
|
||||||
'--save_original',
|
|
||||||
action='store_true',
|
|
||||||
help='Save original. Use it when upscaling to save both versions.',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'-embiggen',
|
|
||||||
'--embiggen',
|
|
||||||
nargs='+',
|
|
||||||
default=None,
|
|
||||||
type=float,
|
|
||||||
help='Embiggen tiled img2img for higher resolution and detail without extra VRAM usage. Takes scale factor relative to the size of the --init_img (-I), followed by ESRGAN upscaling strength (0-1.0), followed by minimum amount of overlap between tiles as a decimal ratio (0 - 1.0) or number of pixels. ESRGAN strength defaults to 0.75, and overlap defaults to 0.25 . ESRGAN is used to upscale the init prior to cutting it into tiles/pieces to run through img2img and then stitch back togeather.',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'-embiggen_tiles',
|
|
||||||
'--embiggen_tiles',
|
|
||||||
nargs='+',
|
|
||||||
default=None,
|
|
||||||
type=int,
|
|
||||||
help='If while doing Embiggen we are altering only parts of the image, takes a list of tiles by number to process and replace onto the image e.g. `1 3 5`, useful for redoing problematic spots from a prior Embiggen run',
|
|
||||||
)
|
|
||||||
# variants is going to be superseded by a generalized "prompt-morph" function
|
|
||||||
# parser.add_argument('-v','--variants',type=int,help="in img2img mode, the first generated image will get passed back to img2img to generate the requested number of variants")
|
|
||||||
parser.add_argument(
|
|
||||||
'-x',
|
|
||||||
'--skip_normalize',
|
|
||||||
action='store_true',
|
|
||||||
help='Skip subprompt weight normalization',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'-A',
|
|
||||||
'-m',
|
|
||||||
'--sampler',
|
|
||||||
dest='sampler_name',
|
|
||||||
default=None,
|
|
||||||
type=str,
|
|
||||||
choices=SAMPLER_CHOICES,
|
|
||||||
metavar='SAMPLER_NAME',
|
|
||||||
help=f'Switch to a different sampler. Supported samplers: {", ".join(SAMPLER_CHOICES)}',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'-t',
|
|
||||||
'--log_tokenization',
|
|
||||||
action='store_true',
|
|
||||||
help='shows how the prompt is split into tokens'
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--threshold',
|
|
||||||
default=0.0,
|
|
||||||
type=float,
|
|
||||||
help='Add threshold value aka perform clipping.',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--perlin',
|
|
||||||
default=0.0,
|
|
||||||
type=float,
|
|
||||||
help='Add perlin noise.',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'-v',
|
|
||||||
'--variation_amount',
|
|
||||||
default=0.0,
|
|
||||||
type=float,
|
|
||||||
help='If > 0, generates variations on the initial seed instead of random seeds per iteration. Must be between 0 and 1. Higher values will be more different.'
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'-V',
|
|
||||||
'--with_variations',
|
|
||||||
default=None,
|
|
||||||
type=str,
|
|
||||||
help='list of variations to apply, in the format `seed:weight,seed:weight,...'
|
|
||||||
)
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
22
scripts/sd-metadata.py
Normal file
22
scripts/sd-metadata.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
from ldm.dream.pngwriter import retrieve_metadata
|
||||||
|
|
||||||
|
if len(sys.argv) < 2:
|
||||||
|
print("Usage: file2prompt.py <file1.png> <file2.png> <file3.png>...")
|
||||||
|
print("This script opens up the indicated dream.py-generated PNG file(s) and prints out their metadata.")
|
||||||
|
exit(-1)
|
||||||
|
|
||||||
|
filenames = sys.argv[1:]
|
||||||
|
for f in filenames:
|
||||||
|
try:
|
||||||
|
metadata = retrieve_metadata(f)
|
||||||
|
print(f'{f}:\n',json.dumps(metadata['sd-metadata'], indent=4))
|
||||||
|
except FileNotFoundError:
|
||||||
|
sys.stderr.write(f'{f} not found\n')
|
||||||
|
continue
|
||||||
|
except PermissionError:
|
||||||
|
sys.stderr.write(f'{f} could not be opened due to inadequate permissions\n')
|
||||||
|
continue
|
0
server/__init__.py
Normal file
0
server/__init__.py
Normal file
152
server/application.py
Normal file
152
server/application.py
Normal file
@ -0,0 +1,152 @@
|
|||||||
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
"""Application module."""
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from flask import Flask
|
||||||
|
from flask_cors import CORS
|
||||||
|
from flask_socketio import SocketIO
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from dependency_injector.wiring import inject, Provide
|
||||||
|
from ldm.dream.args import Args
|
||||||
|
from server import views
|
||||||
|
from server.containers import Container
|
||||||
|
from server.services import GeneratorService, SignalService
|
||||||
|
|
||||||
|
# The socketio_service is injected here (rather than created in run_app) to initialize it
|
||||||
|
@inject
|
||||||
|
def initialize_app(
|
||||||
|
app: Flask,
|
||||||
|
socketio: SocketIO = Provide[Container.socketio]
|
||||||
|
) -> SocketIO:
|
||||||
|
socketio.init_app(app)
|
||||||
|
|
||||||
|
return socketio
|
||||||
|
|
||||||
|
# The signal and generator services are injected to warm up the processing queues
|
||||||
|
# TODO: Initialize these a better way?
|
||||||
|
@inject
|
||||||
|
def initialize_generator(
|
||||||
|
signal_service: SignalService = Provide[Container.signal_service],
|
||||||
|
generator_service: GeneratorService = Provide[Container.generator_service]
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def run_app(config, host, port) -> Flask:
|
||||||
|
app = Flask(__name__, static_url_path='')
|
||||||
|
|
||||||
|
# Set up dependency injection container
|
||||||
|
container = Container()
|
||||||
|
container.config.from_dict(config)
|
||||||
|
container.wire(modules=[__name__])
|
||||||
|
app.container = container
|
||||||
|
|
||||||
|
# Set up CORS
|
||||||
|
CORS(app, resources={r'/api/*': {'origins': '*'}})
|
||||||
|
|
||||||
|
# Web Routes
|
||||||
|
app.add_url_rule('/', view_func=views.WebIndex.as_view('web_index', 'index.html'))
|
||||||
|
app.add_url_rule('/index.css', view_func=views.WebIndex.as_view('web_index_css', 'index.css'))
|
||||||
|
app.add_url_rule('/index.js', view_func=views.WebIndex.as_view('web_index_js', 'index.js'))
|
||||||
|
app.add_url_rule('/config.js', view_func=views.WebConfig.as_view('web_config'))
|
||||||
|
|
||||||
|
# API Routes
|
||||||
|
app.add_url_rule('/api/jobs', view_func=views.ApiJobs.as_view('api_jobs'))
|
||||||
|
app.add_url_rule('/api/cancel', view_func=views.ApiCancel.as_view('api_cancel'))
|
||||||
|
|
||||||
|
# TODO: Get storage root from config
|
||||||
|
app.add_url_rule('/api/images/<string:dreamId>', view_func=views.ApiImages.as_view('api_images', '../'))
|
||||||
|
app.add_url_rule('/api/images/<string:dreamId>/metadata', view_func=views.ApiImagesMetadata.as_view('api_images_metadata', '../'))
|
||||||
|
app.add_url_rule('/api/images', view_func=views.ApiImagesList.as_view('api_images_list'))
|
||||||
|
app.add_url_rule('/api/intermediates/<string:dreamId>/<string:step>', view_func=views.ApiIntermediates.as_view('api_intermediates', '../'))
|
||||||
|
|
||||||
|
app.static_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../static/dream_web/'))
|
||||||
|
|
||||||
|
# Initialize
|
||||||
|
socketio = initialize_app(app)
|
||||||
|
initialize_generator()
|
||||||
|
|
||||||
|
print(">> Started Stable Diffusion api server!")
|
||||||
|
if host == '0.0.0.0':
|
||||||
|
print(f"Point your browser at http://localhost:{port} or use the host's DNS name or IP address.")
|
||||||
|
else:
|
||||||
|
print(">> Default host address now 127.0.0.1 (localhost). Use --host 0.0.0.0 to bind any address.")
|
||||||
|
print(f">> Point your browser at http://{host}:{port}.")
|
||||||
|
|
||||||
|
# Run the app
|
||||||
|
socketio.run(app, host, port)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Initialize command-line parsers and the diffusion model"""
|
||||||
|
arg_parser = Args()
|
||||||
|
opt = arg_parser.parse_args()
|
||||||
|
|
||||||
|
if opt.laion400m:
|
||||||
|
print('--laion400m flag has been deprecated. Please use --model laion400m instead.')
|
||||||
|
sys.exit(-1)
|
||||||
|
if opt.weights:
|
||||||
|
print('--weights argument has been deprecated. Please edit ./configs/models.yaml, and select the weights using --model instead.')
|
||||||
|
sys.exit(-1)
|
||||||
|
|
||||||
|
# try:
|
||||||
|
# models = OmegaConf.load(opt.config)
|
||||||
|
# width = models[opt.model].width
|
||||||
|
# height = models[opt.model].height
|
||||||
|
# config = models[opt.model].config
|
||||||
|
# weights = models[opt.model].weights
|
||||||
|
# except (FileNotFoundError, IOError, KeyError) as e:
|
||||||
|
# print(f'{e}. Aborting.')
|
||||||
|
# sys.exit(-1)
|
||||||
|
|
||||||
|
#print('* Initializing, be patient...\n')
|
||||||
|
sys.path.append('.')
|
||||||
|
|
||||||
|
# these two lines prevent a horrible warning message from appearing
|
||||||
|
# when the frozen CLIP tokenizer is imported
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
transformers.logging.set_verbosity_error()
|
||||||
|
|
||||||
|
appConfig = opt.__dict__
|
||||||
|
|
||||||
|
# appConfig = {
|
||||||
|
# "model": {
|
||||||
|
# "width": width,
|
||||||
|
# "height": height,
|
||||||
|
# "sampler_name": opt.sampler_name,
|
||||||
|
# "weights": weights,
|
||||||
|
# "full_precision": opt.full_precision,
|
||||||
|
# "config": config,
|
||||||
|
# "grid": opt.grid,
|
||||||
|
# "latent_diffusion_weights": opt.laion400m,
|
||||||
|
# "embedding_path": opt.embedding_path
|
||||||
|
# }
|
||||||
|
# }
|
||||||
|
|
||||||
|
# make sure the output directory exists
|
||||||
|
if not os.path.exists(opt.outdir):
|
||||||
|
os.makedirs(opt.outdir)
|
||||||
|
|
||||||
|
# gets rid of annoying messages about random seed
|
||||||
|
from pytorch_lightning import logging
|
||||||
|
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
|
||||||
|
|
||||||
|
print('\n* starting api server...')
|
||||||
|
# Change working directory to the stable-diffusion directory
|
||||||
|
os.chdir(
|
||||||
|
os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start server
|
||||||
|
try:
|
||||||
|
run_app(appConfig, opt.host, opt.port)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
81
server/containers.py
Normal file
81
server/containers.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
"""Containers module."""
|
||||||
|
|
||||||
|
from dependency_injector import containers, providers
|
||||||
|
from flask_socketio import SocketIO
|
||||||
|
from ldm.generate import Generate
|
||||||
|
from server import services
|
||||||
|
|
||||||
|
class Container(containers.DeclarativeContainer):
|
||||||
|
wiring_config = containers.WiringConfiguration(packages=['server'])
|
||||||
|
|
||||||
|
config = providers.Configuration()
|
||||||
|
|
||||||
|
socketio = providers.ThreadSafeSingleton(
|
||||||
|
SocketIO,
|
||||||
|
app = None
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: Add a model provider service that provides model(s) dynamically
|
||||||
|
model_singleton = providers.ThreadSafeSingleton(
|
||||||
|
Generate,
|
||||||
|
model = config.model,
|
||||||
|
sampler_name = config.sampler_name,
|
||||||
|
embedding_path = config.embedding_path,
|
||||||
|
full_precision = config.full_precision
|
||||||
|
# config = config.model.config,
|
||||||
|
|
||||||
|
# width = config.model.width,
|
||||||
|
# height = config.model.height,
|
||||||
|
# sampler_name = config.model.sampler_name,
|
||||||
|
# weights = config.model.weights,
|
||||||
|
# full_precision = config.model.full_precision,
|
||||||
|
# grid = config.model.grid,
|
||||||
|
# seamless = config.model.seamless,
|
||||||
|
# embedding_path = config.model.embedding_path,
|
||||||
|
# device_type = config.model.device_type
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: get location from config
|
||||||
|
image_storage_service = providers.ThreadSafeSingleton(
|
||||||
|
services.ImageStorageService,
|
||||||
|
'./outputs/img-samples/'
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: get location from config
|
||||||
|
image_intermediates_storage_service = providers.ThreadSafeSingleton(
|
||||||
|
services.ImageStorageService,
|
||||||
|
'./outputs/intermediates/'
|
||||||
|
)
|
||||||
|
|
||||||
|
signal_queue_service = providers.ThreadSafeSingleton(
|
||||||
|
services.SignalQueueService
|
||||||
|
)
|
||||||
|
|
||||||
|
signal_service = providers.ThreadSafeSingleton(
|
||||||
|
services.SignalService,
|
||||||
|
socketio = socketio,
|
||||||
|
queue = signal_queue_service
|
||||||
|
)
|
||||||
|
|
||||||
|
generation_queue_service = providers.ThreadSafeSingleton(
|
||||||
|
services.JobQueueService
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: get locations from config
|
||||||
|
log_service = providers.ThreadSafeSingleton(
|
||||||
|
services.LogService,
|
||||||
|
'./outputs/img-samples/',
|
||||||
|
'dream_web_log.txt'
|
||||||
|
)
|
||||||
|
|
||||||
|
generator_service = providers.ThreadSafeSingleton(
|
||||||
|
services.GeneratorService,
|
||||||
|
model = model_singleton,
|
||||||
|
queue = generation_queue_service,
|
||||||
|
imageStorage = image_storage_service,
|
||||||
|
intermediateStorage = image_intermediates_storage_service,
|
||||||
|
log = log_service,
|
||||||
|
signal_service = signal_service
|
||||||
|
)
|
251
server/models.py
Normal file
251
server/models.py
Normal file
@ -0,0 +1,251 @@
|
|||||||
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
from base64 import urlsafe_b64encode
|
||||||
|
import json
|
||||||
|
import string
|
||||||
|
from copy import deepcopy
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, List, Union
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
|
||||||
|
class DreamBase():
|
||||||
|
# Id
|
||||||
|
id: str
|
||||||
|
|
||||||
|
# Initial Image
|
||||||
|
enable_init_image: bool
|
||||||
|
initimg: string = None
|
||||||
|
|
||||||
|
# Img2Img
|
||||||
|
enable_img2img: bool # TODO: support this better
|
||||||
|
strength: float = 0 # TODO: name this something related to img2img to make it clearer?
|
||||||
|
fit = None # Fit initial image dimensions
|
||||||
|
|
||||||
|
# Generation
|
||||||
|
enable_generate: bool
|
||||||
|
prompt: string = ""
|
||||||
|
seed: int = 0 # 0 is random
|
||||||
|
steps: int = 10
|
||||||
|
width: int = 512
|
||||||
|
height: int = 512
|
||||||
|
cfg_scale: float = 7.5
|
||||||
|
sampler_name: string = 'klms'
|
||||||
|
seamless: bool = False
|
||||||
|
model: str = None # The model to use (currently unused)
|
||||||
|
embeddings = None # The embeddings to use (currently unused)
|
||||||
|
progress_images: bool = False
|
||||||
|
|
||||||
|
# GFPGAN
|
||||||
|
enable_gfpgan: bool
|
||||||
|
gfpgan_strength: float = 0
|
||||||
|
|
||||||
|
# Upscale
|
||||||
|
enable_upscale: bool
|
||||||
|
upscale: None
|
||||||
|
upscale_level: int = None
|
||||||
|
upscale_strength: float = 0.75
|
||||||
|
|
||||||
|
# Embiggen
|
||||||
|
enable_embiggen: bool
|
||||||
|
embiggen: Union[None, List[float]] = None
|
||||||
|
embiggen_tiles: Union[None, List[int]] = None
|
||||||
|
|
||||||
|
# Metadata
|
||||||
|
time: int
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.id = urlsafe_b64encode(uuid4().bytes).decode('ascii')
|
||||||
|
|
||||||
|
def parse_json(self, j, new_instance=False):
|
||||||
|
# Id
|
||||||
|
if 'id' in j and not new_instance:
|
||||||
|
self.id = j.get('id')
|
||||||
|
|
||||||
|
# Initial Image
|
||||||
|
self.enable_init_image = 'enable_init_image' in j and bool(j.get('enable_init_image'))
|
||||||
|
if self.enable_init_image:
|
||||||
|
self.initimg = j.get('initimg')
|
||||||
|
|
||||||
|
# Img2Img
|
||||||
|
self.enable_img2img = 'enable_img2img' in j and bool(j.get('enable_img2img'))
|
||||||
|
if self.enable_img2img:
|
||||||
|
self.strength = float(j.get('strength'))
|
||||||
|
self.fit = 'fit' in j
|
||||||
|
|
||||||
|
# Generation
|
||||||
|
self.enable_generate = 'enable_generate' in j and bool(j.get('enable_generate'))
|
||||||
|
if self.enable_generate:
|
||||||
|
self.prompt = j.get('prompt')
|
||||||
|
self.seed = int(j.get('seed'))
|
||||||
|
self.steps = int(j.get('steps'))
|
||||||
|
self.width = int(j.get('width'))
|
||||||
|
self.height = int(j.get('height'))
|
||||||
|
self.cfg_scale = float(j.get('cfgscale') or j.get('cfg_scale'))
|
||||||
|
self.sampler_name = j.get('sampler') or j.get('sampler_name')
|
||||||
|
# model: str = None # The model to use (currently unused)
|
||||||
|
# embeddings = None # The embeddings to use (currently unused)
|
||||||
|
self.seamless = 'seamless' in j
|
||||||
|
self.progress_images = 'progress_images' in j
|
||||||
|
|
||||||
|
# GFPGAN
|
||||||
|
self.enable_gfpgan = 'enable_gfpgan' in j and bool(j.get('enable_gfpgan'))
|
||||||
|
if self.enable_gfpgan:
|
||||||
|
self.gfpgan_strength = float(j.get('gfpgan_strength'))
|
||||||
|
|
||||||
|
# Upscale
|
||||||
|
self.enable_upscale = 'enable_upscale' in j and bool(j.get('enable_upscale'))
|
||||||
|
if self.enable_upscale:
|
||||||
|
self.upscale_level = j.get('upscale_level')
|
||||||
|
self.upscale_strength = j.get('upscale_strength')
|
||||||
|
self.upscale = None if self.upscale_level in {None,''} else [int(self.upscale_level),float(self.upscale_strength)]
|
||||||
|
|
||||||
|
# Embiggen
|
||||||
|
self.enable_embiggen = 'enable_embiggen' in j and bool(j.get('enable_embiggen'))
|
||||||
|
if self.enable_embiggen:
|
||||||
|
self.embiggen = j.get('embiggen')
|
||||||
|
self.embiggen_tiles = j.get('embiggen_tiles')
|
||||||
|
|
||||||
|
# Metadata
|
||||||
|
self.time = int(j.get('time')) if ('time' in j and not new_instance) else int(datetime.now(timezone.utc).timestamp())
|
||||||
|
|
||||||
|
|
||||||
|
class DreamResult(DreamBase):
|
||||||
|
# Result
|
||||||
|
has_upscaled: False
|
||||||
|
has_gfpgan: False
|
||||||
|
|
||||||
|
# TODO: use something else for state tracking
|
||||||
|
images_generated: int = 0
|
||||||
|
images_upscaled: int = 0
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def clone_without_img(self):
|
||||||
|
copy = deepcopy(self)
|
||||||
|
copy.initimg = None
|
||||||
|
return copy
|
||||||
|
|
||||||
|
def to_json(self):
|
||||||
|
copy = deepcopy(self)
|
||||||
|
copy.initimg = None
|
||||||
|
j = json.dumps(copy.__dict__)
|
||||||
|
return j
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_json(j, newTime: bool = False):
|
||||||
|
d = DreamResult()
|
||||||
|
d.parse_json(j)
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: switch this to a pipelined request, with pluggable steps
|
||||||
|
# Will likely require generator code changes to accomplish
|
||||||
|
class JobRequest(DreamBase):
|
||||||
|
# Iteration
|
||||||
|
iterations: int = 1
|
||||||
|
variation_amount = None
|
||||||
|
with_variations = None
|
||||||
|
|
||||||
|
# Results
|
||||||
|
results: List[DreamResult] = []
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def newDreamResult(self) -> DreamResult:
|
||||||
|
result = DreamResult()
|
||||||
|
result.parse_json(self.__dict__, new_instance=True)
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_json(j):
|
||||||
|
job = JobRequest()
|
||||||
|
job.parse_json(j)
|
||||||
|
|
||||||
|
# Metadata
|
||||||
|
job.time = int(j.get('time')) if ('time' in j) else int(datetime.now(timezone.utc).timestamp())
|
||||||
|
|
||||||
|
# Iteration
|
||||||
|
if job.enable_generate:
|
||||||
|
job.iterations = int(j.get('iterations'))
|
||||||
|
job.variation_amount = float(j.get('variation_amount'))
|
||||||
|
job.with_variations = j.get('with_variations')
|
||||||
|
|
||||||
|
return job
|
||||||
|
|
||||||
|
|
||||||
|
class ProgressType(Enum):
|
||||||
|
GENERATION = 1
|
||||||
|
UPSCALING_STARTED = 2
|
||||||
|
UPSCALING_DONE = 3
|
||||||
|
|
||||||
|
class Signal():
|
||||||
|
event: str
|
||||||
|
data = None
|
||||||
|
room: str = None
|
||||||
|
broadcast: bool = False
|
||||||
|
|
||||||
|
def __init__(self, event: str, data, room: str = None, broadcast: bool = False):
|
||||||
|
self.event = event
|
||||||
|
self.data = data
|
||||||
|
self.room = room
|
||||||
|
self.broadcast = broadcast
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def image_progress(jobId: str, dreamId: str, step: int, totalSteps: int, progressType: ProgressType = ProgressType.GENERATION, hasProgressImage: bool = False):
|
||||||
|
return Signal('dream_progress', {
|
||||||
|
'jobId': jobId,
|
||||||
|
'dreamId': dreamId,
|
||||||
|
'step': step,
|
||||||
|
'totalSteps': totalSteps,
|
||||||
|
'hasProgressImage': hasProgressImage,
|
||||||
|
'progressType': progressType.name
|
||||||
|
}, room=jobId, broadcast=True)
|
||||||
|
|
||||||
|
# TODO: use a result id or something? Like a sub-job
|
||||||
|
@staticmethod
|
||||||
|
def image_result(jobId: str, dreamId: str, dreamResult: DreamResult):
|
||||||
|
return Signal('dream_result', {
|
||||||
|
'jobId': jobId,
|
||||||
|
'dreamId': dreamId,
|
||||||
|
'dreamRequest': dreamResult.clone_without_img().__dict__
|
||||||
|
}, room=jobId, broadcast=True)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def job_started(jobId: str):
|
||||||
|
return Signal('job_started', {
|
||||||
|
'jobId': jobId
|
||||||
|
}, room=jobId, broadcast=True)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def job_done(jobId: str):
|
||||||
|
return Signal('job_done', {
|
||||||
|
'jobId': jobId
|
||||||
|
}, room=jobId, broadcast=True)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def job_canceled(jobId: str):
|
||||||
|
return Signal('job_canceled', {
|
||||||
|
'jobId': jobId
|
||||||
|
}, room=jobId, broadcast=True)
|
||||||
|
|
||||||
|
|
||||||
|
class PaginatedItems():
|
||||||
|
items: List[Any]
|
||||||
|
page: int # Current Page
|
||||||
|
pages: int # Total number of pages
|
||||||
|
per_page: int # Number of items per page
|
||||||
|
total: int # Total number of items in result
|
||||||
|
|
||||||
|
def __init__(self, items: List[Any], page: int, pages: int, per_page: int, total: int):
|
||||||
|
self.items = items
|
||||||
|
self.page = page
|
||||||
|
self.pages = pages
|
||||||
|
self.per_page = per_page
|
||||||
|
self.total = total
|
||||||
|
|
||||||
|
def to_json(self):
|
||||||
|
return json.dumps(self.__dict__)
|
389
server/services.py
Normal file
389
server/services.py
Normal file
@ -0,0 +1,389 @@
|
|||||||
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
import base64
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
import glob
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from queue import Empty, Queue
|
||||||
|
import shlex
|
||||||
|
from threading import Thread
|
||||||
|
import time
|
||||||
|
from flask_socketio import SocketIO, join_room, leave_room
|
||||||
|
from ldm.dream.args import Args
|
||||||
|
from ldm.dream.generator import embiggen
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from ldm.dream.pngwriter import PngWriter
|
||||||
|
from ldm.dream.server import CanceledException
|
||||||
|
from ldm.generate import Generate
|
||||||
|
from server.models import DreamResult, JobRequest, PaginatedItems, ProgressType, Signal
|
||||||
|
|
||||||
|
class JobQueueService:
|
||||||
|
__queue: Queue = Queue()
|
||||||
|
|
||||||
|
def push(self, dreamRequest: DreamResult):
|
||||||
|
self.__queue.put(dreamRequest)
|
||||||
|
|
||||||
|
def get(self, timeout: float = None) -> DreamResult:
|
||||||
|
return self.__queue.get(timeout= timeout)
|
||||||
|
|
||||||
|
class SignalQueueService:
|
||||||
|
__queue: Queue = Queue()
|
||||||
|
|
||||||
|
def push(self, signal: Signal):
|
||||||
|
self.__queue.put(signal)
|
||||||
|
|
||||||
|
def get(self) -> Signal:
|
||||||
|
return self.__queue.get(block=False)
|
||||||
|
|
||||||
|
|
||||||
|
class SignalService:
|
||||||
|
__socketio: SocketIO
|
||||||
|
__queue: SignalQueueService
|
||||||
|
|
||||||
|
def __init__(self, socketio: SocketIO, queue: SignalQueueService):
|
||||||
|
self.__socketio = socketio
|
||||||
|
self.__queue = queue
|
||||||
|
|
||||||
|
def on_join(data):
|
||||||
|
room = data['room']
|
||||||
|
join_room(room)
|
||||||
|
self.__socketio.emit("test", "something", room=room)
|
||||||
|
|
||||||
|
def on_leave(data):
|
||||||
|
room = data['room']
|
||||||
|
leave_room(room)
|
||||||
|
|
||||||
|
self.__socketio.on_event('join_room', on_join)
|
||||||
|
self.__socketio.on_event('leave_room', on_leave)
|
||||||
|
|
||||||
|
self.__socketio.start_background_task(self.__process)
|
||||||
|
|
||||||
|
def __process(self):
|
||||||
|
# preload the model
|
||||||
|
print('Started signal queue processor')
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
signal = self.__queue.get()
|
||||||
|
self.__socketio.emit(signal.event, signal.data, room=signal.room, broadcast=signal.broadcast)
|
||||||
|
except Empty:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
self.__socketio.sleep(0.001)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print('Signal queue processor stopped')
|
||||||
|
|
||||||
|
|
||||||
|
def emit(self, signal: Signal):
|
||||||
|
self.__queue.push(signal)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Name this better?
|
||||||
|
# TODO: Logging and signals should probably be event based (multiple listeners for an event)
|
||||||
|
class LogService:
|
||||||
|
__location: str
|
||||||
|
__logFile: str
|
||||||
|
|
||||||
|
def __init__(self, location:str, file:str):
|
||||||
|
self.__location = location
|
||||||
|
self.__logFile = file
|
||||||
|
|
||||||
|
def log(self, dreamResult: DreamResult, seed = None, upscaled = False):
|
||||||
|
with open(os.path.join(self.__location, self.__logFile), "a") as log:
|
||||||
|
log.write(f"{dreamResult.id}: {dreamResult.to_json()}\n")
|
||||||
|
|
||||||
|
|
||||||
|
class ImageStorageService:
|
||||||
|
__location: str
|
||||||
|
__pngWriter: PngWriter
|
||||||
|
__legacyParser: ArgumentParser
|
||||||
|
|
||||||
|
def __init__(self, location):
|
||||||
|
self.__location = location
|
||||||
|
self.__pngWriter = PngWriter(self.__location)
|
||||||
|
self.__legacyParser = Args() # TODO: inject this?
|
||||||
|
|
||||||
|
def __getName(self, dreamId: str, postfix: str = '') -> str:
|
||||||
|
return f'{dreamId}{postfix}.png'
|
||||||
|
|
||||||
|
def save(self, image, dreamResult: DreamResult, postfix: str = '') -> str:
|
||||||
|
name = self.__getName(dreamResult.id, postfix)
|
||||||
|
meta = dreamResult.to_json() # TODO: make all methods consistent with writing metadata. Standardize metadata.
|
||||||
|
path = self.__pngWriter.save_image_and_prompt_to_png(image, dream_prompt=meta, metadata=None, name=name)
|
||||||
|
return path
|
||||||
|
|
||||||
|
def path(self, dreamId: str, postfix: str = '') -> str:
|
||||||
|
name = self.__getName(dreamId, postfix)
|
||||||
|
path = os.path.join(self.__location, name)
|
||||||
|
return path
|
||||||
|
|
||||||
|
# Returns true if found, false if not found or error
|
||||||
|
def delete(self, dreamId: str, postfix: str = '') -> bool:
|
||||||
|
path = self.path(dreamId, postfix)
|
||||||
|
if (os.path.exists(path)):
|
||||||
|
os.remove(path)
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def getMetadata(self, dreamId: str, postfix: str = '') -> DreamResult:
|
||||||
|
path = self.path(dreamId, postfix)
|
||||||
|
image = Image.open(path)
|
||||||
|
text = image.text
|
||||||
|
if text.__contains__('Dream'):
|
||||||
|
dreamMeta = text.get('Dream')
|
||||||
|
try:
|
||||||
|
j = json.loads(dreamMeta)
|
||||||
|
return DreamResult.from_json(j)
|
||||||
|
except ValueError:
|
||||||
|
# Try to parse command-line format (legacy metadata format)
|
||||||
|
try:
|
||||||
|
opt = self.__parseLegacyMetadata(dreamMeta)
|
||||||
|
optd = opt.__dict__
|
||||||
|
if (not 'width' in optd) or (optd.get('width') is None):
|
||||||
|
optd['width'] = image.width
|
||||||
|
if (not 'height' in optd) or (optd.get('height') is None):
|
||||||
|
optd['height'] = image.height
|
||||||
|
if (not 'steps' in optd) or (optd.get('steps') is None):
|
||||||
|
optd['steps'] = 10 # No way around this unfortunately - seems like it wasn't storing this previously
|
||||||
|
|
||||||
|
optd['time'] = os.path.getmtime(path) # Set timestamp manually (won't be exactly correct though)
|
||||||
|
|
||||||
|
return DreamResult.from_json(optd)
|
||||||
|
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def __parseLegacyMetadata(self, command: str) -> DreamResult:
|
||||||
|
# before splitting, escape single quotes so as not to mess
|
||||||
|
# up the parser
|
||||||
|
command = command.replace("'", "\\'")
|
||||||
|
|
||||||
|
try:
|
||||||
|
elements = shlex.split(command)
|
||||||
|
except ValueError as e:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# rearrange the arguments to mimic how it works in the Dream bot.
|
||||||
|
switches = ['']
|
||||||
|
switches_started = False
|
||||||
|
|
||||||
|
for el in elements:
|
||||||
|
if el[0] == '-' and not switches_started:
|
||||||
|
switches_started = True
|
||||||
|
if switches_started:
|
||||||
|
switches.append(el)
|
||||||
|
else:
|
||||||
|
switches[0] += el
|
||||||
|
switches[0] += ' '
|
||||||
|
switches[0] = switches[0][: len(switches[0]) - 1]
|
||||||
|
|
||||||
|
try:
|
||||||
|
opt = self.__legacyParser.parse_cmd(switches)
|
||||||
|
return opt
|
||||||
|
except SystemExit:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def list_files(self, page: int, perPage: int) -> PaginatedItems:
|
||||||
|
files = sorted(glob.glob(os.path.join(self.__location,'*.png')), key=os.path.getmtime, reverse=True)
|
||||||
|
count = len(files)
|
||||||
|
|
||||||
|
startId = page * perPage
|
||||||
|
pageCount = int(count / perPage) + 1
|
||||||
|
endId = min(startId + perPage, count)
|
||||||
|
items = [] if startId >= count else files[startId:endId]
|
||||||
|
|
||||||
|
items = list(map(lambda f: Path(f).stem, items))
|
||||||
|
|
||||||
|
return PaginatedItems(items, page, pageCount, perPage, count)
|
||||||
|
|
||||||
|
|
||||||
|
class GeneratorService:
|
||||||
|
__model: Generate
|
||||||
|
__queue: JobQueueService
|
||||||
|
__imageStorage: ImageStorageService
|
||||||
|
__intermediateStorage: ImageStorageService
|
||||||
|
__log: LogService
|
||||||
|
__thread: Thread
|
||||||
|
__cancellationRequested: bool = False
|
||||||
|
__signal_service: SignalService
|
||||||
|
|
||||||
|
def __init__(self, model: Generate, queue: JobQueueService, imageStorage: ImageStorageService, intermediateStorage: ImageStorageService, log: LogService, signal_service: SignalService):
|
||||||
|
self.__model = model
|
||||||
|
self.__queue = queue
|
||||||
|
self.__imageStorage = imageStorage
|
||||||
|
self.__intermediateStorage = intermediateStorage
|
||||||
|
self.__log = log
|
||||||
|
self.__signal_service = signal_service
|
||||||
|
|
||||||
|
# Create the background thread
|
||||||
|
self.__thread = Thread(target=self.__process, name = "GeneratorService")
|
||||||
|
self.__thread.daemon = True
|
||||||
|
self.__thread.start()
|
||||||
|
|
||||||
|
|
||||||
|
# Request cancellation of the current job
|
||||||
|
def cancel(self):
|
||||||
|
self.__cancellationRequested = True
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Consider moving this to its own service if there's benefit in separating the generator
|
||||||
|
def __process(self):
|
||||||
|
# preload the model
|
||||||
|
# TODO: support multiple models
|
||||||
|
print('Preloading model')
|
||||||
|
tic = time.time()
|
||||||
|
self.__model.load_model()
|
||||||
|
print(f'>> model loaded in', '%4.2fs' % (time.time() - tic))
|
||||||
|
|
||||||
|
print('Started generation queue processor')
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
dreamRequest = self.__queue.get()
|
||||||
|
self.__generate(dreamRequest)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print('Generation queue processor stopped')
|
||||||
|
|
||||||
|
|
||||||
|
def __on_start(self, jobRequest: JobRequest):
|
||||||
|
self.__signal_service.emit(Signal.job_started(jobRequest.id))
|
||||||
|
|
||||||
|
|
||||||
|
def __on_image_result(self, jobRequest: JobRequest, image, seed, upscaled=False):
|
||||||
|
dreamResult = jobRequest.newDreamResult()
|
||||||
|
dreamResult.seed = seed
|
||||||
|
dreamResult.has_upscaled = upscaled
|
||||||
|
dreamResult.iterations = 1
|
||||||
|
jobRequest.results.append(dreamResult)
|
||||||
|
# TODO: Separate status of GFPGAN?
|
||||||
|
|
||||||
|
self.__imageStorage.save(image, dreamResult)
|
||||||
|
|
||||||
|
# TODO: handle upscaling logic better (this is appending data to log, but only on first generation)
|
||||||
|
if not upscaled:
|
||||||
|
self.__log.log(dreamResult)
|
||||||
|
|
||||||
|
# Send result signal
|
||||||
|
self.__signal_service.emit(Signal.image_result(jobRequest.id, dreamResult.id, dreamResult))
|
||||||
|
|
||||||
|
upscaling_requested = dreamResult.enable_upscale or dreamResult.enable_gfpgan
|
||||||
|
|
||||||
|
# Report upscaling status
|
||||||
|
# TODO: this is very coupled to logic inside the generator. Fix that.
|
||||||
|
if upscaling_requested and any(result.has_upscaled for result in jobRequest.results):
|
||||||
|
progressType = ProgressType.UPSCALING_STARTED if len(jobRequest.results) < 2 * jobRequest.iterations else ProgressType.UPSCALING_DONE
|
||||||
|
upscale_count = sum(1 for i in jobRequest.results if i.has_upscaled)
|
||||||
|
self.__signal_service.emit(Signal.image_progress(jobRequest.id, dreamResult.id, upscale_count, jobRequest.iterations, progressType))
|
||||||
|
|
||||||
|
|
||||||
|
def __on_progress(self, jobRequest: JobRequest, sample, step):
|
||||||
|
if self.__cancellationRequested:
|
||||||
|
self.__cancellationRequested = False
|
||||||
|
raise CanceledException
|
||||||
|
|
||||||
|
# TODO: Progress per request will be easier once the seeds (and ids) can all be pre-generated
|
||||||
|
hasProgressImage = False
|
||||||
|
s = str(len(jobRequest.results))
|
||||||
|
if jobRequest.progress_images and step % 5 == 0 and step < jobRequest.steps - 1:
|
||||||
|
image = self.__model._sample_to_image(sample)
|
||||||
|
|
||||||
|
# TODO: clean this up, use a pre-defined dream result
|
||||||
|
result = DreamResult()
|
||||||
|
result.parse_json(jobRequest.__dict__, new_instance=False)
|
||||||
|
self.__intermediateStorage.save(image, result, postfix=f'.{s}.{step}')
|
||||||
|
hasProgressImage = True
|
||||||
|
|
||||||
|
self.__signal_service.emit(Signal.image_progress(jobRequest.id, f'{jobRequest.id}.{s}', step, jobRequest.steps, ProgressType.GENERATION, hasProgressImage))
|
||||||
|
|
||||||
|
|
||||||
|
def __generate(self, jobRequest: JobRequest):
|
||||||
|
try:
|
||||||
|
# TODO: handle this file a file service for init images
|
||||||
|
initimgfile = None # TODO: support this on the model directly?
|
||||||
|
if (jobRequest.enable_init_image):
|
||||||
|
if jobRequest.initimg is not None:
|
||||||
|
with open("./img2img-tmp.png", "wb") as f:
|
||||||
|
initimg = jobRequest.initimg.split(",")[1] # Ignore mime type
|
||||||
|
f.write(base64.b64decode(initimg))
|
||||||
|
initimgfile = "./img2img-tmp.png"
|
||||||
|
|
||||||
|
# Use previous seed if set to -1
|
||||||
|
initSeed = jobRequest.seed
|
||||||
|
if initSeed == -1:
|
||||||
|
initSeed = self.__model.seed
|
||||||
|
|
||||||
|
# Zero gfpgan strength if the model doesn't exist
|
||||||
|
# TODO: determine if this could be at the top now? Used to cause circular import
|
||||||
|
from ldm.gfpgan.gfpgan_tools import gfpgan_model_exists
|
||||||
|
if not gfpgan_model_exists:
|
||||||
|
jobRequest.enable_gfpgan = False
|
||||||
|
|
||||||
|
# Signal start
|
||||||
|
self.__on_start(jobRequest)
|
||||||
|
|
||||||
|
# Generate in model
|
||||||
|
# TODO: Split job generation requests instead of fitting all parameters here
|
||||||
|
# TODO: Support no generation (just upscaling/gfpgan)
|
||||||
|
|
||||||
|
upscale = None if not jobRequest.enable_upscale else jobRequest.upscale
|
||||||
|
gfpgan_strength = 0 if not jobRequest.enable_gfpgan else jobRequest.gfpgan_strength
|
||||||
|
|
||||||
|
if not jobRequest.enable_generate:
|
||||||
|
# If not generating, check if we're upscaling or running gfpgan
|
||||||
|
if not upscale and not gfpgan_strength:
|
||||||
|
# Invalid settings (TODO: Add message to help user)
|
||||||
|
raise CanceledException()
|
||||||
|
|
||||||
|
image = Image.open(initimgfile)
|
||||||
|
# TODO: support progress for upscale?
|
||||||
|
self.__model.upscale_and_reconstruct(
|
||||||
|
image_list = [[image,0]],
|
||||||
|
upscale = upscale,
|
||||||
|
strength = gfpgan_strength,
|
||||||
|
save_original = False,
|
||||||
|
image_callback = lambda image, seed, upscaled=False: self.__on_image_result(jobRequest, image, seed, upscaled))
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Generating - run the generation
|
||||||
|
init_img = None if (not jobRequest.enable_img2img or jobRequest.strength == 0) else initimgfile
|
||||||
|
|
||||||
|
|
||||||
|
self.__model.prompt2image(
|
||||||
|
prompt = jobRequest.prompt,
|
||||||
|
init_img = init_img, # TODO: ensure this works
|
||||||
|
strength = None if init_img is None else jobRequest.strength,
|
||||||
|
fit = None if init_img is None else jobRequest.fit,
|
||||||
|
iterations = jobRequest.iterations,
|
||||||
|
cfg_scale = jobRequest.cfg_scale,
|
||||||
|
width = jobRequest.width,
|
||||||
|
height = jobRequest.height,
|
||||||
|
seed = jobRequest.seed,
|
||||||
|
steps = jobRequest.steps,
|
||||||
|
variation_amount = jobRequest.variation_amount,
|
||||||
|
with_variations = jobRequest.with_variations,
|
||||||
|
gfpgan_strength = gfpgan_strength,
|
||||||
|
upscale = upscale,
|
||||||
|
sampler_name = jobRequest.sampler_name,
|
||||||
|
seamless = jobRequest.seamless,
|
||||||
|
embiggen = jobRequest.embiggen,
|
||||||
|
embiggen_tiles = jobRequest.embiggen_tiles,
|
||||||
|
step_callback = lambda sample, step: self.__on_progress(jobRequest, sample, step),
|
||||||
|
image_callback = lambda image, seed, upscaled=False: self.__on_image_result(jobRequest, image, seed, upscaled))
|
||||||
|
|
||||||
|
except CanceledException:
|
||||||
|
self.__signal_service.emit(Signal.job_canceled(jobRequest.id))
|
||||||
|
|
||||||
|
finally:
|
||||||
|
self.__signal_service.emit(Signal.job_done(jobRequest.id))
|
||||||
|
|
||||||
|
# Remove the temp file
|
||||||
|
if (initimgfile is not None):
|
||||||
|
os.remove("./img2img-tmp.png")
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user