Improves state/API code structure, formatting, etc

This commit is contained in:
psychedelicious 2022-09-18 17:33:09 +10:00
parent ef6609abcb
commit 6e927acd58
49 changed files with 2073 additions and 1401 deletions

View File

@ -7,6 +7,8 @@ import eventlet
import glob import glob
import shlex import shlex
import argparse import argparse
import math
import shutil
from flask_socketio import SocketIO from flask_socketio import SocketIO
from flask import Flask, send_from_directory, url_for, jsonify from flask import Flask, send_from_directory, url_for, jsonify
@ -119,15 +121,15 @@ result_path = os.path.join(output_dir, 'img-samples/')
intermediate_path = os.path.join(result_path, 'intermediates/') intermediate_path = os.path.join(result_path, 'intermediates/')
# path for user-uploaded init images and masks # path for user-uploaded init images and masks
init_path = os.path.join(result_path, 'init-images/') init_image_path = os.path.join(result_path, 'init-images/')
mask_path = os.path.join(result_path, 'mask-images/') mask_image_path = os.path.join(result_path, 'mask-images/')
# txt log # txt log
log_path = os.path.join(result_path, 'dream_log.txt') log_path = os.path.join(result_path, 'dream_log.txt')
# make all output paths # make all output paths
[os.makedirs(path, exist_ok=True) [os.makedirs(path, exist_ok=True)
for path in [result_path, intermediate_path, init_path, mask_path]] for path in [result_path, intermediate_path, init_image_path, mask_image_path]]
""" """
@ -155,7 +157,8 @@ def handle_request_all_images():
else: else:
metadata = all_metadata['sd-metadata'] metadata = all_metadata['sd-metadata']
image_array.append({'path': path, 'metadata': metadata}) image_array.append({'path': path, 'metadata': metadata})
return make_response("OK", data=image_array) socketio.emit('galleryImages', {'images': image_array})
eventlet.sleep(0)
@socketio.on('generateImage') @socketio.on('generateImage')
@ -166,16 +169,32 @@ def handle_generate_image_event(generation_parameters, esrgan_parameters, gfpgan
esrgan_parameters, esrgan_parameters,
gfpgan_parameters gfpgan_parameters
) )
return make_response("OK")
@socketio.on('runESRGAN') @socketio.on('runESRGAN')
def handle_run_esrgan_event(original_image, esrgan_parameters): def handle_run_esrgan_event(original_image, esrgan_parameters):
print(f'>> ESRGAN upscale requested for "{original_image["url"]}": {esrgan_parameters}') print(f'>> ESRGAN upscale requested for "{original_image["url"]}": {esrgan_parameters}')
progress = {
'currentStep': 1,
'totalSteps': 1,
'currentIteration': 1,
'totalIterations': 1,
'currentStatus': 'Preparing',
'isProcessing': True,
'currentStatusHasSteps': False
}
socketio.emit('progressUpdate', progress)
eventlet.sleep(0)
image = Image.open(original_image["url"]) image = Image.open(original_image["url"])
seed = original_image['metadata']['seed'] if 'seed' in original_image['metadata'] else 'unknown_seed' seed = original_image['metadata']['seed'] if 'seed' in original_image['metadata'] else 'unknown_seed'
progress['currentStatus'] = 'Upscaling'
socketio.emit('progressUpdate', progress)
eventlet.sleep(0)
image = real_esrgan_upscale( image = real_esrgan_upscale(
image=image, image=image,
upsampler_scale=esrgan_parameters['upscale'][0], upsampler_scale=esrgan_parameters['upscale'][0],
@ -183,24 +202,54 @@ def handle_run_esrgan_event(original_image, esrgan_parameters):
seed=seed seed=seed
) )
progress['currentStatus'] = 'Saving image'
socketio.emit('progressUpdate', progress)
eventlet.sleep(0)
esrgan_parameters['seed'] = seed esrgan_parameters['seed'] = seed
path = save_image(image, esrgan_parameters, result_path, postprocessing='esrgan') path = save_image(image, esrgan_parameters, result_path, postprocessing='esrgan')
command = parameters_to_command(esrgan_parameters) command = parameters_to_command(esrgan_parameters)
write_log_message(f'[Upscaled] "{original_image["url"]}" > "{path}": {command}') write_log_message(f'[Upscaled] "{original_image["url"]}" > "{path}": {command}')
progress['currentStatus'] = 'Finished'
progress['currentStep'] = 0
progress['totalSteps'] = 0
progress['currentIteration'] = 0
progress['totalIterations'] = 0
progress['isProcessing'] = False
socketio.emit('progressUpdate', progress)
eventlet.sleep(0)
socketio.emit( socketio.emit(
'result', {'url': os.path.relpath(path), 'type': 'esrgan', 'uuid': original_image['uuid'],'metadata': esrgan_parameters}) 'esrganResult', {'url': os.path.relpath(path), 'uuid': original_image['uuid'], 'metadata': esrgan_parameters})
@socketio.on('runGFPGAN') @socketio.on('runGFPGAN')
def handle_run_gfpgan_event(original_image, gfpgan_parameters): def handle_run_gfpgan_event(original_image, gfpgan_parameters):
print(f'>> GFPGAN face fix requested for "{original_image["url"]}": {gfpgan_parameters}') print(f'>> GFPGAN face fix requested for "{original_image["url"]}": {gfpgan_parameters}')
progress = {
'currentStep': 1,
'totalSteps': 1,
'currentIteration': 1,
'totalIterations': 1,
'currentStatus': 'Preparing',
'isProcessing': True,
'currentStatusHasSteps': False
}
socketio.emit('progressUpdate', progress)
eventlet.sleep(0)
image = Image.open(original_image["url"]) image = Image.open(original_image["url"])
seed = original_image['metadata']['seed'] if 'seed' in original_image['metadata'] else 'unknown_seed' seed = original_image['metadata']['seed'] if 'seed' in original_image['metadata'] else 'unknown_seed'
progress['currentStatus'] = 'Fixing faces'
socketio.emit('progressUpdate', progress)
eventlet.sleep(0)
image = run_gfpgan( image = run_gfpgan(
image=image, image=image,
strength=gfpgan_parameters['gfpgan_strength'], strength=gfpgan_parameters['gfpgan_strength'],
@ -208,29 +257,42 @@ def handle_run_gfpgan_event(original_image, gfpgan_parameters):
upsampler_scale=1 upsampler_scale=1
) )
progress['currentStatus'] = 'Saving image'
socketio.emit('progressUpdate', progress)
eventlet.sleep(0)
gfpgan_parameters['seed'] = seed gfpgan_parameters['seed'] = seed
path = save_image(image, gfpgan_parameters, result_path, postprocessing='gfpgan') path = save_image(image, gfpgan_parameters, result_path, postprocessing='gfpgan')
command = parameters_to_command(gfpgan_parameters) command = parameters_to_command(gfpgan_parameters)
write_log_message(f'[Fixed faces] "{original_image["url"]}" > "{path}": {command}') write_log_message(f'[Fixed faces] "{original_image["url"]}" > "{path}": {command}')
progress['currentStatus'] = 'Finished'
progress['currentStep'] = 0
progress['totalSteps'] = 0
progress['currentIteration'] = 0
progress['totalIterations'] = 0
progress['isProcessing'] = False
socketio.emit('progressUpdate', progress)
eventlet.sleep(0)
socketio.emit( socketio.emit(
'result', {'url': os.path.relpath(path), 'type': 'gfpgan', 'uuid': original_image['uuid'],'metadata': gfpgan_parameters}) 'gfpganResult', {'url': os.path.relpath(path), 'uuid': original_image['uuid'], 'metadata': gfpgan_parameters})
@socketio.on('cancel') @socketio.on('cancel')
def handle_cancel(): def handle_cancel():
print(f'>> Cancel processing requested') print(f'>> Cancel processing requested')
canceled.set() canceled.set()
return make_response("OK") socketio.emit('processingCanceled')
# TODO: I think this needs a safety mechanism. # TODO: I think this needs a safety mechanism.
@socketio.on('deleteImage') @socketio.on('deleteImage')
def handle_delete_image(path): def handle_delete_image(path, uuid):
print(f'>> Delete requested "{path}"') print(f'>> Delete requested "{path}"')
send2trash(path) send2trash(path)
return make_response("OK") socketio.emit('imageDeleted', {'url': path, 'uuid': uuid})
# TODO: I think this needs a safety mechanism. # TODO: I think this needs a safety mechanism.
@ -240,11 +302,11 @@ def handle_upload_initial_image(bytes, name):
uuid = uuid4().hex uuid = uuid4().hex
split = os.path.splitext(name) split = os.path.splitext(name)
name = f'{split[0]}.{uuid}{split[1]}' name = f'{split[0]}.{uuid}{split[1]}'
file_path = os.path.join(init_path, name) file_path = os.path.join(init_image_path, name)
os.makedirs(os.path.dirname(file_path), exist_ok=True) os.makedirs(os.path.dirname(file_path), exist_ok=True)
newFile = open(file_path, "wb") newFile = open(file_path, "wb")
newFile.write(bytes) newFile.write(bytes)
return make_response("OK", data=file_path) socketio.emit('initialImageUploaded', {'url': file_path, 'uuid': ''})
# TODO: I think this needs a safety mechanism. # TODO: I think this needs a safety mechanism.
@ -254,11 +316,11 @@ def handle_upload_mask_image(bytes, name):
uuid = uuid4().hex uuid = uuid4().hex
split = os.path.splitext(name) split = os.path.splitext(name)
name = f'{split[0]}.{uuid}{split[1]}' name = f'{split[0]}.{uuid}{split[1]}'
file_path = os.path.join(mask_path, name) file_path = os.path.join(mask_image_path, name)
os.makedirs(os.path.dirname(file_path), exist_ok=True) os.makedirs(os.path.dirname(file_path), exist_ok=True)
newFile = open(file_path, "wb") newFile = open(file_path, "wb")
newFile.write(bytes) newFile.write(bytes)
return make_response("OK", data=file_path) socketio.emit('maskImageUploaded', {'url': file_path, 'uuid': ''})
@ -273,6 +335,13 @@ ADDITIONAL FUNCTIONS
""" """
def make_unique_init_image_filename(name):
uuid = uuid4().hex
split = os.path.splitext(name)
name = f'{split[0]}.{uuid}{split[1]}'
return name
def write_log_message(message, log_path=log_path): def write_log_message(message, log_path=log_path):
"""Logs the filename and parameters used to generate or process that image to log file""" """Logs the filename and parameters used to generate or process that image to log file"""
message = f'{message}\n' message = f'{message}\n'
@ -280,15 +349,6 @@ def write_log_message(message, log_path=log_path):
file.writelines(message) file.writelines(message)
def make_response(status, message=None, data=None):
response = {'status': status}
if message is not None:
response['message'] = message
if data is not None:
response['data'] = data
return response
def save_image(image, parameters, output_dir, step_index=None, postprocessing=False): def save_image(image, parameters, output_dir, step_index=None, postprocessing=False):
seed = parameters['seed'] if 'seed' in parameters else 'unknown_seed' seed = parameters['seed'] if 'seed' in parameters else 'unknown_seed'
@ -310,16 +370,69 @@ def save_image(image, parameters, output_dir, step_index=None, postprocessing=Fa
return path return path
def calculate_real_steps(steps, strength, has_init_image):
return math.floor(strength * steps) if has_init_image else steps
def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters): def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters):
canceled.clear() canceled.clear()
step_index = 1 step_index = 1
"""
If a result image is used as an init image, and then deleted, we will want to be
able to use it as an init image in the future. Need to copy it.
If the init/mask image doesn't exist in the init_image_path/mask_image_path,
make a unique filename for it and copy it there.
"""
if ('init_img' in generation_parameters):
filename = os.path.basename(generation_parameters['init_img'])
if not os.path.exists(os.path.join(init_image_path, filename)):
unique_filename = make_unique_init_image_filename(filename)
new_path = os.path.join(init_image_path, unique_filename)
shutil.copy(generation_parameters['init_img'], new_path)
generation_parameters['init_img'] = new_path
if ('init_mask' in generation_parameters):
filename = os.path.basename(generation_parameters['init_mask'])
if not os.path.exists(os.path.join(mask_image_path, filename)):
unique_filename = make_unique_init_image_filename(filename)
new_path = os.path.join(init_image_path, unique_filename)
shutil.copy(generation_parameters['init_img'], new_path)
generation_parameters['init_mask'] = new_path
totalSteps = calculate_real_steps(
steps=generation_parameters['steps'],
strength=generation_parameters['strength'] if 'strength' in generation_parameters else None,
has_init_image='init_img' in generation_parameters
)
progress = {
'currentStep': 1,
'totalSteps': totalSteps,
'currentIteration': 1,
'totalIterations': generation_parameters['iterations'],
'currentStatus': 'Preparing',
'isProcessing': True,
'currentStatusHasSteps': False
}
socketio.emit('progressUpdate', progress)
eventlet.sleep(0)
def image_progress(sample, step): def image_progress(sample, step):
if canceled.is_set(): if canceled.is_set():
raise CanceledException raise CanceledException
nonlocal step_index nonlocal step_index
nonlocal generation_parameters nonlocal generation_parameters
nonlocal progress
progress['currentStep'] = step + 1
progress['currentStatus'] = 'Generating'
progress['currentStatusHasSteps'] = True
if generation_parameters["progress_images"] and step % 5 == 0 and step < generation_parameters['steps'] - 1: if generation_parameters["progress_images"] and step % 5 == 0 and step < generation_parameters['steps'] - 1:
image = model.sample_to_image(sample) image = model.sample_to_image(sample)
path = save_image(image, generation_parameters, intermediate_path, step_index) path = save_image(image, generation_parameters, intermediate_path, step_index)
@ -327,18 +440,30 @@ def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters)
step_index += 1 step_index += 1
socketio.emit('intermediateResult', { socketio.emit('intermediateResult', {
'url': os.path.relpath(path), 'metadata': generation_parameters}) 'url': os.path.relpath(path), 'metadata': generation_parameters})
socketio.emit('progress', {'step': step + 1}) socketio.emit('progressUpdate', progress)
eventlet.sleep(0) eventlet.sleep(0)
def image_done(image, seed): def image_done(image, seed):
nonlocal generation_parameters nonlocal generation_parameters
nonlocal esrgan_parameters nonlocal esrgan_parameters
nonlocal gfpgan_parameters nonlocal gfpgan_parameters
nonlocal progress
step_index = 1
progress['currentStatus'] = 'Generation complete'
socketio.emit('progressUpdate', progress)
eventlet.sleep(0)
all_parameters = generation_parameters all_parameters = generation_parameters
postprocessing = False postprocessing = False
if esrgan_parameters: if esrgan_parameters:
progress['currentStatus'] = 'Upscaling'
progress['currentStatusHasSteps'] = False
socketio.emit('progressUpdate', progress)
eventlet.sleep(0)
image = real_esrgan_upscale( image = real_esrgan_upscale(
image=image, image=image,
strength=esrgan_parameters['strength'], strength=esrgan_parameters['strength'],
@ -349,6 +474,11 @@ def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters)
all_parameters["upscale"] = [esrgan_parameters['level'], esrgan_parameters['strength']] all_parameters["upscale"] = [esrgan_parameters['level'], esrgan_parameters['strength']]
if gfpgan_parameters: if gfpgan_parameters:
progress['currentStatus'] = 'Fixing faces'
progress['currentStatusHasSteps'] = False
socketio.emit('progressUpdate', progress)
eventlet.sleep(0)
image = run_gfpgan( image = run_gfpgan(
image=image, image=image,
strength=gfpgan_parameters['strength'], strength=gfpgan_parameters['strength'],
@ -359,6 +489,9 @@ def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters)
all_parameters["gfpgan_strength"] = gfpgan_parameters['strength'] all_parameters["gfpgan_strength"] = gfpgan_parameters['strength']
all_parameters['seed'] = seed all_parameters['seed'] = seed
progress['currentStatus'] = 'Saving image'
socketio.emit('progressUpdate', progress)
eventlet.sleep(0)
path = save_image(image, all_parameters, result_path, postprocessing=postprocessing) path = save_image(image, all_parameters, result_path, postprocessing=postprocessing)
command = parameters_to_command(all_parameters) command = parameters_to_command(all_parameters)
@ -366,8 +499,24 @@ def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters)
print(f'Image generated: "{path}"') print(f'Image generated: "{path}"')
write_log_message(f'[Generated] "{path}": {command}') write_log_message(f'[Generated] "{path}": {command}')
if (progress['totalIterations'] > progress['currentIteration']):
progress['currentStep'] = 1
progress['currentIteration'] +=1
progress['currentStatus'] = 'Iteration finished'
progress['currentStatusHasSteps'] = False
else:
progress['currentStep'] = 0
progress['totalSteps'] = 0
progress['currentIteration'] = 0
progress['totalIterations'] = 0
progress['currentStatus'] = 'Finished'
progress['isProcessing'] = False
socketio.emit('progressUpdate', progress)
eventlet.sleep(0)
socketio.emit( socketio.emit(
'result', {'url': os.path.relpath(path), 'type': 'generation', 'metadata': all_parameters}) 'generationResult', {'url': os.path.relpath(path), 'metadata': all_parameters})
eventlet.sleep(0) eventlet.sleep(0)
try: try:
@ -382,7 +531,7 @@ def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters)
except CanceledException: except CanceledException:
pass pass
except Exception as e: except Exception as e:
socketio.emit('error', (str(e))) socketio.emit('error', {'message': (str(e))})
print("\n") print("\n")
traceback.print_exc() traceback.print_exc()
print("\n") print("\n")

