Improves state/API code structure, formatting, etc

This commit is contained in:
psychedelicious 2022-09-18 17:33:09 +10:00
parent ef6609abcb
commit 6e927acd58
49 changed files with 2073 additions and 1401 deletions

View File

@ -7,6 +7,8 @@ import eventlet
import glob
import shlex
import argparse
import math
import shutil
from flask_socketio import SocketIO
from flask import Flask, send_from_directory, url_for, jsonify
@ -119,15 +121,15 @@ result_path = os.path.join(output_dir, 'img-samples/')
intermediate_path = os.path.join(result_path, 'intermediates/')
# path for user-uploaded init images and masks
init_path = os.path.join(result_path, 'init-images/')
mask_path = os.path.join(result_path, 'mask-images/')
init_image_path = os.path.join(result_path, 'init-images/')
mask_image_path = os.path.join(result_path, 'mask-images/')
# txt log
log_path = os.path.join(result_path, 'dream_log.txt')
# make all output paths
[os.makedirs(path, exist_ok=True)
for path in [result_path, intermediate_path, init_path, mask_path]]
for path in [result_path, intermediate_path, init_image_path, mask_image_path]]
"""
@ -155,7 +157,8 @@ def handle_request_all_images():
else:
metadata = all_metadata['sd-metadata']
image_array.append({'path': path, 'metadata': metadata})
return make_response("OK", data=image_array)
socketio.emit('galleryImages', {'images': image_array})
eventlet.sleep(0)
@socketio.on('generateImage')
@ -166,16 +169,32 @@ def handle_generate_image_event(generation_parameters, esrgan_parameters, gfpgan
esrgan_parameters,
gfpgan_parameters
)
return make_response("OK")
@socketio.on('runESRGAN')
def handle_run_esrgan_event(original_image, esrgan_parameters):
print(f'>> ESRGAN upscale requested for "{original_image["url"]}": {esrgan_parameters}')
progress = {
'currentStep': 1,
'totalSteps': 1,
'currentIteration': 1,
'totalIterations': 1,
'currentStatus': 'Preparing',
'isProcessing': True,
'currentStatusHasSteps': False
}
socketio.emit('progressUpdate', progress)
eventlet.sleep(0)
image = Image.open(original_image["url"])
seed = original_image['metadata']['seed'] if 'seed' in original_image['metadata'] else 'unknown_seed'
progress['currentStatus'] = 'Upscaling'
socketio.emit('progressUpdate', progress)
eventlet.sleep(0)
image = real_esrgan_upscale(
image=image,
upsampler_scale=esrgan_parameters['upscale'][0],
@ -183,24 +202,54 @@ def handle_run_esrgan_event(original_image, esrgan_parameters):
seed=seed
)
progress['currentStatus'] = 'Saving image'
socketio.emit('progressUpdate', progress)
eventlet.sleep(0)
esrgan_parameters['seed'] = seed
path = save_image(image, esrgan_parameters, result_path, postprocessing='esrgan')
command = parameters_to_command(esrgan_parameters)
write_log_message(f'[Upscaled] "{original_image["url"]}" > "{path}": {command}')
progress['currentStatus'] = 'Finished'
progress['currentStep'] = 0
progress['totalSteps'] = 0
progress['currentIteration'] = 0
progress['totalIterations'] = 0
progress['isProcessing'] = False
socketio.emit('progressUpdate', progress)
eventlet.sleep(0)
socketio.emit(
'result', {'url': os.path.relpath(path), 'type': 'esrgan', 'uuid': original_image['uuid'],'metadata': esrgan_parameters})
'esrganResult', {'url': os.path.relpath(path), 'uuid': original_image['uuid'], 'metadata': esrgan_parameters})
@socketio.on('runGFPGAN')
def handle_run_gfpgan_event(original_image, gfpgan_parameters):
print(f'>> GFPGAN face fix requested for "{original_image["url"]}": {gfpgan_parameters}')
progress = {
'currentStep': 1,
'totalSteps': 1,
'currentIteration': 1,
'totalIterations': 1,
'currentStatus': 'Preparing',
'isProcessing': True,
'currentStatusHasSteps': False
}
socketio.emit('progressUpdate', progress)
eventlet.sleep(0)
image = Image.open(original_image["url"])
seed = original_image['metadata']['seed'] if 'seed' in original_image['metadata'] else 'unknown_seed'
progress['currentStatus'] = 'Fixing faces'
socketio.emit('progressUpdate', progress)
eventlet.sleep(0)
image = run_gfpgan(
image=image,
strength=gfpgan_parameters['gfpgan_strength'],
@ -208,29 +257,42 @@ def handle_run_gfpgan_event(original_image, gfpgan_parameters):
upsampler_scale=1
)
progress['currentStatus'] = 'Saving image'
socketio.emit('progressUpdate', progress)
eventlet.sleep(0)
gfpgan_parameters['seed'] = seed
path = save_image(image, gfpgan_parameters, result_path, postprocessing='gfpgan')
command = parameters_to_command(gfpgan_parameters)
write_log_message(f'[Fixed faces] "{original_image["url"]}" > "{path}": {command}')
progress['currentStatus'] = 'Finished'
progress['currentStep'] = 0
progress['totalSteps'] = 0
progress['currentIteration'] = 0
progress['totalIterations'] = 0
progress['isProcessing'] = False
socketio.emit('progressUpdate', progress)
eventlet.sleep(0)
socketio.emit(
'result', {'url': os.path.relpath(path), 'type': 'gfpgan', 'uuid': original_image['uuid'],'metadata': gfpgan_parameters})
'gfpganResult', {'url': os.path.relpath(path), 'uuid': original_image['uuid'], 'metadata': gfpgan_parameters})
@socketio.on('cancel')
def handle_cancel():
print(f'>> Cancel processing requested')
canceled.set()
return make_response("OK")
socketio.emit('processingCanceled')
# TODO: I think this needs a safety mechanism.
@socketio.on('deleteImage')
def handle_delete_image(path):
def handle_delete_image(path, uuid):
print(f'>> Delete requested "{path}"')
send2trash(path)
return make_response("OK")
socketio.emit('imageDeleted', {'url': path, 'uuid': uuid})
# TODO: I think this needs a safety mechanism.
@ -240,11 +302,11 @@ def handle_upload_initial_image(bytes, name):
uuid = uuid4().hex
split = os.path.splitext(name)
name = f'{split[0]}.{uuid}{split[1]}'
file_path = os.path.join(init_path, name)
file_path = os.path.join(init_image_path, name)
os.makedirs(os.path.dirname(file_path), exist_ok=True)
newFile = open(file_path, "wb")
newFile.write(bytes)
return make_response("OK", data=file_path)
socketio.emit('initialImageUploaded', {'url': file_path, 'uuid': ''})
# TODO: I think this needs a safety mechanism.
@ -254,11 +316,11 @@ def handle_upload_mask_image(bytes, name):
uuid = uuid4().hex
split = os.path.splitext(name)
name = f'{split[0]}.{uuid}{split[1]}'
file_path = os.path.join(mask_path, name)
file_path = os.path.join(mask_image_path, name)
os.makedirs(os.path.dirname(file_path), exist_ok=True)
newFile = open(file_path, "wb")
newFile.write(bytes)
return make_response("OK", data=file_path)
socketio.emit('maskImageUploaded', {'url': file_path, 'uuid': ''})
@ -273,6 +335,13 @@ ADDITIONAL FUNCTIONS
"""
def make_unique_init_image_filename(name):
uuid = uuid4().hex
split = os.path.splitext(name)
name = f'{split[0]}.{uuid}{split[1]}'
return name
def write_log_message(message, log_path=log_path):
"""Logs the filename and parameters used to generate or process that image to log file"""
message = f'{message}\n'
@ -280,15 +349,6 @@ def write_log_message(message, log_path=log_path):
file.writelines(message)
def make_response(status, message=None, data=None):
response = {'status': status}
if message is not None:
response['message'] = message
if data is not None:
response['data'] = data
return response
def save_image(image, parameters, output_dir, step_index=None, postprocessing=False):
seed = parameters['seed'] if 'seed' in parameters else 'unknown_seed'
@ -310,16 +370,69 @@ def save_image(image, parameters, output_dir, step_index=None, postprocessing=Fa
return path
def calculate_real_steps(steps, strength, has_init_image):
return math.floor(strength * steps) if has_init_image else steps
def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters):
canceled.clear()
step_index = 1
"""
If a result image is used as an init image, and then deleted, we will want to be
able to use it as an init image in the future. Need to copy it.
If the init/mask image doesn't exist in the init_image_path/mask_image_path,
make a unique filename for it and copy it there.
"""
if ('init_img' in generation_parameters):
filename = os.path.basename(generation_parameters['init_img'])
if not os.path.exists(os.path.join(init_image_path, filename)):
unique_filename = make_unique_init_image_filename(filename)
new_path = os.path.join(init_image_path, unique_filename)
shutil.copy(generation_parameters['init_img'], new_path)
generation_parameters['init_img'] = new_path
if ('init_mask' in generation_parameters):
filename = os.path.basename(generation_parameters['init_mask'])
if not os.path.exists(os.path.join(mask_image_path, filename)):
unique_filename = make_unique_init_image_filename(filename)
new_path = os.path.join(init_image_path, unique_filename)
shutil.copy(generation_parameters['init_img'], new_path)
generation_parameters['init_mask'] = new_path
totalSteps = calculate_real_steps(
steps=generation_parameters['steps'],
strength=generation_parameters['strength'] if 'strength' in generation_parameters else None,
has_init_image='init_img' in generation_parameters
)
progress = {
'currentStep': 1,
'totalSteps': totalSteps,
'currentIteration': 1,
'totalIterations': generation_parameters['iterations'],
'currentStatus': 'Preparing',
'isProcessing': True,
'currentStatusHasSteps': False
}
socketio.emit('progressUpdate', progress)
eventlet.sleep(0)
def image_progress(sample, step):
if canceled.is_set():
raise CanceledException
nonlocal step_index
nonlocal generation_parameters
nonlocal progress
progress['currentStep'] = step + 1
progress['currentStatus'] = 'Generating'
progress['currentStatusHasSteps'] = True
if generation_parameters["progress_images"] and step % 5 == 0 and step < generation_parameters['steps'] - 1:
image = model.sample_to_image(sample)
path = save_image(image, generation_parameters, intermediate_path, step_index)
@ -327,18 +440,30 @@ def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters)
step_index += 1
socketio.emit('intermediateResult', {
'url': os.path.relpath(path), 'metadata': generation_parameters})
socketio.emit('progress', {'step': step + 1})
socketio.emit('progressUpdate', progress)
eventlet.sleep(0)
def image_done(image, seed):
nonlocal generation_parameters
nonlocal esrgan_parameters
nonlocal gfpgan_parameters
nonlocal progress
step_index = 1
progress['currentStatus'] = 'Generation complete'
socketio.emit('progressUpdate', progress)
eventlet.sleep(0)
all_parameters = generation_parameters
postprocessing = False
if esrgan_parameters:
progress['currentStatus'] = 'Upscaling'
progress['currentStatusHasSteps'] = False
socketio.emit('progressUpdate', progress)
eventlet.sleep(0)
image = real_esrgan_upscale(
image=image,
strength=esrgan_parameters['strength'],
@ -349,6 +474,11 @@ def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters)
all_parameters["upscale"] = [esrgan_parameters['level'], esrgan_parameters['strength']]
if gfpgan_parameters:
progress['currentStatus'] = 'Fixing faces'
progress['currentStatusHasSteps'] = False
socketio.emit('progressUpdate', progress)
eventlet.sleep(0)
image = run_gfpgan(
image=image,
strength=gfpgan_parameters['strength'],
@ -359,6 +489,9 @@ def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters)
all_parameters["gfpgan_strength"] = gfpgan_parameters['strength']
all_parameters['seed'] = seed
progress['currentStatus'] = 'Saving image'
socketio.emit('progressUpdate', progress)
eventlet.sleep(0)
path = save_image(image, all_parameters, result_path, postprocessing=postprocessing)
command = parameters_to_command(all_parameters)
@ -366,8 +499,24 @@ def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters)
print(f'Image generated: "{path}"')
write_log_message(f'[Generated] "{path}": {command}')
if (progress['totalIterations'] > progress['currentIteration']):
progress['currentStep'] = 1
progress['currentIteration'] +=1
progress['currentStatus'] = 'Iteration finished'
progress['currentStatusHasSteps'] = False
else:
progress['currentStep'] = 0
progress['totalSteps'] = 0
progress['currentIteration'] = 0
progress['totalIterations'] = 0
progress['currentStatus'] = 'Finished'
progress['isProcessing'] = False
socketio.emit('progressUpdate', progress)
eventlet.sleep(0)
socketio.emit(
'result', {'url': os.path.relpath(path), 'type': 'generation', 'metadata': all_parameters})
'generationResult', {'url': os.path.relpath(path), 'metadata': all_parameters})
eventlet.sleep(0)
try:
@ -382,7 +531,7 @@ def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters)
except CanceledException:
pass
except Exception as e:
socketio.emit('error', (str(e)))
socketio.emit('error', {'message': (str(e))})
print("\n")
traceback.print_exc()
print("\n")

