Merge with PR #602

- New and improved web api
- Author: @Kyle0654
This commit is contained in:
Lincoln Stein 2022-09-16 16:35:34 -04:00
parent 00d2d0e90e
commit cbac95b02a
14 changed files with 899 additions and 358 deletions

1
.gitignore vendored
View File

@ -191,3 +191,4 @@ checkpoints
.scratch/
.vscode/
gfpgan/
models/ldm/stable-diffusion-v1/model.sha256

View File

@ -51,6 +51,7 @@ We thank them for all of their time and hard work.
- [Any Winter](https://github.com/any-winter-4079)
- [Doggettx](https://github.com/doggettx)
- [Matthias Wild](https://github.com/mauwii)
- [Kyle Schouviller](https://github.com/kyle0654)
## __Original CompVis Authors:__

View File

@ -33,10 +33,11 @@ class PngWriter:
# saves image named _image_ to outdir/name, writing metadata from prompt
# returns full path of output
def save_image_and_prompt_to_png(self, image, dream_prompt, metadata, name):
def save_image_and_prompt_to_png(self, image, dream_prompt, name, metadata=None):
path = os.path.join(self.outdir, name)
info = PngImagePlugin.PngInfo()
info.add_text('Dream', dream_prompt)
if metadata: # TODO: merge command line app's method of writing metadata and always just write metadata
info.add_text('sd-metadata', json.dumps(metadata))
image.save(path, 'PNG', pnginfo=info)
return path

View File

@ -230,7 +230,7 @@ class DreamServer(BaseHTTPRequestHandler):
image = self.model.sample_to_image(sample)
name = f'{prefix}.{opt.seed}.{step_index}.png'
metadata = f'{opt.prompt} -S{opt.seed} [intermediate]'
path = step_writer.save_image_and_prompt_to_png(image, metadata, name)
path = step_writer.save_image_and_prompt_to_png(image, dream_prompt=metadata, name=name)
step_index += 1
self.wfile.write(bytes(json.dumps(
{'event': 'step', 'step': step + 1, 'url': path}

View File

@ -181,7 +181,7 @@ class Generate:
for image, seed in results:
name = f'{prefix}.{seed}.png'
path = pngwriter.save_image_and_prompt_to_png(
image, f'{prompt} -S{seed}', name)
image, dream_prompt=f'{prompt} -S{seed}', name=name)
outputs.append([path, seed])
return outputs

View File

@ -22,6 +22,11 @@ test-tube
torch-fidelity
torchmetrics
transformers
flask==2.1.3
flask_socketio==5.3.0
flask_cors==3.0.10
dependency_injector==4.40.0
eventlet
git+https://github.com/openai/CLIP.git@main#egg=clip
git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k-diffusion
git+https://github.com/lstein/GFPGAN@fix-dark-cast-images#egg=gfpgan

View File

@ -7,9 +7,10 @@ import os
import sys
from flask import Flask
from flask_cors import CORS
from flask_socketio import SocketIO, join_room, leave_room
from flask_socketio import SocketIO
from omegaconf import OmegaConf
from dependency_injector.wiring import inject, Provide
from ldm.dream.args import Args
from server import views
from server.containers import Container
from server.services import GeneratorService, SignalService
@ -58,6 +59,8 @@ def run_app(config, host, port) -> Flask:
# TODO: Get storage root from config
app.add_url_rule('/api/images/<string:dreamId>', view_func=views.ApiImages.as_view('api_images', '../'))
app.add_url_rule('/api/images/<string:dreamId>/metadata', view_func=views.ApiImagesMetadata.as_view('api_images_metadata', '../'))
app.add_url_rule('/api/images', view_func=views.ApiImagesList.as_view('api_images_list'))
app.add_url_rule('/api/intermediates/<string:dreamId>/<string:step>', view_func=views.ApiIntermediates.as_view('api_intermediates', '../'))
app.static_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../static/dream_web/'))
@ -79,30 +82,28 @@ def run_app(config, host, port) -> Flask:
def main():
"""Initialize command-line parsers and the diffusion model"""
from scripts.dream import create_argv_parser
arg_parser = create_argv_parser()
arg_parser = Args()
opt = arg_parser.parse_args()
if opt.laion400m:
print('--laion400m flag has been deprecated. Please use --model laion400m instead.')
sys.exit(-1)
if opt.weights != 'model':
print('--weights argument has been deprecated. Please configure ./configs/models.yaml, and call it using --model instead.')
if opt.weights:
print('--weights argument has been deprecated. Please edit ./configs/models.yaml, and select the weights using --model instead.')
sys.exit(-1)
try:
models = OmegaConf.load(opt.config)
width = models[opt.model].width
height = models[opt.model].height
config = models[opt.model].config
weights = models[opt.model].weights
except (FileNotFoundError, IOError, KeyError) as e:
print(f'{e}. Aborting.')
sys.exit(-1)
# try:
# models = OmegaConf.load(opt.config)
# width = models[opt.model].width
# height = models[opt.model].height
# config = models[opt.model].config
# weights = models[opt.model].weights
# except (FileNotFoundError, IOError, KeyError) as e:
# print(f'{e}. Aborting.')
# sys.exit(-1)
print('* Initializing, be patient...\n')
#print('* Initializing, be patient...\n')
sys.path.append('.')
from pytorch_lightning import logging
# these two lines prevent a horrible warning message from appearing
# when the frozen CLIP tokenizer is imported
@ -110,26 +111,28 @@ def main():
transformers.logging.set_verbosity_error()
appConfig = {
"model": {
"width": width,
"height": height,
"sampler_name": opt.sampler_name,
"weights": weights,
"full_precision": opt.full_precision,
"config": config,
"grid": opt.grid,
"latent_diffusion_weights": opt.laion400m,
"embedding_path": opt.embedding_path,
"device_type": opt.device
}
}
appConfig = opt.__dict__
# appConfig = {
# "model": {
# "width": width,
# "height": height,
# "sampler_name": opt.sampler_name,
# "weights": weights,
# "full_precision": opt.full_precision,
# "config": config,
# "grid": opt.grid,
# "latent_diffusion_weights": opt.laion400m,
# "embedding_path": opt.embedding_path
# }
# }
# make sure the output directory exists
if not os.path.exists(opt.outdir):
os.makedirs(opt.outdir)
# gets rid of annoying messages about random seed
from pytorch_lightning import logging
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
print('\n* starting api server...')

View File

@ -17,18 +17,24 @@ class Container(containers.DeclarativeContainer):
app = None
)
# TODO: Add a model provider service that provides model(s) dynamically
model_singleton = providers.ThreadSafeSingleton(
Generate,
width = config.model.width,
height = config.model.height,
sampler_name = config.model.sampler_name,
weights = config.model.weights,
full_precision = config.model.full_precision,
config = config.model.config,
grid = config.model.grid,
seamless = config.model.seamless,
embedding_path = config.model.embedding_path,
device_type = config.model.device_type
model = config.model,
sampler_name = config.sampler_name,
embedding_path = config.embedding_path,
full_precision = config.full_precision
# config = config.model.config,
# width = config.model.width,
# height = config.model.height,
# sampler_name = config.model.sampler_name,
# weights = config.model.weights,
# full_precision = config.model.full_precision,
# grid = config.model.grid,
# seamless = config.model.seamless,
# embedding_path = config.model.embedding_path,
# device_type = config.model.device_type
)
# TODO: get location from config

View File

@ -1,77 +1,182 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from base64 import urlsafe_b64encode
import json
import string
from copy import deepcopy
from datetime import datetime, timezone
from enum import Enum
from typing import Any, Dict, List, Union
from uuid import uuid4
class DreamRequest():
prompt: string
initimg: string
strength: float
iterations: int
steps: int
width: int
height: int
fit = None
cfgscale: float
sampler_name: string
gfpgan_strength: float
upscale_level: int
upscale_strength: float
class DreamBase():
# Id
id: str
# Initial Image
enable_init_image: bool
initimg: string = None
# Img2Img
enable_img2img: bool # TODO: support this better
strength: float = 0 # TODO: name this something related to img2img to make it clearer?
fit = None # Fit initial image dimensions
# Generation
enable_generate: bool
prompt: string = ""
seed: int = 0 # 0 is random
steps: int = 10
width: int = 512
height: int = 512
cfg_scale: float = 7.5
sampler_name: string = 'klms'
seamless: bool = False
model: str = None # The model to use (currently unused)
embeddings = None # The embeddings to use (currently unused)
progress_images: bool = False
# GFPGAN
enable_gfpgan: bool
gfpgan_strength: float = 0
# Upscale
enable_upscale: bool
upscale: None
progress_images = None
seed: int
upscale_level: int = None
upscale_strength: float = 0.75
# Embiggen
enable_embiggen: bool
embiggen: Union[None, List[float]] = None
embiggen_tiles: Union[None, List[int]] = None
# Metadata
time: int
def __init__(self):
self.id = urlsafe_b64encode(uuid4().bytes).decode('ascii')
def parse_json(self, j, new_instance=False):
# Id
if 'id' in j and not new_instance:
self.id = j.get('id')
# Initial Image
self.enable_init_image = 'enable_init_image' in j and bool(j.get('enable_init_image'))
if self.enable_init_image:
self.initimg = j.get('initimg')
# Img2Img
self.enable_img2img = 'enable_img2img' in j and bool(j.get('enable_img2img'))
if self.enable_img2img:
self.strength = float(j.get('strength'))
self.fit = 'fit' in j
# Generation
self.enable_generate = 'enable_generate' in j and bool(j.get('enable_generate'))
if self.enable_generate:
self.prompt = j.get('prompt')
self.seed = int(j.get('seed'))
self.steps = int(j.get('steps'))
self.width = int(j.get('width'))
self.height = int(j.get('height'))
self.cfg_scale = float(j.get('cfgscale') or j.get('cfg_scale'))
self.sampler_name = j.get('sampler') or j.get('sampler_name')
# model: str = None # The model to use (currently unused)
# embeddings = None # The embeddings to use (currently unused)
self.seamless = 'seamless' in j
self.progress_images = 'progress_images' in j
# GFPGAN
self.enable_gfpgan = 'enable_gfpgan' in j and bool(j.get('enable_gfpgan'))
if self.enable_gfpgan:
self.gfpgan_strength = float(j.get('gfpgan_strength'))
# Upscale
self.enable_upscale = 'enable_upscale' in j and bool(j.get('enable_upscale'))
if self.enable_upscale:
self.upscale_level = j.get('upscale_level')
self.upscale_strength = j.get('upscale_strength')
self.upscale = None if self.upscale_level in {None,''} else [int(self.upscale_level),float(self.upscale_strength)]
# Embiggen
self.enable_embiggen = 'enable_embiggen' in j and bool(j.get('enable_embiggen'))
if self.enable_embiggen:
self.embiggen = j.get('embiggen')
self.embiggen_tiles = j.get('embiggen_tiles')
# Metadata
self.time = int(j.get('time')) if ('time' in j and not new_instance) else int(datetime.now(timezone.utc).timestamp())
class DreamResult(DreamBase):
# Result
has_upscaled: False
has_gfpgan: False
# TODO: use something else for state tracking
images_generated: int = 0
images_upscaled: int = 0
def id(self, seed = None, upscaled = False) -> str:
return f"{self.time}.{seed or self.seed}{'.u' if upscaled else ''}"
def __init__(self):
super().__init__()
# TODO: handle this more cleanly (probably by splitting this into a Job and Result class)
# TODO: Set iterations to 1 or remove it from the dream result? And just keep it on the job?
def clone_without_image(self, seed = None):
data = deepcopy(self)
data.initimg = None
if seed:
data.seed = seed
def clone_without_img(self):
copy = deepcopy(self)
copy.initimg = None
return copy
return data
def to_json(self, seed: int = None):
copy = self.clone_without_image(seed)
return json.dumps(copy.__dict__)
def to_json(self):
copy = deepcopy(self)
copy.initimg = None
j = json.dumps(copy.__dict__)
return j
@staticmethod
def from_json(j, newTime: bool = False):
d = DreamRequest()
d.prompt = j.get('prompt')
d.initimg = j.get('initimg')
d.strength = float(j.get('strength'))
d.iterations = int(j.get('iterations'))
d.steps = int(j.get('steps'))
d.width = int(j.get('width'))
d.height = int(j.get('height'))
d.fit = 'fit' in j
d.seamless = 'seamless' in j
d.cfgscale = float(j.get('cfgscale'))
d.sampler_name = j.get('sampler')
d.variation_amount = float(j.get('variation_amount'))
d.with_variations = j.get('with_variations')
d.gfpgan_strength = float(j.get('gfpgan_strength'))
d.upscale_level = j.get('upscale_level')
d.upscale_strength = j.get('upscale_strength')
d.upscale = [int(d.upscale_level),float(d.upscale_strength)] if d.upscale_level != '' else None
d.progress_images = 'progress_images' in j
d.seed = int(j.get('seed'))
d.time = int(datetime.now(timezone.utc).timestamp()) if newTime else int(j.get('time'))
d = DreamResult()
d.parse_json(j)
return d
# TODO: switch this to a pipelined request, with pluggable steps
# Will likely require generator code changes to accomplish
class JobRequest(DreamBase):
# Iteration
iterations: int = 1
variation_amount = None
with_variations = None
# Results
results: List[DreamResult] = []
def __init__(self):
super().__init__()
def newDreamResult(self) -> DreamResult:
result = DreamResult()
result.parse_json(self.__dict__, new_instance=True)
return result
@staticmethod
def from_json(j):
job = JobRequest()
job.parse_json(j)
# Metadata
job.time = int(j.get('time')) if ('time' in j) else int(datetime.now(timezone.utc).timestamp())
# Iteration
if job.enable_generate:
job.iterations = int(j.get('iterations'))
job.variation_amount = float(j.get('variation_amount'))
job.with_variations = j.get('with_variations')
return job
class ProgressType(Enum):
GENERATION = 1
UPSCALING_STARTED = 2
@ -102,11 +207,11 @@ class Signal():
# TODO: use a result id or something? Like a sub-job
@staticmethod
def image_result(jobId: str, dreamId: str, dreamRequest: DreamRequest):
def image_result(jobId: str, dreamId: str, dreamResult: DreamResult):
return Signal('dream_result', {
'jobId': jobId,
'dreamId': dreamId,
'dreamRequest': dreamRequest.__dict__
'dreamRequest': dreamResult.clone_without_img().__dict__
}, room=jobId, broadcast=True)
@staticmethod
@ -126,3 +231,21 @@ class Signal():
return Signal('job_canceled', {
'jobId': jobId
}, room=jobId, broadcast=True)
class PaginatedItems():
items: List[Any]
page: int # Current Page
pages: int # Total number of pages
per_page: int # Number of items per page
total: int # Total number of items in result
def __init__(self, items: List[Any], page: int, pages: int, per_page: int, total: int):
self.items = items
self.page = page
self.pages = pages
self.per_page = per_page
self.total = total
def to_json(self):
return json.dumps(self.__dict__)

View File

@ -1,25 +1,33 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from argparse import ArgumentParser
import base64
from datetime import datetime, timezone
import glob
import json
import os
from pathlib import Path
from queue import Empty, Queue
import shlex
from threading import Thread
import time
from flask import app, url_for
from flask_socketio import SocketIO, join_room, leave_room
from ldm.dream.args import Args
from ldm.dream.generator import embiggen
from PIL import Image
from ldm.dream.pngwriter import PngWriter
from ldm.dream.server import CanceledException
from ldm.generate import Generate
from server.models import DreamRequest, ProgressType, Signal
from server.models import DreamResult, JobRequest, PaginatedItems, ProgressType, Signal
class JobQueueService:
__queue: Queue = Queue()
def push(self, dreamRequest: DreamRequest):
def push(self, dreamRequest: DreamResult):
self.__queue.put(dreamRequest)
def get(self, timeout: float = None) -> DreamRequest:
def get(self, timeout: float = None) -> DreamResult:
return self.__queue.get(timeout= timeout)
class SignalQueueService:
@ -85,25 +93,28 @@ class LogService:
self.__location = location
self.__logFile = file
def log(self, dreamRequest: DreamRequest, seed = None, upscaled = False):
def log(self, dreamResult: DreamResult, seed = None, upscaled = False):
with open(os.path.join(self.__location, self.__logFile), "a") as log:
log.write(f"{dreamRequest.id(seed, upscaled)}: {dreamRequest.to_json(seed)}\n")
log.write(f"{dreamResult.id}: {dreamResult.to_json()}\n")
class ImageStorageService:
__location: str
__pngWriter: PngWriter
__legacyParser: ArgumentParser
def __init__(self, location):
self.__location = location
self.__pngWriter = PngWriter(self.__location)
self.__legacyParser = Args() # TODO: inject this?
def __getName(self, dreamId: str, postfix: str = '') -> str:
return f'{dreamId}{postfix}.png'
def save(self, image, dreamRequest, seed = None, upscaled = False, postfix: str = '', metadataPostfix: str = '') -> str:
name = self.__getName(dreamRequest.id(seed, upscaled), postfix)
path = self.__pngWriter.save_image_and_prompt_to_png(image, f'{dreamRequest.prompt} -S{seed or dreamRequest.seed}{metadataPostfix}', name)
def save(self, image, dreamResult: DreamResult, postfix: str = '') -> str:
name = self.__getName(dreamResult.id, postfix)
meta = dreamResult.to_json() # TODO: make all methods consistent with writing metadata. Standardize metadata.
path = self.__pngWriter.save_image_and_prompt_to_png(image, dream_prompt=meta, metadata=None, name=name)
return path
def path(self, dreamId: str, postfix: str = '') -> str:
@ -111,6 +122,88 @@ class ImageStorageService:
path = os.path.join(self.__location, name)
return path
# Returns true if found, false if not found or error
def delete(self, dreamId: str, postfix: str = '') -> bool:
path = self.path(dreamId, postfix)
if (os.path.exists(path)):
os.remove(path)
return True
else:
return False
def getMetadata(self, dreamId: str, postfix: str = '') -> DreamResult:
path = self.path(dreamId, postfix)
image = Image.open(path)
text = image.text
if text.__contains__('Dream'):
dreamMeta = text.get('Dream')
try:
j = json.loads(dreamMeta)
return DreamResult.from_json(j)
except ValueError:
# Try to parse command-line format (legacy metadata format)
try:
opt = self.__parseLegacyMetadata(dreamMeta)
optd = opt.__dict__
if (not 'width' in optd) or (optd.get('width') is None):
optd['width'] = image.width
if (not 'height' in optd) or (optd.get('height') is None):
optd['height'] = image.height
if (not 'steps' in optd) or (optd.get('steps') is None):
optd['steps'] = 10 # No way around this unfortunately - seems like it wasn't storing this previously
optd['time'] = os.path.getmtime(path) # Set timestamp manually (won't be exactly correct though)
return DreamResult.from_json(optd)
except:
return None
else:
return None
def __parseLegacyMetadata(self, command: str) -> DreamResult:
# before splitting, escape single quotes so as not to mess
# up the parser
command = command.replace("'", "\\'")
try:
elements = shlex.split(command)
except ValueError as e:
return None
# rearrange the arguments to mimic how it works in the Dream bot.
switches = ['']
switches_started = False
for el in elements:
if el[0] == '-' and not switches_started:
switches_started = True
if switches_started:
switches.append(el)
else:
switches[0] += el
switches[0] += ' '
switches[0] = switches[0][: len(switches[0]) - 1]
try:
opt = self.__legacyParser.parse_cmd(switches)
return opt
except SystemExit:
return None
def list_files(self, page: int, perPage: int) -> PaginatedItems:
files = sorted(glob.glob(os.path.join(self.__location,'*.png')), key=os.path.getmtime, reverse=True)
count = len(files)
startId = page * perPage
pageCount = int(count / perPage) + 1
endId = min(startId + perPage, count)
items = [] if startId >= count else files[startId:endId]
items = list(map(lambda f: Path(f).stem, items))
return PaginatedItems(items, page, pageCount, perPage, count)
class GeneratorService:
__model: Generate
@ -144,13 +237,11 @@ class GeneratorService:
# TODO: Consider moving this to its own service if there's benefit in separating the generator
def __process(self):
# preload the model
# TODO: support multiple models
print('Preloading model')
tic = time.time()
self.__model.load_model()
print(
f'>> model loaded in', '%4.2fs' % (time.time() - tic)
)
print(f'>> model loaded in', '%4.2fs' % (time.time() - tic))
print('Started generation queue processor')
try:
@ -162,103 +253,136 @@ class GeneratorService:
print('Generation queue processor stopped')
def __start(self, dreamRequest: DreamRequest):
if dreamRequest.start_callback:
dreamRequest.start_callback()
self.__signal_service.emit(Signal.job_started(dreamRequest.id()))
def __on_start(self, jobRequest: JobRequest):
self.__signal_service.emit(Signal.job_started(jobRequest.id))
def __done(self, dreamRequest: DreamRequest, image, seed, upscaled=False):
self.__imageStorage.save(image, dreamRequest, seed, upscaled)
def __on_image_result(self, jobRequest: JobRequest, image, seed, upscaled=False):
dreamResult = jobRequest.newDreamResult()
dreamResult.seed = seed
dreamResult.has_upscaled = upscaled
dreamResult.iterations = 1
jobRequest.results.append(dreamResult)
# TODO: Separate status of GFPGAN?
self.__imageStorage.save(image, dreamResult)
# TODO: handle upscaling logic better (this is appending data to log, but only on first generation)
if not upscaled:
self.__log.log(dreamRequest, seed, upscaled)
self.__log.log(dreamResult)
self.__signal_service.emit(Signal.image_result(dreamRequest.id(), dreamRequest.id(seed, upscaled), dreamRequest.clone_without_image(seed)))
# Send result signal
self.__signal_service.emit(Signal.image_result(jobRequest.id, dreamResult.id, dreamResult))
upscaling_requested = dreamRequest.upscale or dreamRequest.gfpgan_strength>0
upscaling_requested = dreamResult.enable_upscale or dreamResult.enable_gfpgan
if upscaled:
dreamRequest.images_upscaled += 1
else:
dreamRequest.images_generated +=1
if upscaling_requested:
# action = None
if dreamRequest.images_generated >= dreamRequest.iterations:
progressType = ProgressType.UPSCALING_DONE
if dreamRequest.images_upscaled < dreamRequest.iterations:
progressType = ProgressType.UPSCALING_STARTED
self.__signal_service.emit(Signal.image_progress(dreamRequest.id(), dreamRequest.id(seed), dreamRequest.images_upscaled+1, dreamRequest.iterations, progressType))
# Report upscaling status
# TODO: this is very coupled to logic inside the generator. Fix that.
if upscaling_requested and any(result.has_upscaled for result in jobRequest.results):
progressType = ProgressType.UPSCALING_STARTED if len(jobRequest.results) < 2 * jobRequest.iterations else ProgressType.UPSCALING_DONE
upscale_count = sum(1 for i in jobRequest.results if i.has_upscaled)
self.__signal_service.emit(Signal.image_progress(jobRequest.id, dreamResult.id, upscale_count, jobRequest.iterations, progressType))
def __progress(self, dreamRequest, sample, step):
def __on_progress(self, jobRequest: JobRequest, sample, step):
if self.__cancellationRequested:
self.__cancellationRequested = False
raise CanceledException
# TODO: Progress per request will be easier once the seeds (and ids) can all be pre-generated
hasProgressImage = False
if dreamRequest.progress_images and step % 5 == 0 and step < dreamRequest.steps - 1:
s = str(len(jobRequest.results))
if jobRequest.progress_images and step % 5 == 0 and step < jobRequest.steps - 1:
image = self.__model._sample_to_image(sample)
self.__intermediateStorage.save(image, dreamRequest, self.__model.seed, postfix=f'.{step}', metadataPostfix=f' [intermediate]')
# TODO: clean this up, use a pre-defined dream result
result = DreamResult()
result.parse_json(jobRequest.__dict__, new_instance=False)
self.__intermediateStorage.save(image, result, postfix=f'.{s}.{step}')
hasProgressImage = True
self.__signal_service.emit(Signal.image_progress(dreamRequest.id(), dreamRequest.id(self.__model.seed), step, dreamRequest.steps, ProgressType.GENERATION, hasProgressImage))
self.__signal_service.emit(Signal.image_progress(jobRequest.id, f'{jobRequest.id}.{s}', step, jobRequest.steps, ProgressType.GENERATION, hasProgressImage))
def __generate(self, dreamRequest: DreamRequest):
def __generate(self, jobRequest: JobRequest):
try:
initimgfile = None
if dreamRequest.initimg is not None:
# TODO: handle this file a file service for init images
initimgfile = None # TODO: support this on the model directly?
if (jobRequest.enable_init_image):
if jobRequest.initimg is not None:
with open("./img2img-tmp.png", "wb") as f:
initimg = dreamRequest.initimg.split(",")[1] # Ignore mime type
initimg = jobRequest.initimg.split(",")[1] # Ignore mime type
f.write(base64.b64decode(initimg))
initimgfile = "./img2img-tmp.png"
# Get a random seed if we don't have one yet
# TODO: handle "previous" seed usage?
if dreamRequest.seed == -1:
dreamRequest.seed = self.__model.seed
# Use previous seed if set to -1
initSeed = jobRequest.seed
if initSeed == -1:
initSeed = self.__model.seed
# Zero gfpgan strength if the model doesn't exist
# TODO: determine if this could be at the top now? Used to cause circular import
from ldm.gfpgan.gfpgan_tools import gfpgan_model_exists
if not gfpgan_model_exists:
dreamRequest.gfpgan_strength = 0
jobRequest.enable_gfpgan = False
# Signal start
self.__on_start(jobRequest)
# Generate in model
# TODO: Split job generation requests instead of fitting all parameters here
# TODO: Support no generation (just upscaling/gfpgan)
upscale = None if not jobRequest.enable_upscale else jobRequest.upscale
gfpgan_strength = 0 if not jobRequest.enable_gfpgan else jobRequest.gfpgan_strength
if not jobRequest.enable_generate:
# If not generating, check if we're upscaling or running gfpgan
if not upscale and not gfpgan_strength:
# Invalid settings (TODO: Add message to help user)
raise CanceledException()
image = Image.open(initimgfile)
# TODO: support progress for upscale?
self.__model.upscale_and_reconstruct(
image_list = [[image,0]],
upscale = upscale,
strength = gfpgan_strength,
save_original = False,
image_callback = lambda image, seed, upscaled=False: self.__on_image_result(jobRequest, image, seed, upscaled))
else:
# Generating - run the generation
init_img = None if (not jobRequest.enable_img2img or jobRequest.strength == 0) else initimgfile
self.__start(dreamRequest)
self.__model.prompt2image(
prompt = dreamRequest.prompt,
init_img = initimgfile, # TODO: ensure this works
strength = None if initimgfile is None else dreamRequest.strength,
fit = None if initimgfile is None else dreamRequest.fit,
iterations = dreamRequest.iterations,
cfg_scale = dreamRequest.cfgscale,
width = dreamRequest.width,
height = dreamRequest.height,
seed = dreamRequest.seed,
steps = dreamRequest.steps,
variation_amount = dreamRequest.variation_amount,
with_variations = dreamRequest.with_variations,
gfpgan_strength = dreamRequest.gfpgan_strength,
upscale = dreamRequest.upscale,
sampler_name = dreamRequest.sampler_name,
seamless = dreamRequest.seamless,
step_callback = lambda sample, step: self.__progress(dreamRequest, sample, step),
image_callback = lambda image, seed, upscaled=False: self.__done(dreamRequest, image, seed, upscaled))
prompt = jobRequest.prompt,
init_img = init_img, # TODO: ensure this works
strength = None if init_img is None else jobRequest.strength,
fit = None if init_img is None else jobRequest.fit,
iterations = jobRequest.iterations,
cfg_scale = jobRequest.cfg_scale,
width = jobRequest.width,
height = jobRequest.height,
seed = jobRequest.seed,
steps = jobRequest.steps,
variation_amount = jobRequest.variation_amount,
with_variations = jobRequest.with_variations,
gfpgan_strength = gfpgan_strength,
upscale = upscale,
sampler_name = jobRequest.sampler_name,
seamless = jobRequest.seamless,
embiggen = jobRequest.embiggen,
embiggen_tiles = jobRequest.embiggen_tiles,
step_callback = lambda sample, step: self.__on_progress(jobRequest, sample, step),
image_callback = lambda image, seed, upscaled=False: self.__on_image_result(jobRequest, image, seed, upscaled))
except CanceledException:
if dreamRequest.cancelled_callback:
dreamRequest.cancelled_callback()
self.__signal_service.emit(Signal.job_canceled(dreamRequest.id()))
self.__signal_service.emit(Signal.job_canceled(jobRequest.id))
finally:
if dreamRequest.done_callback:
dreamRequest.done_callback()
self.__signal_service.emit(Signal.job_done(dreamRequest.id()))
self.__signal_service.emit(Signal.job_done(jobRequest.id))
# Remove the temp file
if (initimgfile is not None):

View File

@ -8,7 +8,7 @@ from flask import current_app, jsonify, request, Response, send_from_directory,
from flask.views import MethodView
from dependency_injector.wiring import inject, Provide
from server.models import DreamRequest
from server.models import DreamResult, JobRequest
from server.services import GeneratorService, ImageStorageService, JobQueueService
from server.containers import Container
@ -16,23 +16,14 @@ class ApiJobs(MethodView):
@inject
def post(self, job_queue_service: JobQueueService = Provide[Container.generation_queue_service]):
dreamRequest = DreamRequest.from_json(request.json, newTime = True)
jobRequest = JobRequest.from_json(request.json)
#self.canceled.clear()
print(f">> Request to generate with prompt: {dreamRequest.prompt}")
q = Queue()
dreamRequest.start_callback = None
dreamRequest.image_callback = None
dreamRequest.progress_callback = None
dreamRequest.cancelled_callback = None
dreamRequest.done_callback = None
print(f">> Request to generate with prompt: {jobRequest.prompt}")
# Push the request
job_queue_service.push(dreamRequest)
job_queue_service.push(jobRequest)
return { 'dreamId': dreamRequest.id() }
return { 'jobId': jobRequest.id }
class WebIndex(MethodView):
@ -68,6 +59,7 @@ class ApiCancel(MethodView):
return Response(status=204)
# TODO: Combine all image storage access
class ApiImages(MethodView):
init_every_request = False
__pathRoot = None
@ -83,6 +75,27 @@ class ApiImages(MethodView):
fullpath=os.path.join(self.__pathRoot, name)
return send_from_directory(os.path.dirname(fullpath), os.path.basename(fullpath))
def delete(self, dreamId):
result = self.__storage.delete(dreamId)
return Response(status=204) if result else Response(status=404)
class ApiImagesMetadata(MethodView):
init_every_request = False
__pathRoot = None
__storage: ImageStorageService
@inject
def __init__(self, pathBase, storage: ImageStorageService = Provide[Container.image_storage_service]):
self.__pathRoot = os.path.abspath(os.path.join(os.path.dirname(__file__), pathBase))
self.__storage = storage
def get(self, dreamId):
meta = self.__storage.getMetadata(dreamId)
j = {} if meta is None else meta.__dict__
return j
class ApiIntermediates(MethodView):
init_every_request = False
__pathRoot = None
@ -97,3 +110,23 @@ class ApiIntermediates(MethodView):
name = self.__storage.path(dreamId, postfix=f'.{step}')
fullpath=os.path.join(self.__pathRoot, name)
return send_from_directory(os.path.dirname(fullpath), os.path.basename(fullpath))
def delete(self, dreamId):
result = self.__storage.delete(dreamId)
return Response(status=204) if result else Response(status=404)
class ApiImagesList(MethodView):
init_every_request = False
__storage: ImageStorageService
@inject
def __init__(self, storage: ImageStorageService = Provide[Container.image_storage_service]):
self.__storage = storage
def get(self):
page = request.args.get("page", default=0, type=int)
perPage = request.args.get("per_page", default=10, type=int)
result = self.__storage.list_files(page, perPage)
return result.__dict__

View File

@ -1,3 +1,8 @@
:root {
--fields-dark:#DCDCDC;
--fields-light:#F5F5F5;
}
* {
font-family: 'Arial';
font-size: 100%;
@ -18,15 +23,26 @@ fieldset {
border: none;
line-height: 2.2em;
}
fieldset > legend {
width: auto;
margin-left: 0;
margin-right: auto;
font-weight:bold;
}
select, input {
margin-right: 10px;
padding: 2px;
}
input:disabled {
cursor:auto;
}
input[type=submit] {
cursor: pointer;
background-color: #666;
color: white;
}
input[type=checkbox] {
cursor: pointer;
margin-right: 0px;
width: 20px;
height: 20px;
@ -87,11 +103,11 @@ header h1 {
}
#results img {
border-radius: 5px;
object-fit: cover;
object-fit: contain;
background-color: var(--fields-dark);
}
#fieldset-config {
line-height:2em;
background-color: #F0F0F0;
}
input[type="number"] {
width: 60px;
@ -118,35 +134,46 @@ label {
#progress-image {
width: 30vh;
height: 30vh;
object-fit: contain;
background-color: var(--fields-dark);
}
#cancel-button {
cursor: pointer;
color: red;
}
#basic-parameters {
background-color: #EEEEEE;
}
#txt2img {
background-color: #DCDCDC;
background-color: var(--fields-dark);
}
#variations {
background-color: #EEEEEE;
background-color: var(--fields-light);
}
#initimg {
background-color: var(--fields-dark);
}
#img2img {
background-color: #DCDCDC;
background-color: var(--fields-light);
}
#gfpgan {
background-color: #EEEEEE;
#initimg > :not(legend) {
background-color: var(--fields-light);
margin: .5em;
}
#postprocess, #initimg {
display:flex;
flex-wrap:wrap;
padding: 0;
margin-top: 1em;
background-color: var(--fields-dark);
}
#postprocess > fieldset, #initimg > * {
flex-grow: 1;
}
#postprocess > fieldset {
background-color: var(--fields-dark);
}
#progress-section {
background-color: #F5F5F5;
}
.section-header {
text-align: left;
font-weight: bold;
padding: 0 0 0 0;
background-color: var(--fields-light);
}
#no-results-message:not(:only-child) {
display: none;
}