View File

@ -1,17 +1,29 @@
# Stable Diffusion Web UI # Stable Diffusion Web UI
## Run ## Run
- `python backend/server.py` serves both frontend and backend at http://localhost:9090 - `python backend/server.py` serves both frontend and backend at http://localhost:9090
## Evironment
Install [node](https://nodejs.org/en/download/) (includes npm) and optionally
[yarn](https://yarnpkg.com/getting-started/install).
From `frontend/` run `npm install` / `yarn install` to install the frontend packages.
## Dev ## Dev
1. From `frontend/`, run `yarn dev` to start the dev server. 1. From `frontend/`, run `npm dev` / `yarn dev` to start the dev server.
2. Note the address it starts up on (probably `http://localhost:5173/`). 2. Note the address it starts up on (probably `http://localhost:5173/`).
3. Edit `backend/server.py`'s `additional_allowed_origins` to include this address, e.g. `additional_allowed_origins = ['http://localhost:5173']`. 3. Edit `backend/server.py`'s `additional_allowed_origins` to include this address, e.g.
`additional_allowed_origins = ['http://localhost:5173']`.
4. Leaving the dev server running, open a new terminal and go to the project root. 4. Leaving the dev server running, open a new terminal and go to the project root.
5. Run `python backend/server.py`. 5. Run `python backend/server.py`.
6. Navigate to the dev server address e.g. `http://localhost:5173/`. 6. Navigate to the dev server address e.g. `http://localhost:5173/`.
To build for dev: `npm build-dev` / `yarn build-dev`
To build for production: `npm build` / `yarn build`
## TODO ## TODO
@ -20,7 +32,6 @@
`framer-motion`. I would prefer to save the ~30kb on bundle and have zero animations. This is on `framer-motion`. I would prefer to save the ~30kb on bundle and have zero animations. This is on
the Chakra roadmap. See https://github.com/chakra-ui/chakra-ui/pull/6368 for last discussion on the Chakra roadmap. See https://github.com/chakra-ui/chakra-ui/pull/6368 for last discussion on
this. Need to check in on this issue periodically. this. Need to check in on this issue periodically.
- More status info e.g. phase of processing, image we are on of the total count, etc
- Mobile friendly layout - Mobile friendly layout
- Proper image gallery/viewer/manager - Proper image gallery/viewer/manager
- Help tooltips and such - Help tooltips and such

File diff suppressed because one or more lines are too long

694
frontend/dist/assets/index.de730902.js vendored Normal file

File diff suppressed because one or more lines are too long

View File

@ -4,7 +4,7 @@
<meta charset="UTF-8" /> <meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /> <meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Stable Diffusion Dream Server</title> <title>Stable Diffusion Dream Server</title>
<script type="module" crossorigin src="/assets/index.3d2e59c5.js"></script> <script type="module" crossorigin src="/assets/index.de730902.js"></script>
<link rel="stylesheet" href="/assets/index.447eb2a9.css"> <link rel="stylesheet" href="/assets/index.447eb2a9.css">
</head> </head>
<body> <body>

View File

@ -1,25 +1,28 @@
import { Grid, GridItem } from '@chakra-ui/react'; import { Grid, GridItem } from '@chakra-ui/react';
import CurrentImageDisplay from './features/gallery/CurrentImageDisplay'; import { useEffect, useState } from 'react';
import LogViewer from './features/system/LogViewer'; import CurrentImageDisplay from '../features/gallery/CurrentImageDisplay';
import PromptInput from './features/sd/PromptInput'; import ImageGallery from '../features/gallery/ImageGallery';
import ProgressBar from './features/header/ProgressBar'; import ProgressBar from '../features/header/ProgressBar';
import { useEffect } from 'react'; import SiteHeader from '../features/header/SiteHeader';
import { useAppDispatch } from './app/hooks'; import OptionsAccordion from '../features/sd/OptionsAccordion';
import { requestAllImages } from './app/socketio'; import ProcessButtons from '../features/sd/ProcessButtons';
import ProcessButtons from './features/sd/ProcessButtons'; import PromptInput from '../features/sd/PromptInput';
import ImageGallery from './features/gallery/ImageGallery'; import LogViewer from '../features/system/LogViewer';
import SiteHeader from './features/header/SiteHeader'; import Loading from '../Loading';
import OptionsAccordion from './features/sd/OptionsAccordion'; import { useAppDispatch } from './store';
import { requestAllImages } from './socketio/actions';
const App = () => { const App = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const [isReady, setIsReady] = useState<boolean>(false);
// Load images from the gallery once // Load images from the gallery once
useEffect(() => { useEffect(() => {
dispatch(requestAllImages()); dispatch(requestAllImages());
setIsReady(true);
}, [dispatch]); }, [dispatch]);
return ( return isReady ? (
<> <>
<Grid <Grid
width="100vw" width="100vw"
@ -57,6 +60,8 @@ const App = () => {
</Grid> </Grid>
<LogViewer /> <LogViewer />
</> </>
) : (
<Loading />
); );
}; };

View File

@ -1,7 +0,0 @@
import { useDispatch, useSelector } from 'react-redux';
import type { TypedUseSelectorHook } from 'react-redux';
import type { RootState, AppDispatch } from './store';
// Use throughout your app instead of plain `useDispatch` and `useSelector`
export const useAppDispatch: () => AppDispatch = useDispatch;
export const useAppSelector: TypedUseSelectorHook<RootState> = useSelector;

View File

@ -1,393 +0,0 @@
import { createAction, Middleware } from '@reduxjs/toolkit';
import { io } from 'socket.io-client';
import {
addImage,
clearIntermediateImage,
removeImage,
SDImage,
SDMetadata,
setGalleryImages,
setIntermediateImage,
} from '../features/gallery/gallerySlice';
import {
addLogEntry,
setCurrentStep,
setIsConnected,
setIsProcessing,
} from '../features/system/systemSlice';
import { v4 as uuidv4 } from 'uuid';
import { setInitialImagePath, setMaskPath } from '../features/sd/sdSlice';
import {
backendToFrontendParameters,
frontendToBackendParameters,
} from './parameterTranslation';
export interface SocketIOResponse {
status: 'OK' | 'ERROR';
message?: string;
data?: any;
}
export const socketioMiddleware = () => {
const { hostname, port } = new URL(window.location.href);
const socketio = io(`http://${hostname}:9090`);
let areListenersSet = false;
const middleware: Middleware = (store) => (next) => (action) => {
const { dispatch, getState } = store;
if (!areListenersSet) {
// CONNECT
socketio.on('connect', () => {
try {
dispatch(setIsConnected(true));
} catch (e) {
console.error(e);
}
});
// DISCONNECT
socketio.on('disconnect', () => {
try {
dispatch(setIsConnected(false));
dispatch(setIsProcessing(false));
dispatch(addLogEntry(`Disconnected from server`));
} catch (e) {
console.error(e);
}
});
// PROCESSING RESULT
socketio.on(
'result',
(data: {
url: string;
type: 'generation' | 'esrgan' | 'gfpgan';
uuid?: string;
metadata: { [key: string]: any };
}) => {
try {
const newUuid = uuidv4();
const { type, url, uuid, metadata } = data;
switch (type) {
case 'generation': {
const translatedMetadata =
backendToFrontendParameters(metadata);
dispatch(
addImage({
uuid: newUuid,
url,
metadata: translatedMetadata,
})
);
dispatch(
addLogEntry(`Image generated: ${url}`)
);
break;
}
case 'esrgan': {
const originalImage =
getState().gallery.images.find(
(i: SDImage) => i.uuid === uuid
);
const newMetadata = {
...originalImage.metadata,
};
newMetadata.shouldRunESRGAN = true;
newMetadata.upscalingLevel =
metadata.upscale[0];
newMetadata.upscalingStrength =
metadata.upscale[1];
dispatch(
addImage({
uuid: newUuid,
url,
metadata: newMetadata,
})
);
dispatch(
addLogEntry(`ESRGAN upscaled: ${url}`)
);
break;
}
case 'gfpgan': {
const originalImage =
getState().gallery.images.find(
(i: SDImage) => i.uuid === uuid
);
const newMetadata = {
...originalImage.metadata,
};
newMetadata.shouldRunGFPGAN = true;
newMetadata.gfpganStrength =
metadata.gfpgan_strength;
dispatch(
addImage({
uuid: newUuid,
url,
metadata: newMetadata,
})
);
dispatch(
addLogEntry(`GFPGAN fixed faces: ${url}`)
);
break;
}
}
dispatch(setIsProcessing(false));
} catch (e) {
console.error(e);
}
}
);
// PROGRESS UPDATE
socketio.on('progress', (data: { step: number }) => {
try {
dispatch(setIsProcessing(true));
dispatch(setCurrentStep(data.step));
} catch (e) {
console.error(e);
}
});
// INTERMEDIATE IMAGE
socketio.on(
'intermediateResult',
(data: { url: string; metadata: SDMetadata }) => {
try {
const uuid = uuidv4();
const { url, metadata } = data;
dispatch(
setIntermediateImage({
uuid,
url,
metadata,
})
);
dispatch(
addLogEntry(`Intermediate image generated: ${url}`)
);
} catch (e) {
console.error(e);
}
}
);
// ERROR FROM BACKEND
socketio.on('error', (message) => {
try {
dispatch(addLogEntry(`Server error: ${message}`));
dispatch(setIsProcessing(false));
dispatch(clearIntermediateImage());
} catch (e) {
console.error(e);
}
});
areListenersSet = true;
}
// HANDLE ACTIONS
switch (action.type) {
// GENERATE IMAGE
case 'socketio/generateImage': {
dispatch(setIsProcessing(true));
dispatch(setCurrentStep(-1));
const {
generationParameters,
esrganParameters,
gfpganParameters,
} = frontendToBackendParameters(
getState().sd,
getState().system
);
socketio.emit(
'generateImage',
generationParameters,
esrganParameters,
gfpganParameters
);
dispatch(
addLogEntry(
`Image generation requested: ${JSON.stringify({
...generationParameters,
...esrganParameters,
...gfpganParameters,
})}`
)
);
break;
}
// RUN ESRGAN (UPSCALING)
case 'socketio/runESRGAN': {
const imageToProcess = action.payload;
dispatch(setIsProcessing(true));
dispatch(setCurrentStep(-1));
const { upscalingLevel, upscalingStrength } = getState().sd;
const esrganParameters = {
upscale: [upscalingLevel, upscalingStrength],
};
socketio.emit('runESRGAN', imageToProcess, esrganParameters);
dispatch(
addLogEntry(
`ESRGAN upscale requested: ${JSON.stringify({
file: imageToProcess.url,
...esrganParameters,
})}`
)
);
break;
}
// RUN GFPGAN (FIX FACES)
case 'socketio/runGFPGAN': {
const imageToProcess = action.payload;
dispatch(setIsProcessing(true));
dispatch(setCurrentStep(-1));
const { gfpganStrength } = getState().sd;
const gfpganParameters = {
gfpgan_strength: gfpganStrength,
};
socketio.emit('runGFPGAN', imageToProcess, gfpganParameters);
dispatch(
addLogEntry(
`GFPGAN fix faces requested: ${JSON.stringify({
file: imageToProcess.url,
...gfpganParameters,
})}`
)
);
break;
}
// DELETE IMAGE
case 'socketio/deleteImage': {
const imageToDelete = action.payload;
const { url } = imageToDelete;
socketio.emit(
'deleteImage',
url,
(response: SocketIOResponse) => {
if (response.status === 'OK') {
dispatch(removeImage(imageToDelete));
dispatch(addLogEntry(`Image deleted: ${url}`));
}
}
);
break;
}
// GET ALL IMAGES FOR GALLERY
case 'socketio/requestAllImages': {
socketio.emit(
'requestAllImages',
(response: SocketIOResponse) => {
dispatch(setGalleryImages(response.data));
dispatch(
addLogEntry(`Loaded ${response.data.length} images`)
);
}
);
break;
}
// CANCEL PROCESSING
case 'socketio/cancelProcessing': {
socketio.emit('cancel', (response: SocketIOResponse) => {
const { intermediateImage } = getState().gallery;
if (response.status === 'OK') {
dispatch(setIsProcessing(false));
if (intermediateImage) {
dispatch(addImage(intermediateImage));
dispatch(
addLogEntry(
`Intermediate image saved: ${intermediateImage.url}`
)
);
dispatch(clearIntermediateImage());
}
dispatch(addLogEntry(`Processing canceled`));
}
});
break;
}
// UPLOAD INITIAL IMAGE
case 'socketio/uploadInitialImage': {
const file = action.payload;
socketio.emit(
'uploadInitialImage',
file,
file.name,
(response: SocketIOResponse) => {
if (response.status === 'OK') {
dispatch(setInitialImagePath(response.data));
dispatch(
addLogEntry(
`Initial image uploaded: ${response.data}`
)
);
}
}
);
break;
}
// UPLOAD MASK IMAGE
case 'socketio/uploadMaskImage': {
const file = action.payload;
socketio.emit(
'uploadMaskImage',
file,
file.name,
(response: SocketIOResponse) => {
if (response.status === 'OK') {
dispatch(setMaskPath(response.data));
dispatch(
addLogEntry(
`Mask image uploaded: ${response.data}`
)
);
}
}
);
break;
}
}
next(action);
};
return middleware;
};
// Actions to be used by app
export const generateImage = createAction<undefined>('socketio/generateImage');
export const runESRGAN = createAction<SDImage>('socketio/runESRGAN');
export const runGFPGAN = createAction<SDImage>('socketio/runGFPGAN');
export const deleteImage = createAction<SDImage>('socketio/deleteImage');
export const requestAllImages = createAction<undefined>(
'socketio/requestAllImages'
);
export const cancelProcessing = createAction<undefined>(
'socketio/cancelProcessing'
);
export const uploadInitialImage = createAction<File>(
'socketio/uploadInitialImage'
);
export const uploadMaskImage = createAction<File>('socketio/uploadMaskImage');

View File

@ -0,0 +1,24 @@
import { createAction } from '@reduxjs/toolkit';
import { SDImage } from '../../features/gallery/gallerySlice';
/**
* We can't use redux-toolkit's createSlice() to make these actions,
* because they have no associated reducer. They only exist to dispatch
* requests to the server via socketio. These actions will be handled
* by the middleware.
*/
export const generateImage = createAction<undefined>('socketio/generateImage');
export const runESRGAN = createAction<SDImage>('socketio/runESRGAN');
export const runGFPGAN = createAction<SDImage>('socketio/runGFPGAN');
export const deleteImage = createAction<SDImage>('socketio/deleteImage');
export const requestAllImages = createAction<undefined>(
'socketio/requestAllImages'
);
export const cancelProcessing = createAction<undefined>(
'socketio/cancelProcessing'
);
export const uploadInitialImage = createAction<File>(
'socketio/uploadInitialImage'
);
export const uploadMaskImage = createAction<File>('socketio/uploadMaskImage');

View File

@ -0,0 +1,101 @@
import { AnyAction, Dispatch, MiddlewareAPI } from '@reduxjs/toolkit';
import dateFormat from 'dateformat';
import { Socket } from 'socket.io-client';
import { frontendToBackendParameters } from '../../common/util/parameterTranslation';
import { SDImage } from '../../features/gallery/gallerySlice';
import {
addLogEntry,
setIsProcessing,
} from '../../features/system/systemSlice';
/**
* Returns an object containing all functions which use `socketio.emit()`.
* i.e. those which make server requests.
*/
const makeSocketIOEmitters = (
store: MiddlewareAPI<Dispatch<AnyAction>, any>,
socketio: Socket
) => {
// We need to dispatch actions to redux and get pieces of state from the store.
const { dispatch, getState } = store;
return {
emitGenerateImage: () => {
dispatch(setIsProcessing(true));
const { generationParameters, esrganParameters, gfpganParameters } =
frontendToBackendParameters(getState().sd, getState().system);
socketio.emit(
'generateImage',
generationParameters,
esrganParameters,
gfpganParameters
);
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Image generation requested: ${JSON.stringify({
...generationParameters,
...esrganParameters,
...gfpganParameters,
})}`,
})
);
},
emitRunESRGAN: (imageToProcess: SDImage) => {
dispatch(setIsProcessing(true));
const { upscalingLevel, upscalingStrength } = getState().sd;
const esrganParameters = {
upscale: [upscalingLevel, upscalingStrength],
};
socketio.emit('runESRGAN', imageToProcess, esrganParameters);
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `ESRGAN upscale requested: ${JSON.stringify({
file: imageToProcess.url,
...esrganParameters,
})}`,
})
);
},
emitRunGFPGAN: (imageToProcess: SDImage) => {
dispatch(setIsProcessing(true));
const { gfpganStrength } = getState().sd;
const gfpganParameters = {
gfpgan_strength: gfpganStrength,
};
socketio.emit('runGFPGAN', imageToProcess, gfpganParameters);
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `GFPGAN fix faces requested: ${JSON.stringify({
file: imageToProcess.url,
...gfpganParameters,
})}`,
})
);
},
emitDeleteImage: (imageToDelete: SDImage) => {
const { url, uuid } = imageToDelete;
socketio.emit('deleteImage', url, uuid);
},
emitRequestAllImages: () => {
socketio.emit('requestAllImages');
},
emitCancelProcessing: () => {
socketio.emit('cancel');
},
emitUploadInitialImage: (file: File) => {
socketio.emit('uploadInitialImage', file, file.name);
},
emitUploadMaskImage: (file: File) => {
socketio.emit('uploadMaskImage', file, file.name);
},
};
};
export default makeSocketIOEmitters;

View File

@ -0,0 +1,338 @@
import { AnyAction, MiddlewareAPI, Dispatch } from '@reduxjs/toolkit';
import { v4 as uuidv4 } from 'uuid';
import dateFormat from 'dateformat';
import {
addLogEntry,
setIsConnected,
setIsProcessing,
SystemStatus,
setSystemStatus,
setCurrentStatus,
} from '../../features/system/systemSlice';
import type {
ServerGenerationResult,
ServerESRGANResult,
ServerGFPGANResult,
ServerIntermediateResult,
ServerError,
ServerGalleryImages,
ServerImageUrlAndUuid,
ServerImageUrl,
} from './types';
import { backendToFrontendParameters } from '../../common/util/parameterTranslation';
import {
addImage,
clearIntermediateImage,
removeImage,
SDImage,
setGalleryImages,
setIntermediateImage,
} from '../../features/gallery/gallerySlice';
import { setInitialImagePath, setMaskPath } from '../../features/sd/sdSlice';
/**
* Returns an object containing listener callbacks for socketio events.
* TODO: This file is large, but simple. Should it be split up further?
*/
const makeSocketIOListeners = (
store: MiddlewareAPI<Dispatch<AnyAction>, any>
) => {
const { dispatch, getState } = store;
return {
/**
* Callback to run when we receive a 'connect' event.
*/
onConnect: () => {
try {
dispatch(setIsConnected(true));
dispatch(setCurrentStatus('Connected'));
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive a 'disconnect' event.
*/
onDisconnect: () => {
try {
dispatch(setIsConnected(false));
dispatch(setIsProcessing(false));
dispatch(setCurrentStatus('Disconnected'));
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Disconnected from server`,
level: 'warning',
})
);
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive a 'generationResult' event.
*/
onGenerationResult: (data: ServerGenerationResult) => {
try {
const { url, metadata } = data;
const newUuid = uuidv4();
const translatedMetadata = backendToFrontendParameters(metadata);
dispatch(
addImage({
uuid: newUuid,
url,
metadata: translatedMetadata,
})
);
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Image generated: ${url}`,
})
);
dispatch(setIsProcessing(false));
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive a 'intermediateResult' event.
*/
onIntermediateResult: (data: ServerIntermediateResult) => {
try {
const uuid = uuidv4();
const { url, metadata } = data;
dispatch(
setIntermediateImage({
uuid,
url,
metadata,
})
);
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Intermediate image generated: ${url}`,
})
);
dispatch(setIsProcessing(false));
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive an 'esrganResult' event.
*/
onESRGANResult: (data: ServerESRGANResult) => {
try {
const { url, uuid, metadata } = data;
const newUuid = uuidv4();
// This image was only ESRGAN'd, grab the original image's metadata
const originalImage = getState().gallery.images.find(
(i: SDImage) => i.uuid === uuid
);
// Retain the original metadata
const newMetadata = {
...originalImage.metadata,
};
// Update the ESRGAN-related fields
newMetadata.shouldRunESRGAN = true;
newMetadata.upscalingLevel = metadata.upscale[0];
newMetadata.upscalingStrength = metadata.upscale[1];
dispatch(
addImage({
uuid: newUuid,
url,
metadata: newMetadata,
})
);
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Upscaled: ${url}`,
})
);
dispatch(setIsProcessing(false));
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive a 'gfpganResult' event.
*/
onGFPGANResult: (data: ServerGFPGANResult) => {
try {
const { url, uuid, metadata } = data;
const newUuid = uuidv4();
// This image was only GFPGAN'd, grab the original image's metadata
const originalImage = getState().gallery.images.find(
(i: SDImage) => i.uuid === uuid
);
// Retain the original metadata
const newMetadata = {
...originalImage.metadata,
};
// Update the GFPGAN-related fields
newMetadata.shouldRunGFPGAN = true;
newMetadata.gfpganStrength = metadata.gfpgan_strength;
dispatch(
addImage({
uuid: newUuid,
url,
metadata: newMetadata,
})
);
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Fixed faces: ${url}`,
})
);
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive a 'progressUpdate' event.
* TODO: Add additional progress phases
*/
onProgressUpdate: (data: SystemStatus) => {
try {
dispatch(setIsProcessing(true));
dispatch(setSystemStatus(data));
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive a 'progressUpdate' event.
*/
onError: (data: ServerError) => {
const { message, additionalData } = data;
if (additionalData) {
// TODO: handle more data than short message
}
try {
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Server error: ${message}`,
level: 'error',
})
);
dispatch(setIsProcessing(false));
dispatch(clearIntermediateImage());
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive a 'galleryImages' event.
*/
onGalleryImages: (data: ServerGalleryImages) => {
const { images } = data;
const preparedImages = images.map((image): SDImage => {
return {
uuid: uuidv4(),
url: image.path,
metadata: backendToFrontendParameters(image.metadata),
};
});
dispatch(setGalleryImages(preparedImages));
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Loaded ${images.length} images`,
})
);
},
/**
* Callback to run when we receive a 'processingCanceled' event.
*/
onProcessingCanceled: () => {
dispatch(setIsProcessing(false));
const { intermediateImage } = getState().gallery;
if (intermediateImage) {
dispatch(addImage(intermediateImage));
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Intermediate image saved: ${intermediateImage.url}`,
})
);
dispatch(clearIntermediateImage());
}
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Processing canceled`,
level: 'warning',
})
);
},
/**
* Callback to run when we receive a 'imageDeleted' event.
*/
onImageDeleted: (data: ServerImageUrlAndUuid) => {
const { url, uuid } = data;
dispatch(removeImage(uuid));
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Image deleted: ${url}`,
})
);
},
/**
* Callback to run when we receive a 'initialImageUploaded' event.
*/
onInitialImageUploaded: (data: ServerImageUrl) => {
const { url } = data;
dispatch(setInitialImagePath(url));
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Initial image uploaded: ${url}`,
})
);
},
/**
* Callback to run when we receive a 'maskImageUploaded' event.
*/
onMaskImageUploaded: (data: ServerImageUrl) => {
const { url } = data;
dispatch(setMaskPath(url));
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Mask image uploaded: ${url}`,
})
);
},
};
};
export default makeSocketIOListeners;

View File

@ -0,0 +1,157 @@
import { Middleware } from '@reduxjs/toolkit';
import { io } from 'socket.io-client';
import makeSocketIOListeners from './listeners';
import makeSocketIOEmitters from './emitters';
import type {
ServerGenerationResult,
ServerESRGANResult,
ServerGFPGANResult,
ServerIntermediateResult,
ServerError,
ServerGalleryImages,
ServerImageUrlAndUuid,
ServerImageUrl,
} from './types';
import { SystemStatus } from '../../features/system/systemSlice';
export const socketioMiddleware = () => {
const { hostname, port } = new URL(window.location.href);
const socketio = io(`http://${hostname}:9090`);
let areListenersSet = false;
const middleware: Middleware = (store) => (next) => (action) => {
const {
onConnect,
onDisconnect,
onError,
onESRGANResult,
onGFPGANResult,
onGenerationResult,
onIntermediateResult,
onProgressUpdate,
onGalleryImages,
onProcessingCanceled,
onImageDeleted,
onInitialImageUploaded,
onMaskImageUploaded,
} = makeSocketIOListeners(store);
const {
emitGenerateImage,
emitRunESRGAN,
emitRunGFPGAN,
emitDeleteImage,
emitRequestAllImages,
emitCancelProcessing,
emitUploadInitialImage,
emitUploadMaskImage,
} = makeSocketIOEmitters(store, socketio);
/**
* If this is the first time the middleware has been called (e.g. during store setup),
* initialize all our socket.io listeners.
*/
if (!areListenersSet) {
socketio.on('connect', () => onConnect());
socketio.on('disconnect', () => onDisconnect());
socketio.on('error', (data: ServerError) => onError(data));
socketio.on('generationResult', (data: ServerGenerationResult) =>
onGenerationResult(data)
);
socketio.on('esrganResult', (data: ServerESRGANResult) =>
onESRGANResult(data)
);
socketio.on('gfpganResult', (data: ServerGFPGANResult) =>
onGFPGANResult(data)
);
socketio.on('intermediateResult', (data: ServerIntermediateResult) =>
onIntermediateResult(data)
);
socketio.on('progressUpdate', (data: SystemStatus) =>
onProgressUpdate(data)
);
socketio.on('galleryImages', (data: ServerGalleryImages) =>
onGalleryImages(data)
);
socketio.on('processingCanceled', () => {
onProcessingCanceled();
});
socketio.on('imageDeleted', (data: ServerImageUrlAndUuid) => {
onImageDeleted(data);
});
socketio.on('initialImageUploaded', (data: ServerImageUrl) => {
onInitialImageUploaded(data);
});
socketio.on('maskImageUploaded', (data: ServerImageUrl) => {
onMaskImageUploaded(data);
});
areListenersSet = true;
}
/**
* Handle redux actions caught by middleware.
*/
switch (action.type) {
case 'socketio/generateImage': {
emitGenerateImage();
break;
}
case 'socketio/runESRGAN': {
emitRunESRGAN(action.payload);
break;
}
case 'socketio/runGFPGAN': {
emitRunGFPGAN(action.payload);
break;
}
case 'socketio/deleteImage': {
emitDeleteImage(action.payload);
break;
}
case 'socketio/requestAllImages': {
emitRequestAllImages();
break;
}
case 'socketio/cancelProcessing': {
emitCancelProcessing();
break;
}
case 'socketio/uploadInitialImage': {
emitUploadInitialImage(action.payload);
break;
}
case 'socketio/uploadMaskImage': {
emitUploadMaskImage(action.payload);
break;
}
}
next(action);
};
return middleware;
};