View File

@ -1,17 +1,29 @@
# Stable Diffusion Web UI
## Run
- `python backend/server.py` serves both frontend and backend at http://localhost:9090
## Evironment
Install [node](https://nodejs.org/en/download/) (includes npm) and optionally
[yarn](https://yarnpkg.com/getting-started/install).
From `frontend/` run `npm install` / `yarn install` to install the frontend packages.
## Dev
1. From `frontend/`, run `yarn dev` to start the dev server.
1. From `frontend/`, run `npm dev` / `yarn dev` to start the dev server.
2. Note the address it starts up on (probably `http://localhost:5173/`).
3. Edit `backend/server.py`'s `additional_allowed_origins` to include this address, e.g. `additional_allowed_origins = ['http://localhost:5173']`.
3. Edit `backend/server.py`'s `additional_allowed_origins` to include this address, e.g.
`additional_allowed_origins = ['http://localhost:5173']`.
4. Leaving the dev server running, open a new terminal and go to the project root.
5. Run `python backend/server.py`.
6. Navigate to the dev server address e.g. `http://localhost:5173/`.
To build for dev: `npm build-dev` / `yarn build-dev`
To build for production: `npm build` / `yarn build`
## TODO
@ -20,7 +32,6 @@
`framer-motion`. I would prefer to save the ~30kb on bundle and have zero animations. This is on
the Chakra roadmap. See https://github.com/chakra-ui/chakra-ui/pull/6368 for last discussion on
this. Need to check in on this issue periodically.
- More status info e.g. phase of processing, image we are on of the total count, etc
- Mobile friendly layout
- Proper image gallery/viewer/manager
- Help tooltips and such

File diff suppressed because one or more lines are too long

694
frontend/dist/assets/index.de730902.js vendored Normal file

File diff suppressed because one or more lines are too long

View File

@ -4,7 +4,7 @@
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Stable Diffusion Dream Server</title>
<script type="module" crossorigin src="/assets/index.3d2e59c5.js"></script>
<script type="module" crossorigin src="/assets/index.de730902.js"></script>
<link rel="stylesheet" href="/assets/index.447eb2a9.css">
</head>
<body>

View File

@ -1,25 +1,28 @@
import { Grid, GridItem } from '@chakra-ui/react';
import CurrentImageDisplay from './features/gallery/CurrentImageDisplay';
import LogViewer from './features/system/LogViewer';
import PromptInput from './features/sd/PromptInput';
import ProgressBar from './features/header/ProgressBar';
import { useEffect } from 'react';
import { useAppDispatch } from './app/hooks';
import { requestAllImages } from './app/socketio';
import ProcessButtons from './features/sd/ProcessButtons';
import ImageGallery from './features/gallery/ImageGallery';
import SiteHeader from './features/header/SiteHeader';
import OptionsAccordion from './features/sd/OptionsAccordion';
import { useEffect, useState } from 'react';
import CurrentImageDisplay from '../features/gallery/CurrentImageDisplay';
import ImageGallery from '../features/gallery/ImageGallery';
import ProgressBar from '../features/header/ProgressBar';
import SiteHeader from '../features/header/SiteHeader';
import OptionsAccordion from '../features/sd/OptionsAccordion';
import ProcessButtons from '../features/sd/ProcessButtons';
import PromptInput from '../features/sd/PromptInput';
import LogViewer from '../features/system/LogViewer';
import Loading from '../Loading';
import { useAppDispatch } from './store';
import { requestAllImages } from './socketio/actions';
const App = () => {
const dispatch = useAppDispatch();
const [isReady, setIsReady] = useState<boolean>(false);
// Load images from the gallery once
useEffect(() => {
dispatch(requestAllImages());
setIsReady(true);
}, [dispatch]);
return (
return isReady ? (
<>
<Grid
width="100vw"
@ -57,6 +60,8 @@ const App = () => {
</Grid>
<LogViewer />
</>
) : (
<Loading />
);
};

View File

@ -2,52 +2,52 @@
// Valid samplers
export const SAMPLERS: Array<string> = [
'ddim',
'plms',
'k_lms',
'k_dpm_2',
'k_dpm_2_a',
'k_euler',
'k_euler_a',
'k_heun',
'ddim',
'plms',
'k_lms',
'k_dpm_2',
'k_dpm_2_a',
'k_euler',
'k_euler_a',
'k_heun',
];
// Valid image widths
export const WIDTHS: Array<number> = [
64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960,
1024,
64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960,
1024,
];
// Valid image heights
export const HEIGHTS: Array<number> = [
64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960,
1024,
64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960,
1024,
];
// Valid upscaling levels
export const UPSCALING_LEVELS: Array<{ key: string; value: number }> = [
{ key: '2x', value: 2 },
{ key: '4x', value: 4 },
{ key: '2x', value: 2 },
{ key: '4x', value: 4 },
];
// Internal to human-readable parameters
export const PARAMETERS: { [key: string]: string } = {
prompt: 'Prompt',
iterations: 'Iterations',
steps: 'Steps',
cfgScale: 'CFG Scale',
height: 'Height',
width: 'Width',
sampler: 'Sampler',
seed: 'Seed',
img2imgStrength: 'img2img Strength',
gfpganStrength: 'GFPGAN Strength',
upscalingLevel: 'Upscaling Level',
upscalingStrength: 'Upscaling Strength',
initialImagePath: 'Initial Image',
maskPath: 'Initial Image Mask',
shouldFitToWidthHeight: 'Fit Initial Image',
seamless: 'Seamless Tiling',
prompt: 'Prompt',
iterations: 'Iterations',
steps: 'Steps',
cfgScale: 'CFG Scale',
height: 'Height',
width: 'Width',
sampler: 'Sampler',
seed: 'Seed',
img2imgStrength: 'img2img Strength',
gfpganStrength: 'GFPGAN Strength',
upscalingLevel: 'Upscaling Level',
upscalingStrength: 'Upscaling Strength',
initialImagePath: 'Initial Image',
maskPath: 'Initial Image Mask',
shouldFitToWidthHeight: 'Fit Initial Image',
seamless: 'Seamless Tiling',
};
export const NUMPY_RAND_MIN = 0;

View File

@ -1,7 +0,0 @@
import { useDispatch, useSelector } from 'react-redux';
import type { TypedUseSelectorHook } from 'react-redux';
import type { RootState, AppDispatch } from './store';
// Use throughout your app instead of plain `useDispatch` and `useSelector`
export const useAppDispatch: () => AppDispatch = useDispatch;
export const useAppSelector: TypedUseSelectorHook<RootState> = useSelector;

View File

@ -1,393 +0,0 @@
import { createAction, Middleware } from '@reduxjs/toolkit';
import { io } from 'socket.io-client';
import {
addImage,
clearIntermediateImage,
removeImage,
SDImage,
SDMetadata,
setGalleryImages,
setIntermediateImage,
} from '../features/gallery/gallerySlice';
import {
addLogEntry,
setCurrentStep,
setIsConnected,
setIsProcessing,
} from '../features/system/systemSlice';
import { v4 as uuidv4 } from 'uuid';
import { setInitialImagePath, setMaskPath } from '../features/sd/sdSlice';
import {
backendToFrontendParameters,
frontendToBackendParameters,
} from './parameterTranslation';
export interface SocketIOResponse {
status: 'OK' | 'ERROR';
message?: string;
data?: any;
}
export const socketioMiddleware = () => {
const { hostname, port } = new URL(window.location.href);
const socketio = io(`http://${hostname}:9090`);
let areListenersSet = false;
const middleware: Middleware = (store) => (next) => (action) => {
const { dispatch, getState } = store;
if (!areListenersSet) {
// CONNECT
socketio.on('connect', () => {
try {
dispatch(setIsConnected(true));
} catch (e) {
console.error(e);
}
});
// DISCONNECT
socketio.on('disconnect', () => {
try {
dispatch(setIsConnected(false));
dispatch(setIsProcessing(false));
dispatch(addLogEntry(`Disconnected from server`));
} catch (e) {
console.error(e);
}
});
// PROCESSING RESULT
socketio.on(
'result',
(data: {
url: string;
type: 'generation' | 'esrgan' | 'gfpgan';
uuid?: string;
metadata: { [key: string]: any };
}) => {
try {
const newUuid = uuidv4();
const { type, url, uuid, metadata } = data;
switch (type) {
case 'generation': {
const translatedMetadata =
backendToFrontendParameters(metadata);
dispatch(
addImage({
uuid: newUuid,
url,
metadata: translatedMetadata,
})
);
dispatch(
addLogEntry(`Image generated: ${url}`)
);
break;
}
case 'esrgan': {
const originalImage =
getState().gallery.images.find(
(i: SDImage) => i.uuid === uuid
);
const newMetadata = {
...originalImage.metadata,
};
newMetadata.shouldRunESRGAN = true;
newMetadata.upscalingLevel =
metadata.upscale[0];
newMetadata.upscalingStrength =
metadata.upscale[1];
dispatch(
addImage({
uuid: newUuid,
url,
metadata: newMetadata,
})
);
dispatch(
addLogEntry(`ESRGAN upscaled: ${url}`)
);
break;
}
case 'gfpgan': {
const originalImage =
getState().gallery.images.find(
(i: SDImage) => i.uuid === uuid
);
const newMetadata = {
...originalImage.metadata,
};
newMetadata.shouldRunGFPGAN = true;
newMetadata.gfpganStrength =
metadata.gfpgan_strength;
dispatch(
addImage({
uuid: newUuid,
url,
metadata: newMetadata,
})
);
dispatch(
addLogEntry(`GFPGAN fixed faces: ${url}`)
);
break;
}
}
dispatch(setIsProcessing(false));
} catch (e) {
console.error(e);
}
}
);
// PROGRESS UPDATE
socketio.on('progress', (data: { step: number }) => {
try {
dispatch(setIsProcessing(true));
dispatch(setCurrentStep(data.step));
} catch (e) {
console.error(e);
}
});
// INTERMEDIATE IMAGE
socketio.on(
'intermediateResult',
(data: { url: string; metadata: SDMetadata }) => {
try {
const uuid = uuidv4();
const { url, metadata } = data;
dispatch(
setIntermediateImage({
uuid,
url,
metadata,
})
);
dispatch(
addLogEntry(`Intermediate image generated: ${url}`)
);
} catch (e) {
console.error(e);
}
}
);
// ERROR FROM BACKEND
socketio.on('error', (message) => {
try {
dispatch(addLogEntry(`Server error: ${message}`));
dispatch(setIsProcessing(false));
dispatch(clearIntermediateImage());
} catch (e) {
console.error(e);
}
});
areListenersSet = true;
}
// HANDLE ACTIONS
switch (action.type) {
// GENERATE IMAGE
case 'socketio/generateImage': {
dispatch(setIsProcessing(true));
dispatch(setCurrentStep(-1));
const {
generationParameters,
esrganParameters,
gfpganParameters,
} = frontendToBackendParameters(
getState().sd,
getState().system
);
socketio.emit(
'generateImage',
generationParameters,
esrganParameters,
gfpganParameters
);
dispatch(
addLogEntry(
`Image generation requested: ${JSON.stringify({
...generationParameters,
...esrganParameters,
...gfpganParameters,
})}`
)
);
break;
}
// RUN ESRGAN (UPSCALING)
case 'socketio/runESRGAN': {
const imageToProcess = action.payload;
dispatch(setIsProcessing(true));
dispatch(setCurrentStep(-1));
const { upscalingLevel, upscalingStrength } = getState().sd;
const esrganParameters = {
upscale: [upscalingLevel, upscalingStrength],
};
socketio.emit('runESRGAN', imageToProcess, esrganParameters);
dispatch(
addLogEntry(
`ESRGAN upscale requested: ${JSON.stringify({
file: imageToProcess.url,
...esrganParameters,
})}`
)
);
break;
}
// RUN GFPGAN (FIX FACES)
case 'socketio/runGFPGAN': {
const imageToProcess = action.payload;
dispatch(setIsProcessing(true));
dispatch(setCurrentStep(-1));
const { gfpganStrength } = getState().sd;
const gfpganParameters = {
gfpgan_strength: gfpganStrength,
};
socketio.emit('runGFPGAN', imageToProcess, gfpganParameters);
dispatch(
addLogEntry(
`GFPGAN fix faces requested: ${JSON.stringify({
file: imageToProcess.url,
...gfpganParameters,
})}`
)
);
break;
}
// DELETE IMAGE
case 'socketio/deleteImage': {
const imageToDelete = action.payload;
const { url } = imageToDelete;
socketio.emit(
'deleteImage',
url,
(response: SocketIOResponse) => {
if (response.status === 'OK') {
dispatch(removeImage(imageToDelete));
dispatch(addLogEntry(`Image deleted: ${url}`));
}
}
);
break;
}
// GET ALL IMAGES FOR GALLERY
case 'socketio/requestAllImages': {
socketio.emit(
'requestAllImages',
(response: SocketIOResponse) => {
dispatch(setGalleryImages(response.data));
dispatch(
addLogEntry(`Loaded ${response.data.length} images`)
);
}
);
break;
}
// CANCEL PROCESSING
case 'socketio/cancelProcessing': {
socketio.emit('cancel', (response: SocketIOResponse) => {
const { intermediateImage } = getState().gallery;
if (response.status === 'OK') {
dispatch(setIsProcessing(false));
if (intermediateImage) {
dispatch(addImage(intermediateImage));
dispatch(
addLogEntry(
`Intermediate image saved: ${intermediateImage.url}`
)
);
dispatch(clearIntermediateImage());
}
dispatch(addLogEntry(`Processing canceled`));
}
});
break;
}
// UPLOAD INITIAL IMAGE
case 'socketio/uploadInitialImage': {
const file = action.payload;
socketio.emit(
'uploadInitialImage',
file,
file.name,
(response: SocketIOResponse) => {
if (response.status === 'OK') {
dispatch(setInitialImagePath(response.data));
dispatch(
addLogEntry(
`Initial image uploaded: ${response.data}`
)
);
}
}
);
break;
}
// UPLOAD MASK IMAGE
case 'socketio/uploadMaskImage': {
const file = action.payload;
socketio.emit(
'uploadMaskImage',
file,
file.name,
(response: SocketIOResponse) => {
if (response.status === 'OK') {
dispatch(setMaskPath(response.data));
dispatch(
addLogEntry(
`Mask image uploaded: ${response.data}`
)
);
}
}
);
break;
}
}
next(action);
};
return middleware;
};
// Actions to be used by app
export const generateImage = createAction<undefined>('socketio/generateImage');
export const runESRGAN = createAction<SDImage>('socketio/runESRGAN');
export const runGFPGAN = createAction<SDImage>('socketio/runGFPGAN');
export const deleteImage = createAction<SDImage>('socketio/deleteImage');
export const requestAllImages = createAction<undefined>(
'socketio/requestAllImages'
);
export const cancelProcessing = createAction<undefined>(
'socketio/cancelProcessing'
);
export const uploadInitialImage = createAction<File>(
'socketio/uploadInitialImage'
);
export const uploadMaskImage = createAction<File>('socketio/uploadMaskImage');

View File

@ -0,0 +1,24 @@
import { createAction } from '@reduxjs/toolkit';
import { SDImage } from '../../features/gallery/gallerySlice';
/**
* We can't use redux-toolkit's createSlice() to make these actions,
* because they have no associated reducer. They only exist to dispatch
* requests to the server via socketio. These actions will be handled
* by the middleware.
*/
export const generateImage = createAction<undefined>('socketio/generateImage');
export const runESRGAN = createAction<SDImage>('socketio/runESRGAN');
export const runGFPGAN = createAction<SDImage>('socketio/runGFPGAN');
export const deleteImage = createAction<SDImage>('socketio/deleteImage');
export const requestAllImages = createAction<undefined>(
'socketio/requestAllImages'
);
export const cancelProcessing = createAction<undefined>(
'socketio/cancelProcessing'
);
export const uploadInitialImage = createAction<File>(
'socketio/uploadInitialImage'
);
export const uploadMaskImage = createAction<File>('socketio/uploadMaskImage');

View File

@ -0,0 +1,101 @@
import { AnyAction, Dispatch, MiddlewareAPI } from '@reduxjs/toolkit';
import dateFormat from 'dateformat';
import { Socket } from 'socket.io-client';
import { frontendToBackendParameters } from '../../common/util/parameterTranslation';
import { SDImage } from '../../features/gallery/gallerySlice';
import {
addLogEntry,
setIsProcessing,
} from '../../features/system/systemSlice';
/**
* Returns an object containing all functions which use `socketio.emit()`.
* i.e. those which make server requests.
*/
const makeSocketIOEmitters = (
store: MiddlewareAPI<Dispatch<AnyAction>, any>,
socketio: Socket
) => {
// We need to dispatch actions to redux and get pieces of state from the store.
const { dispatch, getState } = store;
return {
emitGenerateImage: () => {
dispatch(setIsProcessing(true));
const { generationParameters, esrganParameters, gfpganParameters } =
frontendToBackendParameters(getState().sd, getState().system);
socketio.emit(
'generateImage',
generationParameters,
esrganParameters,
gfpganParameters
);
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Image generation requested: ${JSON.stringify({
...generationParameters,
...esrganParameters,
...gfpganParameters,
})}`,
})
);
},
emitRunESRGAN: (imageToProcess: SDImage) => {
dispatch(setIsProcessing(true));
const { upscalingLevel, upscalingStrength } = getState().sd;
const esrganParameters = {
upscale: [upscalingLevel, upscalingStrength],
};
socketio.emit('runESRGAN', imageToProcess, esrganParameters);
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `ESRGAN upscale requested: ${JSON.stringify({
file: imageToProcess.url,
...esrganParameters,
})}`,
})
);
},
emitRunGFPGAN: (imageToProcess: SDImage) => {
dispatch(setIsProcessing(true));
const { gfpganStrength } = getState().sd;
const gfpganParameters = {
gfpgan_strength: gfpganStrength,
};
socketio.emit('runGFPGAN', imageToProcess, gfpganParameters);
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `GFPGAN fix faces requested: ${JSON.stringify({
file: imageToProcess.url,
...gfpganParameters,
})}`,
})
);
},
emitDeleteImage: (imageToDelete: SDImage) => {
const { url, uuid } = imageToDelete;
socketio.emit('deleteImage', url, uuid);
},
emitRequestAllImages: () => {
socketio.emit('requestAllImages');
},
emitCancelProcessing: () => {
socketio.emit('cancel');
},
emitUploadInitialImage: (file: File) => {
socketio.emit('uploadInitialImage', file, file.name);
},
emitUploadMaskImage: (file: File) => {
socketio.emit('uploadMaskImage', file, file.name);
},
};
};
export default makeSocketIOEmitters;

View File

@ -0,0 +1,338 @@
import { AnyAction, MiddlewareAPI, Dispatch } from '@reduxjs/toolkit';
import { v4 as uuidv4 } from 'uuid';
import dateFormat from 'dateformat';
import {
addLogEntry,
setIsConnected,
setIsProcessing,
SystemStatus,
setSystemStatus,
setCurrentStatus,
} from '../../features/system/systemSlice';
import type {
ServerGenerationResult,
ServerESRGANResult,
ServerGFPGANResult,
ServerIntermediateResult,
ServerError,
ServerGalleryImages,
ServerImageUrlAndUuid,
ServerImageUrl,
} from './types';
import { backendToFrontendParameters } from '../../common/util/parameterTranslation';
import {
addImage,
clearIntermediateImage,
removeImage,
SDImage,
setGalleryImages,
setIntermediateImage,
} from '../../features/gallery/gallerySlice';
import { setInitialImagePath, setMaskPath } from '../../features/sd/sdSlice';
/**
* Returns an object containing listener callbacks for socketio events.
* TODO: This file is large, but simple. Should it be split up further?
*/
const makeSocketIOListeners = (
store: MiddlewareAPI<Dispatch<AnyAction>, any>
) => {
const { dispatch, getState } = store;
return {
/**
* Callback to run when we receive a 'connect' event.
*/
onConnect: () => {
try {
dispatch(setIsConnected(true));
dispatch(setCurrentStatus('Connected'));
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive a 'disconnect' event.
*/
onDisconnect: () => {
try {
dispatch(setIsConnected(false));
dispatch(setIsProcessing(false));
dispatch(setCurrentStatus('Disconnected'));
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Disconnected from server`,
level: 'warning',
})
);
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive a 'generationResult' event.
*/
onGenerationResult: (data: ServerGenerationResult) => {
try {
const { url, metadata } = data;
const newUuid = uuidv4();
const translatedMetadata = backendToFrontendParameters(metadata);
dispatch(
addImage({
uuid: newUuid,
url,
metadata: translatedMetadata,
})
);
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Image generated: ${url}`,
})
);
dispatch(setIsProcessing(false));
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive a 'intermediateResult' event.
*/
onIntermediateResult: (data: ServerIntermediateResult) => {
try {
const uuid = uuidv4();
const { url, metadata } = data;
dispatch(
setIntermediateImage({
uuid,
url,
metadata,
})
);
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Intermediate image generated: ${url}`,
})
);
dispatch(setIsProcessing(false));
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive an 'esrganResult' event.
*/
onESRGANResult: (data: ServerESRGANResult) => {
try {
const { url, uuid, metadata } = data;
const newUuid = uuidv4();
// This image was only ESRGAN'd, grab the original image's metadata
const originalImage = getState().gallery.images.find(
(i: SDImage) => i.uuid === uuid
);
// Retain the original metadata
const newMetadata = {
...originalImage.metadata,
};
// Update the ESRGAN-related fields
newMetadata.shouldRunESRGAN = true;
newMetadata.upscalingLevel = metadata.upscale[0];
newMetadata.upscalingStrength = metadata.upscale[1];
dispatch(
addImage({
uuid: newUuid,
url,
metadata: newMetadata,
})
);
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Upscaled: ${url}`,
})
);
dispatch(setIsProcessing(false));
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive a 'gfpganResult' event.
*/
onGFPGANResult: (data: ServerGFPGANResult) => {
try {
const { url, uuid, metadata } = data;
const newUuid = uuidv4();
// This image was only GFPGAN'd, grab the original image's metadata
const originalImage = getState().gallery.images.find(
(i: SDImage) => i.uuid === uuid
);
// Retain the original metadata
const newMetadata = {
...originalImage.metadata,
};
// Update the GFPGAN-related fields
newMetadata.shouldRunGFPGAN = true;
newMetadata.gfpganStrength = metadata.gfpgan_strength;
dispatch(
addImage({
uuid: newUuid,
url,
metadata: newMetadata,
})
);
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Fixed faces: ${url}`,
})
);
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive a 'progressUpdate' event.
* TODO: Add additional progress phases
*/
onProgressUpdate: (data: SystemStatus) => {
try {
dispatch(setIsProcessing(true));
dispatch(setSystemStatus(data));
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive a 'progressUpdate' event.
*/
onError: (data: ServerError) => {
const { message, additionalData } = data;
if (additionalData) {
// TODO: handle more data than short message
}
try {
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Server error: ${message}`,
level: 'error',
})
);
dispatch(setIsProcessing(false));
dispatch(clearIntermediateImage());
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive a 'galleryImages' event.
*/
onGalleryImages: (data: ServerGalleryImages) => {
const { images } = data;
const preparedImages = images.map((image): SDImage => {
return {
uuid: uuidv4(),
url: image.path,
metadata: backendToFrontendParameters(image.metadata),
};
});
dispatch(setGalleryImages(preparedImages));
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Loaded ${images.length} images`,
})
);
},
/**
* Callback to run when we receive a 'processingCanceled' event.
*/
onProcessingCanceled: () => {
dispatch(setIsProcessing(false));
const { intermediateImage } = getState().gallery;
if (intermediateImage) {
dispatch(addImage(intermediateImage));
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Intermediate image saved: ${intermediateImage.url}`,
})
);
dispatch(clearIntermediateImage());
}
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Processing canceled`,
level: 'warning',
})
);
},
/**
* Callback to run when we receive a 'imageDeleted' event.
*/
onImageDeleted: (data: ServerImageUrlAndUuid) => {
const { url, uuid } = data;
dispatch(removeImage(uuid));
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Image deleted: ${url}`,
})
);
},
/**
* Callback to run when we receive a 'initialImageUploaded' event.
*/
onInitialImageUploaded: (data: ServerImageUrl) => {
const { url } = data;
dispatch(setInitialImagePath(url));
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Initial image uploaded: ${url}`,
})
);
},
/**
* Callback to run when we receive a 'maskImageUploaded' event.
*/
onMaskImageUploaded: (data: ServerImageUrl) => {
const { url } = data;
dispatch(setMaskPath(url));
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Mask image uploaded: ${url}`,
})
);
},
};
};
export default makeSocketIOListeners;

View File

@ -0,0 +1,157 @@
import { Middleware } from '@reduxjs/toolkit';
import { io } from 'socket.io-client';
import makeSocketIOListeners from './listeners';
import makeSocketIOEmitters from './emitters';
import type {
ServerGenerationResult,
ServerESRGANResult,
ServerGFPGANResult,
ServerIntermediateResult,
ServerError,
ServerGalleryImages,
ServerImageUrlAndUuid,
ServerImageUrl,
} from './types';
import { SystemStatus } from '../../features/system/systemSlice';
export const socketioMiddleware = () => {
const { hostname, port } = new URL(window.location.href);
const socketio = io(`http://${hostname}:9090`);
let areListenersSet = false;
const middleware: Middleware = (store) => (next) => (action) => {
const {
onConnect,
onDisconnect,
onError,
onESRGANResult,
onGFPGANResult,
onGenerationResult,
onIntermediateResult,
onProgressUpdate,
onGalleryImages,
onProcessingCanceled,
onImageDeleted,
onInitialImageUploaded,
onMaskImageUploaded,
} = makeSocketIOListeners(store);
const {
emitGenerateImage,
emitRunESRGAN,
emitRunGFPGAN,
emitDeleteImage,
emitRequestAllImages,
emitCancelProcessing,
emitUploadInitialImage,
emitUploadMaskImage,
} = makeSocketIOEmitters(store, socketio);
/**
* If this is the first time the middleware has been called (e.g. during store setup),
* initialize all our socket.io listeners.
*/
if (!areListenersSet) {
socketio.on('connect', () => onConnect());
socketio.on('disconnect', () => onDisconnect());
socketio.on('error', (data: ServerError) => onError(data));
socketio.on('generationResult', (data: ServerGenerationResult) =>
onGenerationResult(data)
);
socketio.on('esrganResult', (data: ServerESRGANResult) =>
onESRGANResult(data)
);
socketio.on('gfpganResult', (data: ServerGFPGANResult) =>
onGFPGANResult(data)
);
socketio.on('intermediateResult', (data: ServerIntermediateResult) =>
onIntermediateResult(data)
);
socketio.on('progressUpdate', (data: SystemStatus) =>
onProgressUpdate(data)
);
socketio.on('galleryImages', (data: ServerGalleryImages) =>
onGalleryImages(data)
);
socketio.on('processingCanceled', () => {
onProcessingCanceled();
});
socketio.on('imageDeleted', (data: ServerImageUrlAndUuid) => {
onImageDeleted(data);
});
socketio.on('initialImageUploaded', (data: ServerImageUrl) => {
onInitialImageUploaded(data);
});
socketio.on('maskImageUploaded', (data: ServerImageUrl) => {
onMaskImageUploaded(data);
});
areListenersSet = true;
}
/**
* Handle redux actions caught by middleware.
*/
switch (action.type) {
case 'socketio/generateImage': {
emitGenerateImage();
break;
}
case 'socketio/runESRGAN': {
emitRunESRGAN(action.payload);
break;
}
case 'socketio/runGFPGAN': {
emitRunGFPGAN(action.payload);
break;
}
case 'socketio/deleteImage': {
emitDeleteImage(action.payload);
break;
}
case 'socketio/requestAllImages': {
emitRequestAllImages();
break;
}
case 'socketio/cancelProcessing': {
emitCancelProcessing();
break;
}
case 'socketio/uploadInitialImage': {
emitUploadInitialImage(action.payload);
break;
}
case 'socketio/uploadMaskImage': {
emitUploadMaskImage(action.payload);
break;
}
}
next(action);
};
return middleware;
};