View File

@ -1,41 +1,50 @@
<html lang="en">
<head>
<head>
<title>Stable Diffusion Dream Server</title>
<meta charset="utf-8">
<link rel="icon" href="data:,">
<link rel="icon" type="image/x-icon" href="static/dream_web/favicon.ico" />
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<link rel="stylesheet" href="index.css">
<script src="config.js"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/socket.io/4.0.1/socket.io.js" integrity="sha512-q/dWJ3kcmjBLU4Qc47E4A9kTB4m3wuTY7vkFJDTZKjTs8jhyGQnaUrxa0Ytd0ssMZhbNua9hE+E7Qv1j+DyZwA==" crossorigin="anonymous"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/socket.io/4.0.1/socket.io.js"
integrity="sha512-q/dWJ3kcmjBLU4Qc47E4A9kTB4m3wuTY7vkFJDTZKjTs8jhyGQnaUrxa0Ytd0ssMZhbNua9hE+E7Qv1j+DyZwA=="
crossorigin="anonymous"></script>
<link rel="stylesheet" href="index.css">
<script src="index.js"></script>
</head>
<body>
</head>
<body>
<header>
<h1>Stable Diffusion Dream Server</h1>
<div id="about">
For news and support for this web service, visit our <a href="http://github.com/lstein/stable-diffusion">GitHub site</a>
For news and support for this web service, visit our <a href="http://github.com/lstein/stable-diffusion">GitHub
site</a>
</div>
</header>
<main>
<!--
<div id="dropper" style="background-color:red;width:200px;height:200px;">
</div>
-->
<form id="generate-form" method="post" action="api/jobs">
<fieldset id="txt2img">
<legend>
<input type="checkbox" name="enable_generate" id="enable_generate" checked>
<label for="enable_generate">Generate</label>
</legend>
<div id="search-box">
<textarea rows="3" id="prompt" name="prompt"></textarea>
<input type="submit" id="submit" value="Generate">
</div>
</fieldset>
<fieldset id="fieldset-config">
<div class="section-header">Basic options</div>
<label for="iterations">Images to generate:</label>
<input value="1" type="number" id="iterations" name="iterations" size="4">
<label for="steps">Steps:</label>
<input value="50" type="number" id="steps" name="steps">
<label for="cfgscale">Cfg Scale:</label>
<input value="7.5" type="number" id="cfgscale" name="cfgscale" step="any">
<label for="sampler">Sampler:</label>
<select id="sampler" name="sampler" value="k_lms">
<label for="cfg_scale">Cfg Scale:</label>
<input value="7.5" type="number" id="cfg_scale" name="cfg_scale" step="any">
<label for="sampler_name">Sampler:</label>
<select id="sampler_name" name="sampler_name" value="k_lms">
<option value="ddim">DDIM</option>
<option value="plms">PLMS</option>
<option value="k_lms" selected>KLMS</option>
@ -50,25 +59,41 @@
<br>
<label title="Set to multiple of 64" for="width">Width:</label>
<select id="width" name="width" value="512">
<option value="64">64</option> <option value="128">128</option>
<option value="192">192</option> <option value="256">256</option>
<option value="320">320</option> <option value="384">384</option>
<option value="448">448</option> <option value="512" selected>512</option>
<option value="576">576</option> <option value="640">640</option>
<option value="704">704</option> <option value="768">768</option>
<option value="832">832</option> <option value="896">896</option>
<option value="960">960</option> <option value="1024">1024</option>
<option value="64">64</option>
<option value="128">128</option>
<option value="192">192</option>
<option value="256">256</option>
<option value="320">320</option>
<option value="384">384</option>
<option value="448">448</option>
<option value="512" selected>512</option>
<option value="576">576</option>
<option value="640">640</option>
<option value="704">704</option>
<option value="768">768</option>
<option value="832">832</option>
<option value="896">896</option>
<option value="960">960</option>
<option value="1024">1024</option>
</select>
<label title="Set to multiple of 64" for="height">Height:</label>
<select id="height" name="height" value="512">
<option value="64">64</option> <option value="128">128</option>
<option value="192">192</option> <option value="256">256</option>
<option value="320">320</option> <option value="384">384</option>
<option value="448">448</option> <option value="512" selected>512</option>
<option value="576">576</option> <option value="640">640</option>
<option value="704">704</option> <option value="768">768</option>
<option value="832">832</option> <option value="896">896</option>
<option value="960">960</option> <option value="1024">1024</option>
<option value="64">64</option>
<option value="128">128</option>
<option value="192">192</option>
<option value="256">256</option>
<option value="320">320</option>
<option value="384">384</option>
<option value="448">448</option>
<option value="512" selected>512</option>
<option value="576">576</option>
<option value="640">640</option>
<option value="704">704</option>
<option value="768">768</option>
<option value="832">832</option>
<option value="896">896</option>
<option value="960">960</option>
<option value="1024">1024</option>
</select>
<label title="Set to 0 for random seed" for="seed">Seed:</label>
<input value="0" type="number" id="seed" name="seed">
@ -76,29 +101,52 @@
<input type="checkbox" name="progress_images" id="progress_images">
<label for="progress_images">Display in-progress images (slower)</label>
<button type="button" id="reset-all">Reset to Defaults</button>
<span id="variations">
<label title="If > 0, generates variations on the initial seed instead of random seeds per iteration. Must be between 0 and 1. Higher values will be more different." for="variation_amount">Variation amount (0 to disable):</label>
<div id="variations">
<label
title="If > 0, generates variations on the initial seed instead of random seeds per iteration. Must be between 0 and 1. Higher values will be more different."
for="variation_amount">Variation amount (0 to disable):</label>
<input value="0" type="number" id="variation_amount" name="variation_amount" step="0.01" min="0" max="1">
<label title="list of variations to apply, in the format `seed:weight,seed:weight,..." for="with_variations">With variations (seed:weight,seed:weight,...):</label>
<label title="list of variations to apply, in the format `seed:weight,seed:weight,..."
for="with_variations">With variations (seed:weight,seed:weight,...):</label>
<input value="" type="text" id="with_variations" name="with_variations">
</span>
</div>
</fieldset>
<fieldset id="img2img">
<div class="section-header">Image-to-image options</div>
<fieldset id="initimg">
<legend>
<input type="checkbox" name="enable_init_image" id="enable_init_image" checked>
<label for="enable_init_image">Enable init image</label>
</legend>
<div>
<label title="Upload an image to use img2img" for="initimg">Initial image:</label>
<input type="file" id="initimg" name="initimg" accept=".jpg, .jpeg, .png">
<button type="button" id="remove-image">Remove Image</button>
<br>
</div>
<fieldset id="img2img">
<legend>
<input type="checkbox" name="enable_img2img" id="enable_img2img" checked>
<label for="enable_img2img">Enable Img2Img</label>
</legend>
<label for="strength">Img2Img Strength:</label>
<input value="0.75" type="number" id="strength" name="strength" step="0.01" min="0" max="1">
<input type="checkbox" id="fit" name="fit" checked>
<label title="Rescale image to fit within requested width and height" for="fit">Fit to width/height</label>
<label title="Rescale image to fit within requested width and height" for="fit">Fit to width/height:</label>
</fieldset>
</fieldset>
<div id="postprocess">
<fieldset id="gfpgan">
<div class="section-header">Post-processing options</div>
<label title="Strength of the gfpgan (face fixing) algorithm." for="gfpgan_strength">GPFGAN Strength (0 to disable):</label>
<input value="0.0" min="0" max="1" type="number" id="gfpgan_strength" name="gfpgan_strength" step="0.1">
<label title="Upscaling to perform using ESRGAN." for="upscale_level">Upscaling Level</label>
<legend>
<input type="checkbox" name="enable_gfpgan" id="enable_gfpgan">
<label for="enable_gfpgan">Enable gfpgan</label>
</legend>
<label title="Strength of the gfpgan (face fixing) algorithm." for="gfpgan_strength">GPFGAN Strength:</label>
<input value="0.8" min="0" max="1" type="number" id="gfpgan_strength" name="gfpgan_strength" step="0.05">
</fieldset>
<fieldset id="upscale">
<legend>
<input type="checkbox" name="enable_upscale" id="enable_upscale">
<label for="enable_upscale">Enable Upscaling</label>
</legend>
<label title="Upscaling to perform using ESRGAN." for="upscale_level">Upscaling Level:</label>
<select id="upscale_level" name="upscale_level" value="">
<option value="" selected>None</option>
<option value="2">2x</option>
@ -107,6 +155,8 @@
<label title="Strength of the esrgan (upscaling) algorithm." for="upscale_strength">Upscale Strength:</label>
<input value="0.75" min="0" max="1" type="number" id="upscale_strength" name="upscale_strength" step="0.05">
</fieldset>
</div>
<input type="submit" id="submit" value="Generate">
</form>
<br>
<section id="progress-section">
@ -118,14 +168,12 @@
<div id="scaling-inprocess-message">
<i><span>Postprocessing...</span><span id="processing_cnt">1</span>/<span id="processing_total">3</span></i>
</div>
</span>
</div>
</section>
<div id="results">
<div id="no-results-message">
<i><p>No results...</p></i>
</div>
</div>
</main>
</body>
</body>
</html>