46
frontend/src/app/socketio/types.d.ts vendored Normal file
View File

@ -0,0 +1,46 @@
/**
* Interfaces used by the socketio middleware.
*/
export declare interface ServerGenerationResult {
url: string;
metadata: { [key: string]: any };
}
export declare interface ServerESRGANResult {
url: string;
uuid: string;
metadata: { [key: string]: any };
}
export declare interface ServerGFPGANResult {
url: string;
uuid: string;
metadata: { [key: string]: any };
}
export declare interface ServerIntermediateResult {
url: string;
metadata: { [key: string]: any };
}
export declare interface ServerError {
message: string;
additionalData?: string;
}
export declare interface ServerGalleryImages {
images: Array<{
path: string;
metadata: { [key: string]: any };
}>;
}
export declare interface ServerImageUrlAndUuid {
uuid: string;
url: string;
}
export declare interface ServerImageUrl {
url: string;
}

View File

@ -1,53 +1,78 @@
import { combineReducers, configureStore } from '@reduxjs/toolkit'; import { combineReducers, configureStore } from '@reduxjs/toolkit';
import { useDispatch, useSelector } from 'react-redux';
import type { TypedUseSelectorHook } from 'react-redux';
import { persistReducer } from 'redux-persist'; import { persistReducer } from 'redux-persist';
import storage from 'redux-persist/lib/storage'; // defaults to localStorage for web import storage from 'redux-persist/lib/storage'; // defaults to localStorage for web
import sdReducer from '../features/sd/sdSlice'; import sdReducer from '../features/sd/sdSlice';
import galleryReducer from '../features/gallery/gallerySlice'; import galleryReducer from '../features/gallery/gallerySlice';
import systemReducer from '../features/system/systemSlice'; import systemReducer from '../features/system/systemSlice';
import { socketioMiddleware } from './socketio'; import { socketioMiddleware } from './socketio/middleware';
/**
* redux-persist provides an easy and reliable way to persist state across reloads.
*
* While we definitely want generation parameters to be persisted, there are a number
* of things we do *not* want to be persisted across reloads:
* - Gallery/selected image (user may add/delete images from disk between page loads)
* - Connection/processing status
* - Availability of external libraries like ESRGAN/GFPGAN
*
* These can be blacklisted in redux-persist.
*
* The necesssary nested persistors with blacklists are configured below.
*
* TODO: Do we blacklist initialImagePath? If the image is deleted from disk we get an
* ugly 404. But if we blacklist it, then this is a valuable parameter that is lost
* on reload. Need to figure out a good way to handle this.
*/
const rootPersistConfig = {
key: 'root',
storage,
blacklist: ['gallery', 'system'],
};
const systemPersistConfig = {
key: 'system',
storage,
blacklist: [
'isConnected',
'isProcessing',
'currentStep',
'socketId',
'isESRGANAvailable',
'isGFPGANAvailable',
'currentStep',
'totalSteps',
'currentIteration',
'totalIterations',
'currentStatus',
],
};
const reducers = combineReducers({ const reducers = combineReducers({
sd: sdReducer, sd: sdReducer,
gallery: galleryReducer, gallery: galleryReducer,
system: systemReducer, system: persistReducer(systemPersistConfig, systemReducer),
}); });
const persistConfig = { const persistedReducer = persistReducer(rootPersistConfig, reducers);
key: 'root',
storage,
};
const persistedReducer = persistReducer(persistConfig, reducers);
/*
The frontend needs to be distributed as a production build, so
we cannot reasonably ask users to edit the JS and specify the
host and port on which the socket.io server will run.
The solution is to allow server script to be run with arguments
(or just edited) providing the host and port. Then, the server
serves a route `/socketio_config` which responds with the host
and port.
When the frontend loads, it synchronously requests that route
and thus gets the host and port. This requires a suspicious
fetch somewhere, and the store setup seems like as good a place
as any to make this fetch request.
*/
// Continue with store setup // Continue with store setup
export const store = configureStore({ export const store = configureStore({
reducer: persistedReducer, reducer: persistedReducer,
middleware: (getDefaultMiddleware) => middleware: (getDefaultMiddleware) =>
getDefaultMiddleware({ getDefaultMiddleware({
// redux-persist sometimes needs to have a function in redux, need to disable this check // redux-persist sometimes needs to temporarily put a function in redux state, need to disable this check
serializableCheck: false, serializableCheck: false,
}).concat(socketioMiddleware()), }).concat(socketioMiddleware()),
}); });
// Infer the `RootState` and `AppDispatch` types from the store itself
export type RootState = ReturnType<typeof store.getState>; export type RootState = ReturnType<typeof store.getState>;
// Inferred type: {posts: PostsState, comments: CommentsState, users: UsersState}
export type AppDispatch = typeof store.dispatch; export type AppDispatch = typeof store.dispatch;
// Use throughout your app instead of plain `useDispatch` and `useSelector`
export const useAppDispatch: () => AppDispatch = useDispatch;
export const useAppSelector: TypedUseSelectorHook<RootState> = useSelector;