46
frontend/src/app/socketio/types.d.ts vendored Normal file
View File

@ -0,0 +1,46 @@
/**
* Interfaces used by the socketio middleware.
*/
export declare interface ServerGenerationResult {
url: string;
metadata: { [key: string]: any };
}
export declare interface ServerESRGANResult {
url: string;
uuid: string;
metadata: { [key: string]: any };
}
export declare interface ServerGFPGANResult {
url: string;
uuid: string;
metadata: { [key: string]: any };
}
export declare interface ServerIntermediateResult {
url: string;
metadata: { [key: string]: any };
}
export declare interface ServerError {
message: string;
additionalData?: string;
}
export declare interface ServerGalleryImages {
images: Array<{
path: string;
metadata: { [key: string]: any };
}>;
}
export declare interface ServerImageUrlAndUuid {
uuid: string;
url: string;
}
export declare interface ServerImageUrl {
url: string;
}

View File

@ -1,53 +1,78 @@
import { combineReducers, configureStore } from '@reduxjs/toolkit';
import { useDispatch, useSelector } from 'react-redux';
import type { TypedUseSelectorHook } from 'react-redux';
import { persistReducer } from 'redux-persist';
import storage from 'redux-persist/lib/storage'; // defaults to localStorage for web
import sdReducer from '../features/sd/sdSlice';
import galleryReducer from '../features/gallery/gallerySlice';
import systemReducer from '../features/system/systemSlice';
import { socketioMiddleware } from './socketio';
import { socketioMiddleware } from './socketio/middleware';
/**
* redux-persist provides an easy and reliable way to persist state across reloads.
*
* While we definitely want generation parameters to be persisted, there are a number
* of things we do *not* want to be persisted across reloads:
* - Gallery/selected image (user may add/delete images from disk between page loads)
* - Connection/processing status
* - Availability of external libraries like ESRGAN/GFPGAN
*
* These can be blacklisted in redux-persist.
*
* The necesssary nested persistors with blacklists are configured below.
*
* TODO: Do we blacklist initialImagePath? If the image is deleted from disk we get an
* ugly 404. But if we blacklist it, then this is a valuable parameter that is lost
* on reload. Need to figure out a good way to handle this.
*/
const rootPersistConfig = {
key: 'root',
storage,
blacklist: ['gallery', 'system'],
};
const systemPersistConfig = {
key: 'system',
storage,
blacklist: [
'isConnected',
'isProcessing',
'currentStep',
'socketId',
'isESRGANAvailable',
'isGFPGANAvailable',
'currentStep',
'totalSteps',
'currentIteration',
'totalIterations',
'currentStatus',
],
};
const reducers = combineReducers({
sd: sdReducer,
gallery: galleryReducer,
system: systemReducer,
system: persistReducer(systemPersistConfig, systemReducer),
});
const persistConfig = {
key: 'root',
storage,
};
const persistedReducer = persistReducer(persistConfig, reducers);
/*
The frontend needs to be distributed as a production build, so
we cannot reasonably ask users to edit the JS and specify the
host and port on which the socket.io server will run.
The solution is to allow server script to be run with arguments
(or just edited) providing the host and port. Then, the server
serves a route `/socketio_config` which responds with the host
and port.
When the frontend loads, it synchronously requests that route
and thus gets the host and port. This requires a suspicious
fetch somewhere, and the store setup seems like as good a place
as any to make this fetch request.
*/
const persistedReducer = persistReducer(rootPersistConfig, reducers);
// Continue with store setup
export const store = configureStore({
reducer: persistedReducer,
middleware: (getDefaultMiddleware) =>
getDefaultMiddleware({
// redux-persist sometimes needs to have a function in redux, need to disable this check
// redux-persist sometimes needs to temporarily put a function in redux state, need to disable this check
serializableCheck: false,
}).concat(socketioMiddleware()),
});
// Infer the `RootState` and `AppDispatch` types from the store itself
export type RootState = ReturnType<typeof store.getState>;
// Inferred type: {posts: PostsState, comments: CommentsState, users: UsersState}
export type AppDispatch = typeof store.dispatch;
// Use throughout your app instead of plain `useDispatch` and `useSelector`
export const useAppDispatch: () => AppDispatch = useDispatch;
export const useAppSelector: TypedUseSelectorHook<RootState> = useSelector;