View File

@ -1,5 +1,73 @@
const socket = io();
var priorResultsLoadState = {
page: 0,
pages: 1,
per_page: 10,
total: 20,
offset: 0, // number of items generated since last load
loading: false,
initialized: false
};
function loadPriorResults() {
// Fix next page by offset
let offsetPages = priorResultsLoadState.offset / priorResultsLoadState.per_page;
priorResultsLoadState.page += offsetPages;
priorResultsLoadState.pages += offsetPages;
priorResultsLoadState.total += priorResultsLoadState.offset;
priorResultsLoadState.offset = 0;
if (priorResultsLoadState.loading) {
return;
}
if (priorResultsLoadState.page >= priorResultsLoadState.pages) {
return; // Nothing more to load
}
// Load
priorResultsLoadState.loading = true
let url = new URL('/api/images', document.baseURI);
url.searchParams.append('page', priorResultsLoadState.initialized ? priorResultsLoadState.page + 1 : priorResultsLoadState.page);
url.searchParams.append('per_page', priorResultsLoadState.per_page);
fetch(url.href, {
method: 'GET',
headers: new Headers({'content-type': 'application/json'})
})
.then(response => response.json())
.then(data => {
priorResultsLoadState.page = data.page;
priorResultsLoadState.pages = data.pages;
priorResultsLoadState.per_page = data.per_page;
priorResultsLoadState.total = data.total;
data.items.forEach(function(dreamId, index) {
let src = 'api/images/' + dreamId;
fetch('/api/images/' + dreamId + '/metadata', {
method: 'GET',
headers: new Headers({'content-type': 'application/json'})
})
.then(response => response.json())
.then(metadata => {
let seed = metadata.seed || 0; // TODO: Parse old metadata
appendOutput(src, seed, metadata, true);
});
});
// Load until page is full
if (!priorResultsLoadState.initialized) {
if (document.body.scrollHeight <= window.innerHeight) {
loadPriorResults();
}
}
})
.finally(() => {
priorResultsLoadState.loading = false;
priorResultsLoadState.initialized = true;
});
}
function resetForm() {
var form = document.getElementById('generate-form');
form.querySelector('fieldset').removeAttribute('disabled');
@ -45,48 +113,64 @@ function toBase64(file) {
});
}
function appendOutput(src, seed, config) {
let outputNode = document.createElement("figure");
let altText = seed.toString() + " | " + config.prompt;
function ondragdream(event) {
let dream = event.target.dataset.dream;
event.dataTransfer.setData("dream", dream);
}
const figureContents = `
<a href="${src}" target="_blank">
<img src="${src}" alt="${altText}" title="${altText}">
</a>
<figcaption>${seed}</figcaption>
`;
function seedClick(event) {
// Get element
var image = event.target.closest('figure').querySelector('img');
var dream = JSON.parse(decodeURIComponent(image.dataset.dream));
outputNode.innerHTML = figureContents;
let figcaption = outputNode.querySelector('figcaption')
// Reload image config
figcaption.addEventListener('click', () => {
let form = document.querySelector("#generate-form");
for (const [k, v] of new FormData(form)) {
if (k == 'initimg') { continue; }
form.querySelector(`*[name=${k}]`).value = config[k];
}
if (config.variation_amount > 0 || config.with_variations != '') {
document.querySelector("#seed").value = config.seed;
} else {
document.querySelector("#seed").value = seed;
let formElem = form.querySelector(`*[name=${k}]`);
formElem.value = dream[k] !== undefined ? dream[k] : formElem.defaultValue;
}
if (config.variation_amount > 0) {
let oldVarAmt = document.querySelector("#variation_amount").value
let oldVariations = document.querySelector("#with_variations").value
let varSep = ''
document.querySelector("#variation_amount").value = 0;
if (document.querySelector("#with_variations").value != '') {
varSep = ","
}
document.querySelector("#with_variations").value = oldVariations + varSep + seed + ':' + config.variation_amount
}
document.querySelector("#seed").value = dream.seed;
document.querySelector('#iterations').value = 1; // Reset to 1 iteration since we clicked a single image (not a full job)
// NOTE: leaving this manual for the user for now - it was very confusing with this behavior
// document.querySelector("#with_variations").value = variations || '';
// if (document.querySelector("#variation_amount").value <= 0) {
// document.querySelector("#variation_amount").value = 0.2;
// }
saveFields(document.querySelector("#generate-form"));
});
}
function appendOutput(src, seed, config, toEnd=false) {
let outputNode = document.createElement("figure");
let altText = seed.toString() + " | " + config.prompt;
// img needs width and height for lazy loading to work
// TODO: store the full config in a data attribute on the image?
const figureContents = `
<a href="${src}" target="_blank">
<img src="${src}"
alt="${altText}"
title="${altText}"
loading="lazy"
width="256"
height="256"
draggable="true"
ondragstart="ondragdream(event, this)"
data-dream="${encodeURIComponent(JSON.stringify(config))}"
data-dreamId="${encodeURIComponent(config.dreamId)}">
</a>
<figcaption onclick="seedClick(event, this)">${seed}</figcaption>
`;
outputNode.innerHTML = figureContents;
if (toEnd) {
document.querySelector("#results").append(outputNode);
} else {
document.querySelector("#results").prepend(outputNode);
}
document.querySelector("#no-results-message")?.remove();
}
@ -119,14 +203,33 @@ async function generateSubmit(form) {
// Convert file data to base64
// TODO: Should probably uplaod files with formdata or something, and store them in the backend?
let formData = Object.fromEntries(new FormData(form));
if (!formData.enable_generate && !formData.enable_init_image) {
gen_label = document.querySelector("label[for=enable_generate]").innerHTML;
initimg_label = document.querySelector("label[for=enable_init_image]").innerHTML;
alert(`Error: one of "${gen_label}" or "${initimg_label}" must be set`);
}
formData.initimg_name = formData.initimg.name
formData.initimg = formData.initimg.name !== '' ? await toBase64(formData.initimg) : null;
// Evaluate all checkboxes
let checkboxes = form.querySelectorAll('input[type=checkbox]');
checkboxes.forEach(function (checkbox) {
if (checkbox.checked) {
formData[checkbox.name] = 'true';
}
});
let strength = formData.strength;
let totalSteps = formData.initimg ? Math.floor(strength * formData.steps) : formData.steps;
let showProgressImages = formData.progress_images;
// Set enabling flags
// Initialize the progress bar
initProgress(totalSteps);
initProgress(totalSteps, showProgressImages);
// POST, use response to listen for events
fetch(form.action, {
@ -136,13 +239,19 @@ async function generateSubmit(form) {
})
.then(response => response.json())
.then(data => {
var dreamId = data.dreamId;
socket.emit('join_room', { 'room': dreamId });
var jobId = data.jobId;
socket.emit('join_room', { 'room': jobId });
});
form.querySelector('fieldset').setAttribute('disabled','');
}
function fieldSetEnableChecked(event) {
cb = event.target;
fields = cb.closest('fieldset');
fields.disabled = !cb.checked;
}
// Socket listeners
socket.on('job_started', (data) => {})
@ -152,6 +261,7 @@ socket.on('dream_result', (data) => {
var dreamRequest = data.dreamRequest;
var src = 'api/images/' + dreamId;
priorResultsLoadState.offset += 1;
appendOutput(src, dreamRequest.seed, dreamRequest);
resetProgress(false);
@ -193,7 +303,13 @@ socket.on('job_done', (data) => {
resetProgress();
})
window.onload = () => {
window.onload = async () => {
document.querySelector("#prompt").addEventListener("keydown", (e) => {
if (e.key === "Enter" && !e.shiftKey) {
const form = e.target.form;
generateSubmit(form);
}
});
document.querySelector("#generate-form").addEventListener('submit', (e) => {
e.preventDefault();
const form = e.target;
@ -216,12 +332,65 @@ window.onload = () => {
loadFields(document.querySelector("#generate-form"));
document.querySelector('#cancel-button').addEventListener('click', () => {
fetch('/cancel').catch(e => {
fetch('/api/cancel').catch(e => {
console.error(e);
});
});
document.documentElement.addEventListener('keydown', (e) => {
if (e.key === "Escape")
fetch('/api/cancel').catch(err => {
console.error(err);
});
});
if (!config.gfpgan_model_exists) {
document.querySelector("#gfpgan").style.display = 'none';
}
window.addEventListener("scroll", () => {
if ((window.innerHeight + window.pageYOffset) >= document.body.offsetHeight) {
loadPriorResults();
}
});
// Enable/disable forms by checkboxes
document.querySelectorAll("legend > input[type=checkbox]").forEach(function(cb) {
cb.addEventListener('change', fieldSetEnableChecked);
fieldSetEnableChecked({ target: cb})
});
// Load some of the previous results
loadPriorResults();
// Image drop/upload WIP
/*
let drop = document.getElementById('dropper');
function ondrop(event) {
let dreamData = event.dataTransfer.getData('dream');
if (dreamData) {
var dream = JSON.parse(decodeURIComponent(dreamData));
alert(dream.dreamId);
}
};
function ondragenter(event) {
event.preventDefault();
};
function ondragover(event) {
event.preventDefault();
};
function ondragleave(event) {
}
drop.addEventListener('drop', ondrop);
drop.addEventListener('dragenter', ondragenter);
drop.addEventListener('dragover', ondragover);
drop.addEventListener('dragleave', ondragleave);
*/
};