View File

@ -0,0 +1,171 @@
/**
* Defines common parameters required to generate an image.
* See #266 for the eventual maturation of this interface.
*/
interface CommonParameters {
/**
* The "txt2img" prompt. String. Minimum one character. No maximum.
*/
prompt: string;
/**
* The number of sampler steps. Integer. Minimum value 1. No maximum.
*/
steps: number;
/**
* Classifier-free guidance scale. Float. Minimum value 0. Maximum?
*/
cfgScale: number;
/**
* Height of output image in pixels. Integer. Minimum 64. Must be multiple of 64. No maximum.
*/
height: number;
/**
* Width of output image in pixels. Integer. Minimum 64. Must be multiple of 64. No maximum.
*/
width: number;
/**
* Name of the sampler to use. String. Restricted values.
*/
sampler:
| 'ddim'
| 'plms'
| 'k_lms'
| 'k_dpm_2'
| 'k_dpm_2_a'
| 'k_euler'
| 'k_euler_a'
| 'k_heun';
/**
* Seed used for randomness. Integer. 0 --> 4294967295, inclusive.
*/
seed: number;
/**
* Flag to enable seamless tiling image generation. Boolean.
*/
seamless: boolean;
}
/**
* Defines parameters needed to use the "img2img" generation method.
*/
interface ImageToImageParameters {
/**
* Folder path to the image used as the initial image. String.
*/
initialImagePath: string;
/**
* Flag to enable the use of a mask image during "img2img" generations.
* Requires valid ImageToImageParameters. Boolean.
*/
shouldUseMaskImage: boolean;
/**
* Folder path to the image used as a mask image. String.
*/
maskImagePath: string;
/**
* Strength of adherance to initial image. Float. 0 --> 1, exclusive.
*/
img2imgStrength: number;
/**
* Flag to enable the stretching of init image to desired output. Boolean.
*/
shouldFit: boolean;
}
/**
* Defines the parameters needed to generate variations.
*/
interface VariationParameters {
/**
* Variation amount. Float. 0 --> 1, exclusive.
* TODO: What does this really do?
*/
variationAmount: number;
/**
* List of seed-weight pairs formatted as "seed:weight,...".
* Seed is a valid seed. Weight is a float, 0 --> 1, exclusive.
* String, must be parseable into [[seed,weight],...] format.
*/
seedWeights: string;
}
/**
* Defines the parameters needed to use GFPGAN postprocessing.
*/
interface GFPGANParameters {
/**
* GFPGAN strength. Strength to apply face-fixing processing. Float. 0 --> 1, exclusive.
*/
gfpganStrength: number;
}
/**
* Defines the parameters needed to use ESRGAN postprocessing.
*/
interface ESRGANParameters {
/**
* ESRGAN strength. Strength to apply upscaling. Float. 0 --> 1, exclusive.
*/
esrganStrength: number;
/**
* ESRGAN upscaling scale. One of 2x | 4x. Represented as integer.
*/
esrganScale: 2 | 4;
}
/**
* Extends the generation and processing method parameters, adding flags to enable each.
*/
interface ProcessingParameters extends CommonParameters {
/**
* Flag to enable the generation of variations. Requires valid VariationParameters. Boolean.
*/
shouldGenerateVariations: boolean;
/**
* Variation parameters.
*/
variationParameters: VariationParameters;
/**
* Flag to enable the use of an initial image, i.e. to use "img2img" generation.
* Requires valid ImageToImageParameters. Boolean.
*/
shouldUseImageToImage: boolean;
/**
* ImageToImage parameters.
*/
imageToImageParameters: ImageToImageParameters;
/**
* Flag to enable GFPGAN postprocessing. Requires valid GFPGANParameters. Boolean.
*/
shouldRunGFPGAN: boolean;
/**
* GFPGAN parameters.
*/
gfpganParameters: GFPGANParameters;
/**
* Flag to enable ESRGAN postprocessing. Requires valid ESRGANParameters. Boolean.
*/
shouldRunESRGAN: boolean;
/**
* ESRGAN parameters.
*/
esrganParameters: GFPGANParameters;
}
/**
* Extends ProcessingParameters, adding items needed to request processing.
*/
interface ProcessingState extends ProcessingParameters {
/**
* Number of images to generate. Integer. Minimum 1.
*/
iterations: number;
/**
* Flag to enable the randomization of the seed on each generation. Boolean.
*/
shouldRandomizeSeed: boolean;
}
export {}