View File

@ -0,0 +1,171 @@
/**
* Defines common parameters required to generate an image.
* See #266 for the eventual maturation of this interface.
*/
interface CommonParameters {
/**
* The "txt2img" prompt. String. Minimum one character. No maximum.
*/
prompt: string;
/**
* The number of sampler steps. Integer. Minimum value 1. No maximum.
*/
steps: number;
/**
* Classifier-free guidance scale. Float. Minimum value 0. Maximum?
*/
cfgScale: number;
/**
* Height of output image in pixels. Integer. Minimum 64. Must be multiple of 64. No maximum.
*/
height: number;
/**
* Width of output image in pixels. Integer. Minimum 64. Must be multiple of 64. No maximum.
*/
width: number;
/**
* Name of the sampler to use. String. Restricted values.
*/
sampler:
| 'ddim'
| 'plms'
| 'k_lms'
| 'k_dpm_2'
| 'k_dpm_2_a'
| 'k_euler'
| 'k_euler_a'
| 'k_heun';
/**
* Seed used for randomness. Integer. 0 --> 4294967295, inclusive.
*/
seed: number;
/**
* Flag to enable seamless tiling image generation. Boolean.
*/
seamless: boolean;
}
/**
* Defines parameters needed to use the "img2img" generation method.
*/
interface ImageToImageParameters {
/**
* Folder path to the image used as the initial image. String.
*/
initialImagePath: string;
/**
* Flag to enable the use of a mask image during "img2img" generations.
* Requires valid ImageToImageParameters. Boolean.
*/
shouldUseMaskImage: boolean;
/**
* Folder path to the image used as a mask image. String.
*/
maskImagePath: string;
/**
* Strength of adherance to initial image. Float. 0 --> 1, exclusive.
*/
img2imgStrength: number;
/**
* Flag to enable the stretching of init image to desired output. Boolean.
*/
shouldFit: boolean;
}
/**
* Defines the parameters needed to generate variations.
*/
interface VariationParameters {
/**
* Variation amount. Float. 0 --> 1, exclusive.
* TODO: What does this really do?
*/
variationAmount: number;
/**
* List of seed-weight pairs formatted as "seed:weight,...".
* Seed is a valid seed. Weight is a float, 0 --> 1, exclusive.
* String, must be parseable into [[seed,weight],...] format.
*/
seedWeights: string;
}
/**
* Defines the parameters needed to use GFPGAN postprocessing.
*/
interface GFPGANParameters {
/**
* GFPGAN strength. Strength to apply face-fixing processing. Float. 0 --> 1, exclusive.
*/
gfpganStrength: number;
}
/**
* Defines the parameters needed to use ESRGAN postprocessing.
*/
interface ESRGANParameters {
/**
* ESRGAN strength. Strength to apply upscaling. Float. 0 --> 1, exclusive.
*/
esrganStrength: number;
/**
* ESRGAN upscaling scale. One of 2x | 4x. Represented as integer.
*/
esrganScale: 2 | 4;
}
/**
* Extends the generation and processing method parameters, adding flags to enable each.
*/
interface ProcessingParameters extends CommonParameters {
/**
* Flag to enable the generation of variations. Requires valid VariationParameters. Boolean.
*/
shouldGenerateVariations: boolean;
/**
* Variation parameters.
*/
variationParameters: VariationParameters;
/**
* Flag to enable the use of an initial image, i.e. to use "img2img" generation.
* Requires valid ImageToImageParameters. Boolean.
*/
shouldUseImageToImage: boolean;
/**
* ImageToImage parameters.
*/
imageToImageParameters: ImageToImageParameters;
/**
* Flag to enable GFPGAN postprocessing. Requires valid GFPGANParameters. Boolean.
*/
shouldRunGFPGAN: boolean;
/**
* GFPGAN parameters.
*/
gfpganParameters: GFPGANParameters;
/**
* Flag to enable ESRGAN postprocessing. Requires valid ESRGANParameters. Boolean.
*/
shouldRunESRGAN: boolean;
/**
* ESRGAN parameters.
*/
esrganParameters: GFPGANParameters;
}
/**
* Extends ProcessingParameters, adding items needed to request processing.
*/
interface ProcessingState extends ProcessingParameters {
/**
* Number of images to generate. Integer. Minimum 1.
*/
iterations: number;
/**
* Flag to enable the randomization of the seed on each generation. Boolean.
*/
shouldRandomizeSeed: boolean;
}
export {}