View File

@ -1,11 +1,11 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash'; import { isEqual } from 'lodash';
import { useMemo } from 'react'; import { useMemo } from 'react';
import { useAppSelector } from '../../app/hooks'; import { useAppSelector } from '../../app/store';
import { RootState } from '../../app/store'; import { RootState } from '../../app/store';
import { SDState } from '../sd/sdSlice'; import { SDState } from '../../features/sd/sdSlice';
import { validateSeedWeights } from '../sd/util/seedWeightPairs'; import { SystemState } from '../../features/system/systemSlice';
import { SystemState } from './systemSlice'; import { validateSeedWeights } from '../util/seedWeightPairs';
const sdSelector = createSelector( const sdSelector = createSelector(
(state: RootState) => state.sd, (state: RootState) => state.sd,

View File

@ -1,17 +1,15 @@
import { SDState } from '../features/sd/sdSlice';
import randomInt from '../features/sd/util/randomInt';
import {
seedWeightsToString,
stringToSeedWeights,
} from '../features/sd/util/seedWeightPairs';
import { SystemState } from '../features/system/systemSlice';
import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from './constants';
/* /*
These functions translate frontend state into parameters These functions translate frontend state into parameters
suitable for consumption by the backend, and vice-versa. suitable for consumption by the backend, and vice-versa.
*/ */
import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from "../../app/constants";
import { SDState } from "../../features/sd/sdSlice";
import { SystemState } from "../../features/system/systemSlice";
import randomInt from "./randomInt";
import { seedWeightsToString, stringToSeedWeights } from "./seedWeightPairs";
export const frontendToBackendParameters = ( export const frontendToBackendParameters = (
sdState: SDState, sdState: SDState,
systemState: SystemState systemState: SystemState
@ -77,7 +75,7 @@ export const frontendToBackendParameters = (
stringToSeedWeights(seedWeights); stringToSeedWeights(seedWeights);
} }
} else { } else {
generationParameters.variation_amount = 0.1; generationParameters.variation_amount = 0;
} }
let esrganParameters: false | { [k: string]: any } = false; let esrganParameters: false | { [k: string]: any } = false;
@ -96,6 +94,8 @@ export const frontendToBackendParameters = (
}; };
} }
console.log(generationParameters)
return { return {
generationParameters, generationParameters,
esrganParameters, esrganParameters,

View File

@ -1,14 +1,14 @@
import { Flex } from '@chakra-ui/react'; import { Flex } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from '../../app/hooks'; import { useAppDispatch, useAppSelector } from '../../app/store';
import { RootState } from '../../app/store'; import { RootState } from '../../app/store';
import { setAllParameters, setInitialImagePath, setSeed } from '../sd/sdSlice'; import { setAllParameters, setInitialImagePath, setSeed } from '../sd/sdSlice';
import DeleteImageModal from './DeleteImageModal'; import DeleteImageModal from './DeleteImageModal';
import SDButton from '../../components/SDButton';
import { runESRGAN, runGFPGAN } from '../../app/socketio';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { SystemState } from '../system/systemSlice'; import { SystemState } from '../system/systemSlice';
import { isEqual } from 'lodash'; import { isEqual } from 'lodash';
import { SDImage } from './gallerySlice'; import { SDImage } from './gallerySlice';
import SDButton from '../../common/components/SDButton';
import { runESRGAN, runGFPGAN } from '../../app/socketio/actions';
const systemSelector = createSelector( const systemSelector = createSelector(
(state: RootState) => state.system, (state: RootState) => state.system,

View File

@ -1,5 +1,5 @@
import { Center, Flex, Image, Text, useColorModeValue } from '@chakra-ui/react'; import { Center, Flex, Image, Text, useColorModeValue } from '@chakra-ui/react';
import { useAppSelector } from '../../app/hooks'; import { useAppSelector } from '../../app/store';
import { RootState } from '../../app/store'; import { RootState } from '../../app/store';
import { useState } from 'react'; import { useState } from 'react';
import ImageMetadataViewer from './ImageMetadataViewer'; import ImageMetadataViewer from './ImageMetadataViewer';

View File

@ -21,8 +21,8 @@ import {
SyntheticEvent, SyntheticEvent,
useRef, useRef,
} from 'react'; } from 'react';
import { useAppDispatch, useAppSelector } from '../../app/hooks'; import { useAppDispatch, useAppSelector } from '../../app/store';
import { deleteImage } from '../../app/socketio'; import { deleteImage } from '../../app/socketio/actions';
import { RootState } from '../../app/store'; import { RootState } from '../../app/store';
import { setShouldConfirmOnDelete, SystemState } from '../system/systemSlice'; import { setShouldConfirmOnDelete, SystemState } from '../system/systemSlice';
import { SDImage } from './gallerySlice'; import { SDImage } from './gallerySlice';

View File

@ -6,7 +6,7 @@ import {
Image, Image,
useColorModeValue, useColorModeValue,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { useAppDispatch } from '../../app/hooks'; import { useAppDispatch } from '../../app/store';
import { SDImage, setCurrentImage } from './gallerySlice'; import { SDImage, setCurrentImage } from './gallerySlice';
import { FaCheck, FaCopy, FaSeedling, FaTrash } from 'react-icons/fa'; import { FaCheck, FaCopy, FaSeedling, FaTrash } from 'react-icons/fa';
import DeleteImageModal from './DeleteImageModal'; import DeleteImageModal from './DeleteImageModal';

View File

@ -1,6 +1,6 @@
import { Flex } from '@chakra-ui/react'; import { Center, Flex, Text } from '@chakra-ui/react';
import { RootState } from '../../app/store'; import { RootState } from '../../app/store';
import { useAppSelector } from '../../app/hooks'; import { useAppSelector } from '../../app/store';
import HoverableImage from './HoverableImage'; import HoverableImage from './HoverableImage';
/** /**
@ -19,7 +19,7 @@ const ImageGallery = () => {
* TODO: Refactor if performance complaints, or after migrating to new API which supports pagination. * TODO: Refactor if performance complaints, or after migrating to new API which supports pagination.
*/ */
return ( return images.length ? (
<Flex gap={2} wrap="wrap" pb={2}> <Flex gap={2} wrap="wrap" pb={2}>
{[...images].reverse().map((image) => { {[...images].reverse().map((image) => {
const { uuid } = image; const { uuid } = image;
@ -29,6 +29,10 @@ const ImageGallery = () => {
); );
})} })}
</Flex> </Flex>
) : (
<Center height={'100%'} position={'relative'}>
<Text size={'xl'}>No images in gallery</Text>
</Center>
); );
}; };

View File

@ -10,8 +10,8 @@ import {
import { memo } from 'react'; import { memo } from 'react';
import { FaPlus } from 'react-icons/fa'; import { FaPlus } from 'react-icons/fa';
import { PARAMETERS } from '../../app/constants'; import { PARAMETERS } from '../../app/constants';
import { useAppDispatch } from '../../app/hooks'; import { useAppDispatch } from '../../app/store';
import SDButton from '../../components/SDButton'; import SDButton from '../../common/components/SDButton';
import { setAllParameters, setParameter } from '../sd/sdSlice'; import { setAllParameters, setParameter } from '../sd/sdSlice';
import { SDImage, SDMetadata } from './gallerySlice'; import { SDImage, SDMetadata } from './gallerySlice';

View File

@ -1,8 +1,7 @@
import { createSlice } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit';
import type { PayloadAction } from '@reduxjs/toolkit'; import type { PayloadAction } from '@reduxjs/toolkit';
import { v4 as uuidv4 } from 'uuid';
import { UpscalingLevel } from '../sd/sdSlice'; import { UpscalingLevel } from '../sd/sdSlice';
import { backendToFrontendParameters } from '../../app/parameterTranslation'; import { clamp } from 'lodash';
// TODO: Revise pending metadata RFC: https://github.com/lstein/stable-diffusion/issues/266 // TODO: Revise pending metadata RFC: https://github.com/lstein/stable-diffusion/issues/266
export interface SDMetadata { export interface SDMetadata {
@ -50,22 +49,38 @@ export const gallerySlice = createSlice({
state.currentImage = action.payload; state.currentImage = action.payload;
state.currentImageUuid = action.payload.uuid; state.currentImageUuid = action.payload.uuid;
}, },
removeImage: (state, action: PayloadAction<SDImage>) => { removeImage: (state, action: PayloadAction<string>) => {
const { uuid } = action.payload; const uuid = action.payload;
const newImages = state.images.filter((image) => image.uuid !== uuid); const newImages = state.images.filter((image) => image.uuid !== uuid);
if (uuid === state.currentImageUuid) {
/**
* We are deleting the currently selected image.
*
* We want the new currentl selected image to be under the cursor in the
* gallery, so we need to do some fanagling. The currently selected image
* is set by its UUID, not its index in the image list.
*
* Get the currently selected image's index.
*/
const imageToDeleteIndex = state.images.findIndex( const imageToDeleteIndex = state.images.findIndex(
(image) => image.uuid === uuid (image) => image.uuid === uuid
); );
const newCurrentImageIndex = Math.min( /**
Math.max(imageToDeleteIndex, 0), * New current image needs to be in the same spot, but because the gallery
* is sorted in reverse order, the new current image's index will actuall be
* one less than the deleted image's index.
*
* Clamp the new index to ensure it is valid..
*/
const newCurrentImageIndex = clamp(
imageToDeleteIndex - 1,
0,
newImages.length - 1 newImages.length - 1
); );
state.images = newImages;
state.currentImage = newImages.length state.currentImage = newImages.length
? newImages[newCurrentImageIndex] ? newImages[newCurrentImageIndex]
: undefined; : undefined;
@ -73,6 +88,9 @@ export const gallerySlice = createSlice({
state.currentImageUuid = newImages.length state.currentImageUuid = newImages.length
? newImages[newCurrentImageIndex].uuid ? newImages[newCurrentImageIndex].uuid
: ''; : '';
}
state.images = newImages;
}, },
addImage: (state, action: PayloadAction<SDImage>) => { addImage: (state, action: PayloadAction<SDImage>) => {
state.images.push(action.payload); state.images.push(action.payload);
@ -86,48 +104,14 @@ export const gallerySlice = createSlice({
clearIntermediateImage: (state) => { clearIntermediateImage: (state) => {
state.intermediateImage = undefined; state.intermediateImage = undefined;
}, },
setGalleryImages: ( setGalleryImages: (state, action: PayloadAction<Array<SDImage>>) => {
state, const newImages = action.payload;
action: PayloadAction< if (newImages.length) {
Array<{
path: string;
metadata: { [key: string]: string | number | boolean };
}>
>
) => {
// TODO: Revise pending metadata RFC: https://github.com/lstein/stable-diffusion/issues/266
const images = action.payload;
if (images.length === 0) {
// there are no images on disk, clear the gallery
state.images = [];
state.currentImageUuid = '';
state.currentImage = undefined;
} else {
// Filter image urls that are already in the rehydrated state
const filteredImages = action.payload.filter(
(image) => !state.images.find((i) => i.url === image.path)
);
const preparedImages = filteredImages.map((image): SDImage => {
return {
uuid: uuidv4(),
url: image.path,
metadata: backendToFrontendParameters(image.metadata),
};
});
const newImages = [...state.images].concat(preparedImages);
// if previous currentimage no longer exists, set a new one
if (!newImages.find((image) => image.uuid === state.currentImageUuid)) {
const newCurrentImage = newImages[newImages.length - 1]; const newCurrentImage = newImages[newImages.length - 1];
state.images = newImages;
state.currentImage = newCurrentImage; state.currentImage = newCurrentImage;
state.currentImageUuid = newCurrentImage.uuid; state.currentImageUuid = newCurrentImage.uuid;
} }
state.images = newImages;
}
}, },
}, },
}); });

View File

@ -1,33 +1,36 @@
import { Progress } from '@chakra-ui/react'; import { Progress } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash'; import { isEqual } from 'lodash';
import { useAppSelector } from '../../app/hooks'; import { useAppSelector } from '../../app/store';
import { RootState } from '../../app/store'; import { RootState } from '../../app/store';
import { SDState } from '../sd/sdSlice'; import { SystemState } from '../system/systemSlice';
const sdSelector = createSelector( const systemSelector = createSelector(
(state: RootState) => state.sd, (state: RootState) => state.system,
(sd: SDState) => { (system: SystemState) => {
return { return {
realSteps: sd.realSteps, isProcessing: system.isProcessing,
currentStep: system.currentStep,
totalSteps: system.totalSteps,
currentStatusHasSteps: system.currentStatusHasSteps,
}; };
}, },
{ {
memoizeOptions: { memoizeOptions: { resultEqualityCheck: isEqual },
resultEqualityCheck: isEqual,
},
} }
); );
const ProgressBar = () => { const ProgressBar = () => {
const { realSteps } = useAppSelector(sdSelector); const { isProcessing, currentStep, totalSteps, currentStatusHasSteps } =
const { currentStep } = useAppSelector((state: RootState) => state.system); useAppSelector(systemSelector);
const progress = Math.round((currentStep * 100) / realSteps);
const value = currentStep ? Math.round((currentStep * 100) / totalSteps) : 0;
return ( return (
<Progress <Progress
height='10px' height="10px"
value={progress} value={value}
isIndeterminate={progress < 0 || currentStep === realSteps} isIndeterminate={isProcessing && !currentStatusHasSteps}
/> />
); );
}; };

View File

@ -12,15 +12,20 @@ import { isEqual } from 'lodash';
import { FaSun, FaMoon, FaGithub } from 'react-icons/fa'; import { FaSun, FaMoon, FaGithub } from 'react-icons/fa';
import { MdHelp, MdSettings } from 'react-icons/md'; import { MdHelp, MdSettings } from 'react-icons/md';
import { useAppSelector } from '../../app/hooks'; import { useAppSelector } from '../../app/store';
import { RootState } from '../../app/store'; import { RootState } from '../../app/store';
import SettingsModal from '../system/SettingsModal'; import SettingsModal from '../system/SettingsModal';
import { SystemState } from '../system/systemSlice'; import { SystemState } from '../system/systemSlice';
const systemSelector = createSelector( const systemSelector = createSelector(
(state: RootState) => state.system, (state: RootState) => state.system,
(system: SystemState) => { (system: SystemState) => {
return { isConnected: system.isConnected }; return {
isConnected: system.isConnected,
isProcessing: system.isProcessing,
currentIteration: system.currentIteration,
totalIterations: system.totalIterations,
currentStatus: system.currentStatus,
};
}, },
{ {
memoizeOptions: { resultEqualityCheck: isEqual }, memoizeOptions: { resultEqualityCheck: isEqual },
@ -32,11 +37,13 @@ const systemSelector = createSelector(
*/ */
const SiteHeader = () => { const SiteHeader = () => {
const { colorMode, toggleColorMode } = useColorMode(); const { colorMode, toggleColorMode } = useColorMode();
const { isConnected } = useAppSelector(systemSelector); const {
isConnected,
const statusMessage = isConnected isProcessing,
? `Connected to server` currentIteration,
: 'No connection to server'; totalIterations,
currentStatus,
} = useAppSelector(systemSelector);
const statusMessageTextColor = isConnected ? 'green.500' : 'red.500'; const statusMessageTextColor = isConnected ? 'green.500' : 'red.500';
@ -45,6 +52,14 @@ const SiteHeader = () => {
// Make FaMoon and FaSun icon apparent size consistent // Make FaMoon and FaSun icon apparent size consistent
const colorModeIconFontSize = colorMode == 'light' ? 18 : 20; const colorModeIconFontSize = colorMode == 'light' ? 18 : 20;
let statusMessage = currentStatus;
if (isProcessing) {
if (totalIterations > 1) {
statusMessage += ` [${currentIteration}/${totalIterations}]`;
}
}
return ( return (
<Flex minWidth="max-content" alignItems="center" gap="1" pl={2} pr={1}> <Flex minWidth="max-content" alignItems="center" gap="1" pl={2} pr={1}>
<Heading size={'lg'}>Stable Diffusion Dream Server</Heading> <Heading size={'lg'}>Stable Diffusion Dream Server</Heading>

View File

@ -1,7 +1,7 @@
import { Flex } from '@chakra-ui/react'; import { Flex } from '@chakra-ui/react';
import { RootState } from '../../app/store'; import { RootState } from '../../app/store';
import { useAppDispatch, useAppSelector } from '../../app/hooks'; import { useAppDispatch, useAppSelector } from '../../app/store';
import { import {
setUpscalingLevel, setUpscalingLevel,
@ -10,14 +10,14 @@ import {
SDState, SDState,
} from '../sd/sdSlice'; } from '../sd/sdSlice';
import SDNumberInput from '../../components/SDNumberInput';
import SDSelect from '../../components/SDSelect';
import { UPSCALING_LEVELS } from '../../app/constants'; import { UPSCALING_LEVELS } from '../../app/constants';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash'; import { isEqual } from 'lodash';
import { SystemState } from '../system/systemSlice'; import { SystemState } from '../system/systemSlice';
import { ChangeEvent } from 'react'; import { ChangeEvent } from 'react';
import SDNumberInput from '../../common/components/SDNumberInput';
import SDSelect from '../../common/components/SDSelect';
const sdSelector = createSelector( const sdSelector = createSelector(
(state: RootState) => state.sd, (state: RootState) => state.sd,

View File

@ -1,15 +1,15 @@
import { Flex } from '@chakra-ui/react'; import { Flex } from '@chakra-ui/react';
import { RootState } from '../../app/store'; import { RootState } from '../../app/store';
import { useAppDispatch, useAppSelector } from '../../app/hooks'; import { useAppDispatch, useAppSelector } from '../../app/store';
import { SDState, setGfpganStrength } from '../sd/sdSlice'; import { SDState, setGfpganStrength } from '../sd/sdSlice';
import SDNumberInput from '../../components/SDNumberInput';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash'; import { isEqual } from 'lodash';
import { SystemState } from '../system/systemSlice'; import { SystemState } from '../system/systemSlice';
import SDNumberInput from '../../common/components/SDNumberInput';
const sdSelector = createSelector( const sdSelector = createSelector(
(state: RootState) => state.sd, (state: RootState) => state.sd,

View File

@ -1,10 +1,10 @@
import { Flex } from '@chakra-ui/react'; import { Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { ChangeEvent } from 'react'; import { ChangeEvent } from 'react';
import { useAppDispatch, useAppSelector } from '../../app/hooks'; import { useAppDispatch, useAppSelector } from '../../app/store';
import { RootState } from '../../app/store'; import { RootState } from '../../app/store';
import SDNumberInput from '../../components/SDNumberInput'; import SDNumberInput from '../../common/components/SDNumberInput';
import SDSwitch from '../../components/SDSwitch'; import SDSwitch from '../../common/components/SDSwitch';
import InitAndMaskImage from './InitAndMaskImage'; import InitAndMaskImage from './InitAndMaskImage';
import { import {
SDState, SDState,

View File

@ -1,6 +1,6 @@
import { Flex, Image } from '@chakra-ui/react'; import { Flex, Image } from '@chakra-ui/react';
import { useState } from 'react'; import { useState } from 'react';
import { useAppSelector } from '../../app/hooks'; import { useAppSelector } from '../../app/store';
import { RootState } from '../../app/store'; import { RootState } from '../../app/store';
import { SDState } from '../../features/sd/sdSlice'; import { SDState } from '../../features/sd/sdSlice';
import './InitAndMaskImage.css'; import './InitAndMaskImage.css';

View File

@ -1,14 +1,14 @@
import { Button, Flex, IconButton, useToast } from '@chakra-ui/react'; import { Button, Flex, IconButton, useToast } from '@chakra-ui/react';
import { SyntheticEvent, useCallback } from 'react'; import { SyntheticEvent, useCallback } from 'react';
import { FaTrash } from 'react-icons/fa'; import { FaTrash } from 'react-icons/fa';
import { useAppDispatch, useAppSelector } from '../../app/hooks'; import { useAppDispatch, useAppSelector } from '../../app/store';
import { RootState } from '../../app/store'; import { RootState } from '../../app/store';
import { import {
SDState, SDState,
setInitialImagePath, setInitialImagePath,
setMaskPath, setMaskPath,
} from '../../features/sd/sdSlice'; } from '../../features/sd/sdSlice';
import { uploadInitialImage, uploadMaskImage } from '../../app/socketio'; import { uploadInitialImage, uploadMaskImage } from '../../app/socketio/actions';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash'; import { isEqual } from 'lodash';
import ImageUploader from './ImageUploader'; import ImageUploader from './ImageUploader';

View File

@ -12,7 +12,7 @@ import {
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { RootState } from '../../app/store'; import { RootState } from '../../app/store';
import { useAppDispatch, useAppSelector } from '../../app/hooks'; import { useAppDispatch, useAppSelector } from '../../app/store';
import { import {
setShouldRunGFPGAN, setShouldRunGFPGAN,

View File

@ -1,17 +1,17 @@
import { Flex } from '@chakra-ui/react'; import { Flex } from '@chakra-ui/react';
import { RootState } from '../../app/store'; import { RootState } from '../../app/store';
import { useAppDispatch, useAppSelector } from '../../app/hooks'; import { useAppDispatch, useAppSelector } from '../../app/store';
import { setHeight, setWidth, setSeamless, SDState } from '../sd/sdSlice'; import { setHeight, setWidth, setSeamless, SDState } from '../sd/sdSlice';
import SDSelect from '../../components/SDSelect';
import { HEIGHTS, WIDTHS } from '../../app/constants'; import { HEIGHTS, WIDTHS } from '../../app/constants';
import SDSwitch from '../../components/SDSwitch';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash'; import { isEqual } from 'lodash';
import { ChangeEvent } from 'react'; import { ChangeEvent } from 'react';
import SDSelect from '../../common/components/SDSelect';
import SDSwitch from '../../common/components/SDSwitch';
const sdSelector = createSelector( const sdSelector = createSelector(
(state: RootState) => state.sd, (state: RootState) => state.sd,

View File

@ -1,12 +1,12 @@
import { Flex } from '@chakra-ui/react'; import { Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash'; import { isEqual } from 'lodash';
import { useAppDispatch, useAppSelector } from '../../app/hooks'; import { useAppDispatch, useAppSelector } from '../../app/store';
import { cancelProcessing, generateImage } from '../../app/socketio'; import { cancelProcessing, generateImage } from '../../app/socketio/actions';
import { RootState } from '../../app/store'; import { RootState } from '../../app/store';
import SDButton from '../../components/SDButton'; import SDButton from '../../common/components/SDButton';
import useCheckParameters from '../../common/hooks/useCheckParameters';
import { SystemState } from '../system/systemSlice'; import { SystemState } from '../system/systemSlice';
import useCheckParameters from '../system/useCheckParameters';
const systemSelector = createSelector( const systemSelector = createSelector(
(state: RootState) => state.system, (state: RootState) => state.system,

View File

@ -3,7 +3,8 @@ import {
ChangeEvent, ChangeEvent,
KeyboardEvent, KeyboardEvent,
} from 'react'; } from 'react';
import { useAppDispatch, useAppSelector } from '../../app/hooks'; import { useAppDispatch, useAppSelector } from '../../app/store';
import { generateImage } from '../../app/socketio/actions';
import { RootState } from '../../app/store'; import { RootState } from '../../app/store';
import { setPrompt } from '../sd/sdSlice'; import { setPrompt } from '../sd/sdSlice';

View File

@ -1,17 +1,17 @@
import { Flex } from '@chakra-ui/react'; import { Flex } from '@chakra-ui/react';
import { RootState } from '../../app/store'; import { RootState } from '../../app/store';
import { useAppDispatch, useAppSelector } from '../../app/hooks'; import { useAppDispatch, useAppSelector } from '../../app/store';
import { setCfgScale, setSampler, setSteps, SDState } from '../sd/sdSlice'; import { setCfgScale, setSampler, setSteps, SDState } from '../sd/sdSlice';
import SDNumberInput from '../../components/SDNumberInput';
import SDSelect from '../../components/SDSelect';
import { SAMPLERS } from '../../app/constants'; import { SAMPLERS } from '../../app/constants';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash'; import { isEqual } from 'lodash';
import { ChangeEvent } from 'react'; import { ChangeEvent } from 'react';
import SDNumberInput from '../../common/components/SDNumberInput';
import SDSelect from '../../common/components/SDSelect';
const sdSelector = createSelector( const sdSelector = createSelector(
(state: RootState) => state.sd, (state: RootState) => state.sd,

View File

@ -11,10 +11,12 @@ import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash'; import { isEqual } from 'lodash';
import { ChangeEvent } from 'react'; import { ChangeEvent } from 'react';
import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from '../../app/constants'; import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from '../../app/constants';
import { useAppDispatch, useAppSelector } from '../../app/hooks'; import { useAppDispatch, useAppSelector } from '../../app/store';
import { RootState } from '../../app/store'; import { RootState } from '../../app/store';
import SDNumberInput from '../../components/SDNumberInput'; import SDNumberInput from '../../common/components/SDNumberInput';
import SDSwitch from '../../components/SDSwitch'; import SDSwitch from '../../common/components/SDSwitch';
import randomInt from '../../common/util/randomInt';
import { validateSeedWeights } from '../../common/util/seedWeightPairs';
import { import {
SDState, SDState,
setIterations, setIterations,
@ -24,8 +26,6 @@ import {
setShouldRandomizeSeed, setShouldRandomizeSeed,
setVariationAmount, setVariationAmount,
} from './sdSlice'; } from './sdSlice';
import randomInt from './util/randomInt';
import { validateSeedWeights } from './util/seedWeightPairs';
const sdSelector = createSelector( const sdSelector = createSelector(
(state: RootState) => state.sd, (state: RootState) => state.sd,

View File

@ -2,21 +2,12 @@ import { createSlice } from '@reduxjs/toolkit';
import type { PayloadAction } from '@reduxjs/toolkit'; import type { PayloadAction } from '@reduxjs/toolkit';
import { SDMetadata } from '../gallery/gallerySlice'; import { SDMetadata } from '../gallery/gallerySlice';
const calculateRealSteps = (
steps: number,
strength: number,
hasInitImage: boolean
): number => {
return hasInitImage ? Math.floor(strength * steps) : steps;
};
export type UpscalingLevel = 0 | 2 | 4; export type UpscalingLevel = 0 | 2 | 4;
export interface SDState { export interface SDState {
prompt: string; prompt: string;
iterations: number; iterations: number;
steps: number; steps: number;
realSteps: number;
cfgScale: number; cfgScale: number;
height: number; height: number;
width: number; width: number;
@ -43,7 +34,6 @@ const initialSDState: SDState = {
prompt: '', prompt: '',
iterations: 1, iterations: 1,
steps: 50, steps: 50,
realSteps: 50,
cfgScale: 7.5, cfgScale: 7.5,
height: 512, height: 512,
width: 512, width: 512,
@ -79,14 +69,7 @@ export const sdSlice = createSlice({
state.iterations = action.payload; state.iterations = action.payload;
}, },
setSteps: (state, action: PayloadAction<number>) => { setSteps: (state, action: PayloadAction<number>) => {
const { img2imgStrength, initialImagePath } = state; state.steps = action.payload;
const steps = action.payload;
state.steps = steps;
state.realSteps = calculateRealSteps(
steps,
img2imgStrength,
Boolean(initialImagePath)
);
}, },
setCfgScale: (state, action: PayloadAction<number>) => { setCfgScale: (state, action: PayloadAction<number>) => {
state.cfgScale = action.payload; state.cfgScale = action.payload;
@ -105,14 +88,7 @@ export const sdSlice = createSlice({
state.shouldRandomizeSeed = false; state.shouldRandomizeSeed = false;
}, },
setImg2imgStrength: (state, action: PayloadAction<number>) => { setImg2imgStrength: (state, action: PayloadAction<number>) => {
const img2imgStrength = action.payload; state.img2imgStrength = action.payload;
const { steps, initialImagePath } = state;
state.img2imgStrength = img2imgStrength;
state.realSteps = calculateRealSteps(
steps,
img2imgStrength,
Boolean(initialImagePath)
);
}, },
setGfpganStrength: (state, action: PayloadAction<number>) => { setGfpganStrength: (state, action: PayloadAction<number>) => {
state.gfpganStrength = action.payload; state.gfpganStrength = action.payload;
@ -127,15 +103,9 @@ export const sdSlice = createSlice({
state.shouldUseInitImage = action.payload; state.shouldUseInitImage = action.payload;
}, },
setInitialImagePath: (state, action: PayloadAction<string>) => { setInitialImagePath: (state, action: PayloadAction<string>) => {
const initialImagePath = action.payload; const newInitialImagePath = action.payload;
const { steps, img2imgStrength } = state; state.shouldUseInitImage = newInitialImagePath ? true : false;
state.shouldUseInitImage = initialImagePath ? true : false; state.initialImagePath = newInitialImagePath;
state.initialImagePath = initialImagePath;
state.realSteps = calculateRealSteps(
steps,
img2imgStrength,
Boolean(initialImagePath)
);
}, },
setMaskPath: (state, action: PayloadAction<string>) => { setMaskPath: (state, action: PayloadAction<string>) => {
state.maskPath = action.payload; state.maskPath = action.payload;
@ -153,6 +123,7 @@ export const sdSlice = createSlice({
state, state,
action: PayloadAction<{ key: string; value: string | number | boolean }> action: PayloadAction<{ key: string; value: string | number | boolean }>
) => { ) => {
// TODO: This probably needs to be refactored.
const { key, value } = action.payload; const { key, value } = action.payload;
const temp = { ...state, [key]: value }; const temp = { ...state, [key]: value };
if (key === 'seed') { if (key === 'seed') {
@ -173,6 +144,7 @@ export const sdSlice = createSlice({
state.seedWeights = action.payload; state.seedWeights = action.payload;
}, },
setAllParameters: (state, action: PayloadAction<SDMetadata>) => { setAllParameters: (state, action: PayloadAction<SDMetadata>) => {
// TODO: This probably needs to be refactored.
const { const {
prompt, prompt,
steps, steps,

View File

@ -5,7 +5,7 @@ import {
Text, Text,
Tooltip, Tooltip,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from '../../app/hooks'; import { useAppDispatch, useAppSelector } from '../../app/store';
import { RootState } from '../../app/store'; import { RootState } from '../../app/store';
import { setShouldShowLogViewer, SystemState } from './systemSlice'; import { setShouldShowLogViewer, SystemState } from './systemSlice';
import { useLayoutEffect, useRef, useState } from 'react'; import { useLayoutEffect, useRef, useState } from 'react';
@ -44,18 +44,42 @@ const LogViewer = () => {
const log = useAppSelector(logSelector); const log = useAppSelector(logSelector);
const { shouldShowLogViewer } = useAppSelector(systemSelector); const { shouldShowLogViewer } = useAppSelector(systemSelector);
// Set colors based on dark/light mode
const bg = useColorModeValue('gray.50', 'gray.900'); const bg = useColorModeValue('gray.50', 'gray.900');
const borderColor = useColorModeValue('gray.500', 'gray.500'); const borderColor = useColorModeValue('gray.500', 'gray.500');
const logTextColors = useColorModeValue(
{
info: undefined,
warning: 'yellow.500',
error: 'red.500',
},
{
info: undefined,
warning: 'yellow.300',
error: 'red.300',
}
);
// Rudimentary autoscroll // Rudimentary autoscroll
const [shouldAutoscroll, setShouldAutoscroll] = useState<boolean>(true); const [shouldAutoscroll, setShouldAutoscroll] = useState<boolean>(true);
const viewerRef = useRef<HTMLDivElement>(null); const viewerRef = useRef<HTMLDivElement>(null);
/**
* If autoscroll is on, scroll to the bottom when:
* - log updates
* - viewer is toggled
*
* Also scroll to the bottom whenever autoscroll is turned on.
*/
useLayoutEffect(() => { useLayoutEffect(() => {
if (viewerRef.current !== null && shouldAutoscroll) { if (viewerRef.current !== null && shouldAutoscroll) {
viewerRef.current.scrollTop = viewerRef.current.scrollHeight; viewerRef.current.scrollTop = viewerRef.current.scrollHeight;
} }
}, [shouldAutoscroll]); }, [shouldAutoscroll, log, shouldShowLogViewer]);
const handleClickLogViewerToggle = () => {
dispatch(setShouldShowLogViewer(!shouldShowLogViewer));
};
return ( return (
<> <>
@ -78,16 +102,19 @@ const LogViewer = () => {
background={bg} background={bg}
ref={viewerRef} ref={viewerRef}
> >
{log.map((entry, i) => ( {log.map((entry, i) => {
<Flex gap={2} key={i}> const { timestamp, message, level } = entry;
return (
<Flex gap={2} key={i} textColor={logTextColors[level]}>
<Text fontSize="sm" fontWeight={'semibold'}> <Text fontSize="sm" fontWeight={'semibold'}>
{entry.timestamp}: {timestamp}:
</Text> </Text>
<Text fontSize="sm" wordBreak={'break-all'}> <Text fontSize="sm" wordBreak={'break-all'}>
{entry.message} {message}
</Text> </Text>
</Flex> </Flex>
))} );
})}
</Flex> </Flex>
)} )}
{shouldShowLogViewer && ( {shouldShowLogViewer && (
@ -114,7 +141,7 @@ const LogViewer = () => {
variant={'solid'} variant={'solid'}
aria-label="Toggle Log Viewer" aria-label="Toggle Log Viewer"
icon={shouldShowLogViewer ? <FaMinus /> : <FaCode />} icon={shouldShowLogViewer ? <FaMinus /> : <FaCode />}
onClick={() => dispatch(setShouldShowLogViewer(!shouldShowLogViewer))} onClick={handleClickLogViewerToggle}
/> />
</Tooltip> </Tooltip>
</> </>

View File

@ -16,7 +16,7 @@ import {
Text, Text,
useDisclosure, useDisclosure,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from '../../app/hooks'; import { useAppDispatch, useAppSelector } from '../../app/store';
import { import {
setShouldConfirmOnDelete, setShouldConfirmOnDelete,
setShouldDisplayInProgress, setShouldDisplayInProgress,

View File

@ -1,10 +1,12 @@
import { createSlice } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit';
import type { PayloadAction } from '@reduxjs/toolkit'; import type { PayloadAction } from '@reduxjs/toolkit';
import dateFormat from 'dateformat';
import { ExpandedIndex } from '@chakra-ui/react'; import { ExpandedIndex } from '@chakra-ui/react';
export type LogLevel = 'info' | 'warning' | 'error';
export interface LogEntry { export interface LogEntry {
timestamp: string; timestamp: string;
level: LogLevel;
message: string; message: string;
} }
@ -12,10 +14,18 @@ export interface Log {
[index: number]: LogEntry; [index: number]: LogEntry;
} }
export interface SystemState { export interface SystemStatus {
shouldDisplayInProgress: boolean;
isProcessing: boolean; isProcessing: boolean;
currentStep: number; currentStep: number;
totalSteps: number;
currentIteration: number;
totalIterations: number;
currentStatus: string;
currentStatusHasSteps: boolean;
}
export interface SystemState extends SystemStatus {
shouldDisplayInProgress: boolean;
log: Array<LogEntry>; log: Array<LogEntry>;
shouldShowLogViewer: boolean; shouldShowLogViewer: boolean;
isGFPGANAvailable: boolean; isGFPGANAvailable: boolean;
@ -24,12 +34,17 @@ export interface SystemState {
socketId: string; socketId: string;
shouldConfirmOnDelete: boolean; shouldConfirmOnDelete: boolean;
openAccordions: ExpandedIndex; openAccordions: ExpandedIndex;
currentStep: number;
totalSteps: number;
currentIteration: number;
totalIterations: number;
currentStatus: string;
currentStatusHasSteps: boolean;
} }
const initialSystemState = { const initialSystemState = {
isConnected: false, isConnected: false,
isProcessing: false, isProcessing: false,
currentStep: 0,
log: [], log: [],
shouldShowLogViewer: false, shouldShowLogViewer: false,
shouldDisplayInProgress: false, shouldDisplayInProgress: false,
@ -38,6 +53,12 @@ const initialSystemState = {
socketId: '', socketId: '',
shouldConfirmOnDelete: true, shouldConfirmOnDelete: true,
openAccordions: [0], openAccordions: [0],
currentStep: 0,
totalSteps: 0,
currentIteration: 0,
totalIterations: 0,
currentStatus: '',
currentStatusHasSteps: false,
}; };
const initialState: SystemState = initialSystemState; const initialState: SystemState = initialSystemState;
@ -51,18 +72,35 @@ export const systemSlice = createSlice({
}, },
setIsProcessing: (state, action: PayloadAction<boolean>) => { setIsProcessing: (state, action: PayloadAction<boolean>) => {
state.isProcessing = action.payload; state.isProcessing = action.payload;
if (action.payload === false) {
state.currentStep = 0;
}
}, },
setCurrentStep: (state, action: PayloadAction<number>) => { setCurrentStatus: (state, action: PayloadAction<string>) => {
state.currentStep = action.payload; state.currentStatus = action.payload;
}, },
addLogEntry: (state, action: PayloadAction<string>) => { setSystemStatus: (state, action: PayloadAction<SystemStatus>) => {
const currentStatus =
!action.payload.isProcessing && state.isConnected
? 'Connected'
: action.payload.currentStatus;
return { ...state, ...action.payload, currentStatus };
},
addLogEntry: (
state,
action: PayloadAction<{
timestamp: string;
message: string;
level?: LogLevel;
}>
) => {
const { timestamp, message, level } = action.payload;
const logLevel = level || 'info';
const entry: LogEntry = { const entry: LogEntry = {
timestamp: dateFormat(new Date(), 'isoDateTime'), timestamp,
message: action.payload, message,
level: logLevel,
}; };
state.log.push(entry); state.log.push(entry);
}, },
setShouldShowLogViewer: (state, action: PayloadAction<boolean>) => { setShouldShowLogViewer: (state, action: PayloadAction<boolean>) => {
@ -86,13 +124,14 @@ export const systemSlice = createSlice({
export const { export const {
setShouldDisplayInProgress, setShouldDisplayInProgress,
setIsProcessing, setIsProcessing,
setCurrentStep,
addLogEntry, addLogEntry,
setShouldShowLogViewer, setShouldShowLogViewer,
setIsConnected, setIsConnected,
setSocketId, setSocketId,
setShouldConfirmOnDelete, setShouldConfirmOnDelete,
setOpenAccordions, setOpenAccordions,
setSystemStatus,
setCurrentStatus,
} = systemSlice.actions; } = systemSlice.actions;
export default systemSlice.reducer; export default systemSlice.reducer;

View File

@ -8,9 +8,9 @@ import { persistStore } from 'redux-persist';
export const persistor = persistStore(store); export const persistor = persistStore(store);
import App from './App';
import { theme } from './app/theme'; import { theme } from './app/theme';
import Loading from './Loading'; import Loading from './Loading';
import App from './app/App';
ReactDOM.createRoot(document.getElementById('root') as HTMLElement).render( ReactDOM.createRoot(document.getElementById('root') as HTMLElement).render(
<React.StrictMode> <React.StrictMode>