View File

@ -1,11 +1,11 @@
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { useMemo } from 'react';
import { useAppSelector } from '../../app/hooks';
import { useAppSelector } from '../../app/store';
import { RootState } from '../../app/store';
import { SDState } from '../sd/sdSlice';
import { validateSeedWeights } from '../sd/util/seedWeightPairs';
import { SystemState } from './systemSlice';
import { SDState } from '../../features/sd/sdSlice';
import { SystemState } from '../../features/system/systemSlice';
import { validateSeedWeights } from '../util/seedWeightPairs';
const sdSelector = createSelector(
(state: RootState) => state.sd,

View File

@ -1,17 +1,15 @@
import { SDState } from '../features/sd/sdSlice';
import randomInt from '../features/sd/util/randomInt';
import {
seedWeightsToString,
stringToSeedWeights,
} from '../features/sd/util/seedWeightPairs';
import { SystemState } from '../features/system/systemSlice';
import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from './constants';
/*
These functions translate frontend state into parameters
suitable for consumption by the backend, and vice-versa.
*/
import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from "../../app/constants";
import { SDState } from "../../features/sd/sdSlice";
import { SystemState } from "../../features/system/systemSlice";
import randomInt from "./randomInt";
import { seedWeightsToString, stringToSeedWeights } from "./seedWeightPairs";
export const frontendToBackendParameters = (
sdState: SDState,
systemState: SystemState
@ -77,7 +75,7 @@ export const frontendToBackendParameters = (
stringToSeedWeights(seedWeights);
}
} else {
generationParameters.variation_amount = 0.1;
generationParameters.variation_amount = 0;
}
let esrganParameters: false | { [k: string]: any } = false;
@ -96,6 +94,8 @@ export const frontendToBackendParameters = (
};
}
console.log(generationParameters)
return {
generationParameters,
esrganParameters,

View File

@ -1,14 +1,14 @@
import { Flex } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { useAppDispatch, useAppSelector } from '../../app/store';
import { RootState } from '../../app/store';
import { setAllParameters, setInitialImagePath, setSeed } from '../sd/sdSlice';
import DeleteImageModal from './DeleteImageModal';
import SDButton from '../../components/SDButton';
import { runESRGAN, runGFPGAN } from '../../app/socketio';
import { createSelector } from '@reduxjs/toolkit';
import { SystemState } from '../system/systemSlice';
import { isEqual } from 'lodash';
import { SDImage } from './gallerySlice';
import SDButton from '../../common/components/SDButton';
import { runESRGAN, runGFPGAN } from '../../app/socketio/actions';
const systemSelector = createSelector(
(state: RootState) => state.system,

View File

@ -1,5 +1,5 @@
import { Center, Flex, Image, Text, useColorModeValue } from '@chakra-ui/react';
import { useAppSelector } from '../../app/hooks';
import { useAppSelector } from '../../app/store';
import { RootState } from '../../app/store';
import { useState } from 'react';
import ImageMetadataViewer from './ImageMetadataViewer';

View File

@ -21,8 +21,8 @@ import {
SyntheticEvent,
useRef,
} from 'react';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { deleteImage } from '../../app/socketio';
import { useAppDispatch, useAppSelector } from '../../app/store';
import { deleteImage } from '../../app/socketio/actions';
import { RootState } from '../../app/store';
import { setShouldConfirmOnDelete, SystemState } from '../system/systemSlice';
import { SDImage } from './gallerySlice';

View File

@ -6,7 +6,7 @@ import {
Image,
useColorModeValue,
} from '@chakra-ui/react';
import { useAppDispatch } from '../../app/hooks';
import { useAppDispatch } from '../../app/store';
import { SDImage, setCurrentImage } from './gallerySlice';
import { FaCheck, FaCopy, FaSeedling, FaTrash } from 'react-icons/fa';
import DeleteImageModal from './DeleteImageModal';

View File

@ -1,6 +1,6 @@
import { Flex } from '@chakra-ui/react';
import { Center, Flex, Text } from '@chakra-ui/react';
import { RootState } from '../../app/store';
import { useAppSelector } from '../../app/hooks';
import { useAppSelector } from '../../app/store';
import HoverableImage from './HoverableImage';
/**
@ -19,7 +19,7 @@ const ImageGallery = () => {
* TODO: Refactor if performance complaints, or after migrating to new API which supports pagination.
*/
return (
return images.length ? (
<Flex gap={2} wrap="wrap" pb={2}>
{[...images].reverse().map((image) => {
const { uuid } = image;
@ -29,6 +29,10 @@ const ImageGallery = () => {
);
})}
</Flex>
) : (
<Center height={'100%'} position={'relative'}>
<Text size={'xl'}>No images in gallery</Text>
</Center>
);
};

View File

@ -10,8 +10,8 @@ import {
import { memo } from 'react';
import { FaPlus } from 'react-icons/fa';
import { PARAMETERS } from '../../app/constants';
import { useAppDispatch } from '../../app/hooks';
import SDButton from '../../components/SDButton';
import { useAppDispatch } from '../../app/store';
import SDButton from '../../common/components/SDButton';
import { setAllParameters, setParameter } from '../sd/sdSlice';
import { SDImage, SDMetadata } from './gallerySlice';

View File

@ -1,8 +1,7 @@
import { createSlice } from '@reduxjs/toolkit';
import type { PayloadAction } from '@reduxjs/toolkit';
import { v4 as uuidv4 } from 'uuid';
import { UpscalingLevel } from '../sd/sdSlice';
import { backendToFrontendParameters } from '../../app/parameterTranslation';
import { clamp } from 'lodash';
// TODO: Revise pending metadata RFC: https://github.com/lstein/stable-diffusion/issues/266
export interface SDMetadata {
@ -50,29 +49,48 @@ export const gallerySlice = createSlice({
state.currentImage = action.payload;
state.currentImageUuid = action.payload.uuid;
},
removeImage: (state, action: PayloadAction<SDImage>) => {
const { uuid } = action.payload;
removeImage: (state, action: PayloadAction<string>) => {
const uuid = action.payload;
const newImages = state.images.filter((image) => image.uuid !== uuid);
const imageToDeleteIndex = state.images.findIndex(
(image) => image.uuid === uuid
);
if (uuid === state.currentImageUuid) {
/**
* We are deleting the currently selected image.
*
* We want the new currentl selected image to be under the cursor in the
* gallery, so we need to do some fanagling. The currently selected image
* is set by its UUID, not its index in the image list.
*
* Get the currently selected image's index.
*/
const imageToDeleteIndex = state.images.findIndex(
(image) => image.uuid === uuid
);
const newCurrentImageIndex = Math.min(
Math.max(imageToDeleteIndex, 0),
newImages.length - 1
);
/**
* New current image needs to be in the same spot, but because the gallery
* is sorted in reverse order, the new current image's index will actuall be
* one less than the deleted image's index.
*
* Clamp the new index to ensure it is valid..
*/
const newCurrentImageIndex = clamp(
imageToDeleteIndex - 1,
0,
newImages.length - 1
);
state.currentImage = newImages.length
? newImages[newCurrentImageIndex]
: undefined;
state.currentImageUuid = newImages.length
? newImages[newCurrentImageIndex].uuid
: '';
}
state.images = newImages;
state.currentImage = newImages.length
? newImages[newCurrentImageIndex]
: undefined;
state.currentImageUuid = newImages.length
? newImages[newCurrentImageIndex].uuid
: '';
},
addImage: (state, action: PayloadAction<SDImage>) => {
state.images.push(action.payload);
@ -86,47 +104,13 @@ export const gallerySlice = createSlice({
clearIntermediateImage: (state) => {
state.intermediateImage = undefined;
},
setGalleryImages: (
state,
action: PayloadAction<
Array<{
path: string;
metadata: { [key: string]: string | number | boolean };
}>
>
) => {
// TODO: Revise pending metadata RFC: https://github.com/lstein/stable-diffusion/issues/266
const images = action.payload;
if (images.length === 0) {
// there are no images on disk, clear the gallery
state.images = [];
state.currentImageUuid = '';
state.currentImage = undefined;
} else {
// Filter image urls that are already in the rehydrated state
const filteredImages = action.payload.filter(
(image) => !state.images.find((i) => i.url === image.path)
);
const preparedImages = filteredImages.map((image): SDImage => {
return {
uuid: uuidv4(),
url: image.path,
metadata: backendToFrontendParameters(image.metadata),
};
});
const newImages = [...state.images].concat(preparedImages);
// if previous currentimage no longer exists, set a new one
if (!newImages.find((image) => image.uuid === state.currentImageUuid)) {
const newCurrentImage = newImages[newImages.length - 1];
state.currentImage = newCurrentImage;
state.currentImageUuid = newCurrentImage.uuid;
}
setGalleryImages: (state, action: PayloadAction<Array<SDImage>>) => {
const newImages = action.payload;
if (newImages.length) {
const newCurrentImage = newImages[newImages.length - 1];
state.images = newImages;
state.currentImage = newCurrentImage;
state.currentImageUuid = newCurrentImage.uuid;
}
},
},

View File

@ -1,35 +1,38 @@
import { Progress } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { useAppSelector } from '../../app/hooks';
import { useAppSelector } from '../../app/store';
import { RootState } from '../../app/store';
import { SDState } from '../sd/sdSlice';
import { SystemState } from '../system/systemSlice';
const sdSelector = createSelector(
(state: RootState) => state.sd,
(sd: SDState) => {
return {
realSteps: sd.realSteps,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return {
isProcessing: system.isProcessing,
currentStep: system.currentStep,
totalSteps: system.totalSteps,
currentStatusHasSteps: system.currentStatusHasSteps,
};
},
{
memoizeOptions: { resultEqualityCheck: isEqual },
}
);
const ProgressBar = () => {
const { realSteps } = useAppSelector(sdSelector);
const { currentStep } = useAppSelector((state: RootState) => state.system);
const progress = Math.round((currentStep * 100) / realSteps);
return (
<Progress
height='10px'
value={progress}
isIndeterminate={progress < 0 || currentStep === realSteps}
/>
);
const { isProcessing, currentStep, totalSteps, currentStatusHasSteps } =
useAppSelector(systemSelector);
const value = currentStep ? Math.round((currentStep * 100) / totalSteps) : 0;
return (
<Progress
height="10px"
value={value}
isIndeterminate={isProcessing && !currentStatusHasSteps}
/>
);
};
export default ProgressBar;

View File

@ -12,15 +12,20 @@ import { isEqual } from 'lodash';
import { FaSun, FaMoon, FaGithub } from 'react-icons/fa';
import { MdHelp, MdSettings } from 'react-icons/md';
import { useAppSelector } from '../../app/hooks';
import { useAppSelector } from '../../app/store';
import { RootState } from '../../app/store';
import SettingsModal from '../system/SettingsModal';
import { SystemState } from '../system/systemSlice';
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return { isConnected: system.isConnected };
return {
isConnected: system.isConnected,
isProcessing: system.isProcessing,
currentIteration: system.currentIteration,
totalIterations: system.totalIterations,
currentStatus: system.currentStatus,
};
},
{
memoizeOptions: { resultEqualityCheck: isEqual },
@ -32,11 +37,13 @@ const systemSelector = createSelector(
*/
const SiteHeader = () => {
const { colorMode, toggleColorMode } = useColorMode();
const { isConnected } = useAppSelector(systemSelector);
const statusMessage = isConnected
? `Connected to server`
: 'No connection to server';
const {
isConnected,
isProcessing,
currentIteration,
totalIterations,
currentStatus,
} = useAppSelector(systemSelector);
const statusMessageTextColor = isConnected ? 'green.500' : 'red.500';
@ -45,6 +52,14 @@ const SiteHeader = () => {
// Make FaMoon and FaSun icon apparent size consistent
const colorModeIconFontSize = colorMode == 'light' ? 18 : 20;
let statusMessage = currentStatus;
if (isProcessing) {
if (totalIterations > 1) {
statusMessage += ` [${currentIteration}/${totalIterations}]`;
}
}
return (
<Flex minWidth="max-content" alignItems="center" gap="1" pl={2} pr={1}>
<Heading size={'lg'}>Stable Diffusion Dream Server</Heading>

View File

@ -1,7 +1,7 @@
import { Flex } from '@chakra-ui/react';
import { RootState } from '../../app/store';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { useAppDispatch, useAppSelector } from '../../app/store';
import {
setUpscalingLevel,
@ -10,14 +10,14 @@ import {
SDState,
} from '../sd/sdSlice';
import SDNumberInput from '../../components/SDNumberInput';
import SDSelect from '../../components/SDSelect';
import { UPSCALING_LEVELS } from '../../app/constants';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { SystemState } from '../system/systemSlice';
import { ChangeEvent } from 'react';
import SDNumberInput from '../../common/components/SDNumberInput';
import SDSelect from '../../common/components/SDSelect';
const sdSelector = createSelector(
(state: RootState) => state.sd,

View File

@ -1,15 +1,15 @@
import { Flex } from '@chakra-ui/react';
import { RootState } from '../../app/store';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { useAppDispatch, useAppSelector } from '../../app/store';
import { SDState, setGfpganStrength } from '../sd/sdSlice';
import SDNumberInput from '../../components/SDNumberInput';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { SystemState } from '../system/systemSlice';
import SDNumberInput from '../../common/components/SDNumberInput';
const sdSelector = createSelector(
(state: RootState) => state.sd,

View File

@ -1,10 +1,10 @@
import { Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { ChangeEvent } from 'react';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { useAppDispatch, useAppSelector } from '../../app/store';
import { RootState } from '../../app/store';
import SDNumberInput from '../../components/SDNumberInput';
import SDSwitch from '../../components/SDSwitch';
import SDNumberInput from '../../common/components/SDNumberInput';
import SDSwitch from '../../common/components/SDSwitch';
import InitAndMaskImage from './InitAndMaskImage';
import {
SDState,

View File

@ -1,6 +1,6 @@
import { Flex, Image } from '@chakra-ui/react';
import { useState } from 'react';
import { useAppSelector } from '../../app/hooks';
import { useAppSelector } from '../../app/store';
import { RootState } from '../../app/store';
import { SDState } from '../../features/sd/sdSlice';
import './InitAndMaskImage.css';

View File

@ -1,14 +1,14 @@
import { Button, Flex, IconButton, useToast } from '@chakra-ui/react';
import { SyntheticEvent, useCallback } from 'react';
import { FaTrash } from 'react-icons/fa';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { useAppDispatch, useAppSelector } from '../../app/store';
import { RootState } from '../../app/store';
import {
SDState,
setInitialImagePath,
setMaskPath,
} from '../../features/sd/sdSlice';
import { uploadInitialImage, uploadMaskImage } from '../../app/socketio';
import { uploadInitialImage, uploadMaskImage } from '../../app/socketio/actions';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import ImageUploader from './ImageUploader';

View File

@ -12,7 +12,7 @@ import {
} from '@chakra-ui/react';
import { RootState } from '../../app/store';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { useAppDispatch, useAppSelector } from '../../app/store';
import {
setShouldRunGFPGAN,

View File

@ -1,17 +1,17 @@
import { Flex } from '@chakra-ui/react';
import { RootState } from '../../app/store';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { useAppDispatch, useAppSelector } from '../../app/store';
import { setHeight, setWidth, setSeamless, SDState } from '../sd/sdSlice';
import SDSelect from '../../components/SDSelect';
import { HEIGHTS, WIDTHS } from '../../app/constants';
import SDSwitch from '../../components/SDSwitch';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { ChangeEvent } from 'react';
import SDSelect from '../../common/components/SDSelect';
import SDSwitch from '../../common/components/SDSwitch';
const sdSelector = createSelector(
(state: RootState) => state.sd,

View File

@ -1,12 +1,12 @@
import { Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { cancelProcessing, generateImage } from '../../app/socketio';
import { useAppDispatch, useAppSelector } from '../../app/store';
import { cancelProcessing, generateImage } from '../../app/socketio/actions';
import { RootState } from '../../app/store';
import SDButton from '../../components/SDButton';
import SDButton from '../../common/components/SDButton';
import useCheckParameters from '../../common/hooks/useCheckParameters';
import { SystemState } from '../system/systemSlice';
import useCheckParameters from '../system/useCheckParameters';
const systemSelector = createSelector(
(state: RootState) => state.system,

View File

@ -3,7 +3,8 @@ import {
ChangeEvent,
KeyboardEvent,
} from 'react';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { useAppDispatch, useAppSelector } from '../../app/store';
import { generateImage } from '../../app/socketio/actions';
import { RootState } from '../../app/store';
import { setPrompt } from '../sd/sdSlice';

View File

@ -1,17 +1,17 @@
import { Flex } from '@chakra-ui/react';
import { RootState } from '../../app/store';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { useAppDispatch, useAppSelector } from '../../app/store';
import { setCfgScale, setSampler, setSteps, SDState } from '../sd/sdSlice';
import SDNumberInput from '../../components/SDNumberInput';
import SDSelect from '../../components/SDSelect';
import { SAMPLERS } from '../../app/constants';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { ChangeEvent } from 'react';
import SDNumberInput from '../../common/components/SDNumberInput';
import SDSelect from '../../common/components/SDSelect';
const sdSelector = createSelector(
(state: RootState) => state.sd,

View File

@ -11,10 +11,12 @@ import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { ChangeEvent } from 'react';
import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from '../../app/constants';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { useAppDispatch, useAppSelector } from '../../app/store';
import { RootState } from '../../app/store';
import SDNumberInput from '../../components/SDNumberInput';
import SDSwitch from '../../components/SDSwitch';
import SDNumberInput from '../../common/components/SDNumberInput';
import SDSwitch from '../../common/components/SDSwitch';
import randomInt from '../../common/util/randomInt';
import { validateSeedWeights } from '../../common/util/seedWeightPairs';
import {
SDState,
setIterations,
@ -24,8 +26,6 @@ import {
setShouldRandomizeSeed,
setVariationAmount,
} from './sdSlice';
import randomInt from './util/randomInt';
import { validateSeedWeights } from './util/seedWeightPairs';
const sdSelector = createSelector(
(state: RootState) => state.sd,

View File

@ -2,21 +2,12 @@ import { createSlice } from '@reduxjs/toolkit';
import type { PayloadAction } from '@reduxjs/toolkit';
import { SDMetadata } from '../gallery/gallerySlice';
const calculateRealSteps = (
steps: number,
strength: number,
hasInitImage: boolean
): number => {
return hasInitImage ? Math.floor(strength * steps) : steps;
};
export type UpscalingLevel = 0 | 2 | 4;
export interface SDState {
prompt: string;
iterations: number;
steps: number;
realSteps: number;
cfgScale: number;
height: number;
width: number;
@ -43,7 +34,6 @@ const initialSDState: SDState = {
prompt: '',
iterations: 1,
steps: 50,
realSteps: 50,
cfgScale: 7.5,
height: 512,
width: 512,
@ -79,14 +69,7 @@ export const sdSlice = createSlice({
state.iterations = action.payload;
},
setSteps: (state, action: PayloadAction<number>) => {
const { img2imgStrength, initialImagePath } = state;
const steps = action.payload;
state.steps = steps;
state.realSteps = calculateRealSteps(
steps,
img2imgStrength,
Boolean(initialImagePath)
);
state.steps = action.payload;
},
setCfgScale: (state, action: PayloadAction<number>) => {
state.cfgScale = action.payload;
@ -105,14 +88,7 @@ export const sdSlice = createSlice({
state.shouldRandomizeSeed = false;
},
setImg2imgStrength: (state, action: PayloadAction<number>) => {
const img2imgStrength = action.payload;
const { steps, initialImagePath } = state;
state.img2imgStrength = img2imgStrength;
state.realSteps = calculateRealSteps(
steps,
img2imgStrength,
Boolean(initialImagePath)
);
state.img2imgStrength = action.payload;
},
setGfpganStrength: (state, action: PayloadAction<number>) => {
state.gfpganStrength = action.payload;
@ -127,15 +103,9 @@ export const sdSlice = createSlice({
state.shouldUseInitImage = action.payload;
},
setInitialImagePath: (state, action: PayloadAction<string>) => {
const initialImagePath = action.payload;
const { steps, img2imgStrength } = state;
state.shouldUseInitImage = initialImagePath ? true : false;
state.initialImagePath = initialImagePath;
state.realSteps = calculateRealSteps(
steps,
img2imgStrength,
Boolean(initialImagePath)
);
const newInitialImagePath = action.payload;
state.shouldUseInitImage = newInitialImagePath ? true : false;
state.initialImagePath = newInitialImagePath;
},
setMaskPath: (state, action: PayloadAction<string>) => {
state.maskPath = action.payload;
@ -153,6 +123,7 @@ export const sdSlice = createSlice({
state,
action: PayloadAction<{ key: string; value: string | number | boolean }>
) => {
// TODO: This probably needs to be refactored.
const { key, value } = action.payload;
const temp = { ...state, [key]: value };
if (key === 'seed') {
@ -173,6 +144,7 @@ export const sdSlice = createSlice({
state.seedWeights = action.payload;
},
setAllParameters: (state, action: PayloadAction<SDMetadata>) => {
// TODO: This probably needs to be refactored.
const {
prompt,
steps,

View File

@ -5,7 +5,7 @@ import {
Text,
Tooltip,
} from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { useAppDispatch, useAppSelector } from '../../app/store';
import { RootState } from '../../app/store';
import { setShouldShowLogViewer, SystemState } from './systemSlice';
import { useLayoutEffect, useRef, useState } from 'react';
@ -44,18 +44,42 @@ const LogViewer = () => {
const log = useAppSelector(logSelector);
const { shouldShowLogViewer } = useAppSelector(systemSelector);
// Set colors based on dark/light mode
const bg = useColorModeValue('gray.50', 'gray.900');
const borderColor = useColorModeValue('gray.500', 'gray.500');
const logTextColors = useColorModeValue(
{
info: undefined,
warning: 'yellow.500',
error: 'red.500',
},
{
info: undefined,
warning: 'yellow.300',
error: 'red.300',
}
);
// Rudimentary autoscroll
const [shouldAutoscroll, setShouldAutoscroll] = useState<boolean>(true);
const viewerRef = useRef<HTMLDivElement>(null);
/**
* If autoscroll is on, scroll to the bottom when:
* - log updates
* - viewer is toggled
*
* Also scroll to the bottom whenever autoscroll is turned on.
*/
useLayoutEffect(() => {
if (viewerRef.current !== null && shouldAutoscroll) {
viewerRef.current.scrollTop = viewerRef.current.scrollHeight;
}
}, [shouldAutoscroll]);
}, [shouldAutoscroll, log, shouldShowLogViewer]);
const handleClickLogViewerToggle = () => {
dispatch(setShouldShowLogViewer(!shouldShowLogViewer));
};
return (
<>
@ -78,16 +102,19 @@ const LogViewer = () => {
background={bg}
ref={viewerRef}
>
{log.map((entry, i) => (
<Flex gap={2} key={i}>
<Text fontSize="sm" fontWeight={'semibold'}>
{entry.timestamp}:
</Text>
<Text fontSize="sm" wordBreak={'break-all'}>
{entry.message}
</Text>
</Flex>
))}
{log.map((entry, i) => {
const { timestamp, message, level } = entry;
return (
<Flex gap={2} key={i} textColor={logTextColors[level]}>
<Text fontSize="sm" fontWeight={'semibold'}>
{timestamp}:
</Text>
<Text fontSize="sm" wordBreak={'break-all'}>
{message}
</Text>
</Flex>
);
})}
</Flex>
)}
{shouldShowLogViewer && (
@ -114,7 +141,7 @@ const LogViewer = () => {
variant={'solid'}
aria-label="Toggle Log Viewer"
icon={shouldShowLogViewer ? <FaMinus /> : <FaCode />}
onClick={() => dispatch(setShouldShowLogViewer(!shouldShowLogViewer))}
onClick={handleClickLogViewerToggle}
/>
</Tooltip>
</>

View File

@ -16,7 +16,7 @@ import {
Text,
useDisclosure,
} from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { useAppDispatch, useAppSelector } from '../../app/store';
import {
setShouldConfirmOnDelete,
setShouldDisplayInProgress,

View File

@ -1,10 +1,12 @@
import { createSlice } from '@reduxjs/toolkit';
import type { PayloadAction } from '@reduxjs/toolkit';
import dateFormat from 'dateformat';
import { ExpandedIndex } from '@chakra-ui/react';
export type LogLevel = 'info' | 'warning' | 'error';
export interface LogEntry {
timestamp: string;
level: LogLevel;
message: string;
}
@ -12,10 +14,18 @@ export interface Log {
[index: number]: LogEntry;
}
export interface SystemState {
shouldDisplayInProgress: boolean;
export interface SystemStatus {
isProcessing: boolean;
currentStep: number;
totalSteps: number;
currentIteration: number;
totalIterations: number;
currentStatus: string;
currentStatusHasSteps: boolean;
}
export interface SystemState extends SystemStatus {
shouldDisplayInProgress: boolean;
log: Array<LogEntry>;
shouldShowLogViewer: boolean;
isGFPGANAvailable: boolean;
@ -24,12 +34,17 @@ export interface SystemState {
socketId: string;
shouldConfirmOnDelete: boolean;
openAccordions: ExpandedIndex;
currentStep: number;
totalSteps: number;
currentIteration: number;
totalIterations: number;
currentStatus: string;
currentStatusHasSteps: boolean;
}
const initialSystemState = {
isConnected: false,
isProcessing: false,
currentStep: 0,
log: [],
shouldShowLogViewer: false,
shouldDisplayInProgress: false,
@ -38,6 +53,12 @@ const initialSystemState = {
socketId: '',
shouldConfirmOnDelete: true,
openAccordions: [0],
currentStep: 0,
totalSteps: 0,
currentIteration: 0,
totalIterations: 0,
currentStatus: '',
currentStatusHasSteps: false,
};
const initialState: SystemState = initialSystemState;
@ -51,18 +72,35 @@ export const systemSlice = createSlice({
},
setIsProcessing: (state, action: PayloadAction<boolean>) => {
state.isProcessing = action.payload;
if (action.payload === false) {
state.currentStep = 0;
}
},
setCurrentStep: (state, action: PayloadAction<number>) => {
state.currentStep = action.payload;
setCurrentStatus: (state, action: PayloadAction<string>) => {
state.currentStatus = action.payload;
},
addLogEntry: (state, action: PayloadAction<string>) => {
setSystemStatus: (state, action: PayloadAction<SystemStatus>) => {
const currentStatus =
!action.payload.isProcessing && state.isConnected
? 'Connected'
: action.payload.currentStatus;
return { ...state, ...action.payload, currentStatus };
},
addLogEntry: (
state,
action: PayloadAction<{
timestamp: string;
message: string;
level?: LogLevel;
}>
) => {
const { timestamp, message, level } = action.payload;
const logLevel = level || 'info';
const entry: LogEntry = {
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: action.payload,
timestamp,
message,
level: logLevel,
};
state.log.push(entry);
},
setShouldShowLogViewer: (state, action: PayloadAction<boolean>) => {
@ -86,13 +124,14 @@ export const systemSlice = createSlice({
export const {
setShouldDisplayInProgress,
setIsProcessing,
setCurrentStep,
addLogEntry,
setShouldShowLogViewer,
setIsConnected,
setSocketId,
setShouldConfirmOnDelete,
setOpenAccordions,
setSystemStatus,
setCurrentStatus,
} = systemSlice.actions;
export default systemSlice.reducer;

View File

@ -8,9 +8,9 @@ import { persistStore } from 'redux-persist';
export const persistor = persistStore(store);
import App from './App';
import { theme } from './app/theme';
import Loading from './Loading';
import App from './app/App';
ReactDOM.createRoot(document.getElementById('root') as HTMLElement).render(
<React.StrictMode>