Merge with PR #602

- New and improved web api
- Author: @Kyle0654
This commit is contained in:
Lincoln Stein 2022-09-16 16:35:34 -04:00
parent 00d2d0e90e
commit cbac95b02a
14 changed files with 899 additions and 358 deletions

1
.gitignore vendored
View File

@ -191,3 +191,4 @@ checkpoints
.scratch/ .scratch/
.vscode/ .vscode/
gfpgan/ gfpgan/
models/ldm/stable-diffusion-v1/model.sha256

View File

@ -51,6 +51,7 @@ We thank them for all of their time and hard work.
- [Any Winter](https://github.com/any-winter-4079) - [Any Winter](https://github.com/any-winter-4079)
- [Doggettx](https://github.com/doggettx) - [Doggettx](https://github.com/doggettx)
- [Matthias Wild](https://github.com/mauwii) - [Matthias Wild](https://github.com/mauwii)
- [Kyle Schouviller](https://github.com/kyle0654)
## __Original CompVis Authors:__ ## __Original CompVis Authors:__

View File

@ -33,10 +33,11 @@ class PngWriter:
# saves image named _image_ to outdir/name, writing metadata from prompt # saves image named _image_ to outdir/name, writing metadata from prompt
# returns full path of output # returns full path of output
def save_image_and_prompt_to_png(self, image, dream_prompt, metadata, name): def save_image_and_prompt_to_png(self, image, dream_prompt, name, metadata=None):
path = os.path.join(self.outdir, name) path = os.path.join(self.outdir, name)
info = PngImagePlugin.PngInfo() info = PngImagePlugin.PngInfo()
info.add_text('Dream', dream_prompt) info.add_text('Dream', dream_prompt)
if metadata: # TODO: merge command line app's method of writing metadata and always just write metadata
info.add_text('sd-metadata', json.dumps(metadata)) info.add_text('sd-metadata', json.dumps(metadata))
image.save(path, 'PNG', pnginfo=info) image.save(path, 'PNG', pnginfo=info)
return path return path

View File

@ -230,7 +230,7 @@ class DreamServer(BaseHTTPRequestHandler):
image = self.model.sample_to_image(sample) image = self.model.sample_to_image(sample)
name = f'{prefix}.{opt.seed}.{step_index}.png' name = f'{prefix}.{opt.seed}.{step_index}.png'
metadata = f'{opt.prompt} -S{opt.seed} [intermediate]' metadata = f'{opt.prompt} -S{opt.seed} [intermediate]'
path = step_writer.save_image_and_prompt_to_png(image, metadata, name) path = step_writer.save_image_and_prompt_to_png(image, dream_prompt=metadata, name=name)
step_index += 1 step_index += 1
self.wfile.write(bytes(json.dumps( self.wfile.write(bytes(json.dumps(
{'event': 'step', 'step': step + 1, 'url': path} {'event': 'step', 'step': step + 1, 'url': path}

View File

@ -181,7 +181,7 @@ class Generate:
for image, seed in results: for image, seed in results:
name = f'{prefix}.{seed}.png' name = f'{prefix}.{seed}.png'
path = pngwriter.save_image_and_prompt_to_png( path = pngwriter.save_image_and_prompt_to_png(
image, f'{prompt} -S{seed}', name) image, dream_prompt=f'{prompt} -S{seed}', name=name)
outputs.append([path, seed]) outputs.append([path, seed])
return outputs return outputs

View File

@ -22,6 +22,11 @@ test-tube
torch-fidelity torch-fidelity
torchmetrics torchmetrics
transformers transformers
flask==2.1.3
flask_socketio==5.3.0
flask_cors==3.0.10
dependency_injector==4.40.0
eventlet
git+https://github.com/openai/CLIP.git@main#egg=clip git+https://github.com/openai/CLIP.git@main#egg=clip
git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k-diffusion git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k-diffusion
git+https://github.com/lstein/GFPGAN@fix-dark-cast-images#egg=gfpgan git+https://github.com/lstein/GFPGAN@fix-dark-cast-images#egg=gfpgan

View File

@ -7,9 +7,10 @@ import os
import sys import sys
from flask import Flask from flask import Flask
from flask_cors import CORS from flask_cors import CORS
from flask_socketio import SocketIO, join_room, leave_room from flask_socketio import SocketIO
from omegaconf import OmegaConf from omegaconf import OmegaConf
from dependency_injector.wiring import inject, Provide from dependency_injector.wiring import inject, Provide
from ldm.dream.args import Args
from server import views from server import views
from server.containers import Container from server.containers import Container
from server.services import GeneratorService, SignalService from server.services import GeneratorService, SignalService
@ -58,6 +59,8 @@ def run_app(config, host, port) -> Flask:
# TODO: Get storage root from config # TODO: Get storage root from config
app.add_url_rule('/api/images/<string:dreamId>', view_func=views.ApiImages.as_view('api_images', '../')) app.add_url_rule('/api/images/<string:dreamId>', view_func=views.ApiImages.as_view('api_images', '../'))
app.add_url_rule('/api/images/<string:dreamId>/metadata', view_func=views.ApiImagesMetadata.as_view('api_images_metadata', '../'))
app.add_url_rule('/api/images', view_func=views.ApiImagesList.as_view('api_images_list'))
app.add_url_rule('/api/intermediates/<string:dreamId>/<string:step>', view_func=views.ApiIntermediates.as_view('api_intermediates', '../')) app.add_url_rule('/api/intermediates/<string:dreamId>/<string:step>', view_func=views.ApiIntermediates.as_view('api_intermediates', '../'))
app.static_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../static/dream_web/')) app.static_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../static/dream_web/'))
@ -79,30 +82,28 @@ def run_app(config, host, port) -> Flask:
def main(): def main():
"""Initialize command-line parsers and the diffusion model""" """Initialize command-line parsers and the diffusion model"""
from scripts.dream import create_argv_parser arg_parser = Args()
arg_parser = create_argv_parser()
opt = arg_parser.parse_args() opt = arg_parser.parse_args()
if opt.laion400m: if opt.laion400m:
print('--laion400m flag has been deprecated. Please use --model laion400m instead.') print('--laion400m flag has been deprecated. Please use --model laion400m instead.')
sys.exit(-1) sys.exit(-1)
if opt.weights != 'model': if opt.weights:
print('--weights argument has been deprecated. Please configure ./configs/models.yaml, and call it using --model instead.') print('--weights argument has been deprecated. Please edit ./configs/models.yaml, and select the weights using --model instead.')
sys.exit(-1) sys.exit(-1)
try: # try:
models = OmegaConf.load(opt.config) # models = OmegaConf.load(opt.config)
width = models[opt.model].width # width = models[opt.model].width
height = models[opt.model].height # height = models[opt.model].height
config = models[opt.model].config # config = models[opt.model].config
weights = models[opt.model].weights # weights = models[opt.model].weights
except (FileNotFoundError, IOError, KeyError) as e: # except (FileNotFoundError, IOError, KeyError) as e:
print(f'{e}. Aborting.') # print(f'{e}. Aborting.')
sys.exit(-1) # sys.exit(-1)
print('* Initializing, be patient...\n') #print('* Initializing, be patient...\n')
sys.path.append('.') sys.path.append('.')
from pytorch_lightning import logging
# these two lines prevent a horrible warning message from appearing # these two lines prevent a horrible warning message from appearing
# when the frozen CLIP tokenizer is imported # when the frozen CLIP tokenizer is imported
@ -110,26 +111,28 @@ def main():
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
appConfig = { appConfig = opt.__dict__
"model": {
"width": width, # appConfig = {
"height": height, # "model": {
"sampler_name": opt.sampler_name, # "width": width,
"weights": weights, # "height": height,
"full_precision": opt.full_precision, # "sampler_name": opt.sampler_name,
"config": config, # "weights": weights,
"grid": opt.grid, # "full_precision": opt.full_precision,
"latent_diffusion_weights": opt.laion400m, # "config": config,
"embedding_path": opt.embedding_path, # "grid": opt.grid,
"device_type": opt.device # "latent_diffusion_weights": opt.laion400m,
} # "embedding_path": opt.embedding_path
} # }
# }
# make sure the output directory exists # make sure the output directory exists
if not os.path.exists(opt.outdir): if not os.path.exists(opt.outdir):
os.makedirs(opt.outdir) os.makedirs(opt.outdir)
# gets rid of annoying messages about random seed # gets rid of annoying messages about random seed
from pytorch_lightning import logging
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR) logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
print('\n* starting api server...') print('\n* starting api server...')

View File

@ -17,18 +17,24 @@ class Container(containers.DeclarativeContainer):
app = None app = None
) )
# TODO: Add a model provider service that provides model(s) dynamically
model_singleton = providers.ThreadSafeSingleton( model_singleton = providers.ThreadSafeSingleton(
Generate, Generate,
width = config.model.width, model = config.model,
height = config.model.height, sampler_name = config.sampler_name,
sampler_name = config.model.sampler_name, embedding_path = config.embedding_path,
weights = config.model.weights, full_precision = config.full_precision
full_precision = config.model.full_precision, # config = config.model.config,
config = config.model.config,
grid = config.model.grid, # width = config.model.width,
seamless = config.model.seamless, # height = config.model.height,
embedding_path = config.model.embedding_path, # sampler_name = config.model.sampler_name,
device_type = config.model.device_type # weights = config.model.weights,
# full_precision = config.model.full_precision,
# grid = config.model.grid,
# seamless = config.model.seamless,
# embedding_path = config.model.embedding_path,
# device_type = config.model.device_type
) )
# TODO: get location from config # TODO: get location from config

View File

@ -1,77 +1,182 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from base64 import urlsafe_b64encode
import json import json
import string import string
from copy import deepcopy from copy import deepcopy
from datetime import datetime, timezone from datetime import datetime, timezone
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Union
from uuid import uuid4
class DreamRequest():
prompt: string class DreamBase():
initimg: string # Id
strength: float id: str
iterations: int
steps: int # Initial Image
width: int enable_init_image: bool
height: int initimg: string = None
fit = None
cfgscale: float # Img2Img
sampler_name: string enable_img2img: bool # TODO: support this better
gfpgan_strength: float strength: float = 0 # TODO: name this something related to img2img to make it clearer?
upscale_level: int fit = None # Fit initial image dimensions
upscale_strength: float
# Generation
enable_generate: bool
prompt: string = ""
seed: int = 0 # 0 is random
steps: int = 10
width: int = 512
height: int = 512
cfg_scale: float = 7.5
sampler_name: string = 'klms'
seamless: bool = False
model: str = None # The model to use (currently unused)
embeddings = None # The embeddings to use (currently unused)
progress_images: bool = False
# GFPGAN
enable_gfpgan: bool
gfpgan_strength: float = 0
# Upscale
enable_upscale: bool
upscale: None upscale: None
progress_images = None upscale_level: int = None
seed: int upscale_strength: float = 0.75
# Embiggen
enable_embiggen: bool
embiggen: Union[None, List[float]] = None
embiggen_tiles: Union[None, List[int]] = None
# Metadata
time: int time: int
def __init__(self):
self.id = urlsafe_b64encode(uuid4().bytes).decode('ascii')
def parse_json(self, j, new_instance=False):
# Id
if 'id' in j and not new_instance:
self.id = j.get('id')
# Initial Image
self.enable_init_image = 'enable_init_image' in j and bool(j.get('enable_init_image'))
if self.enable_init_image:
self.initimg = j.get('initimg')
# Img2Img
self.enable_img2img = 'enable_img2img' in j and bool(j.get('enable_img2img'))
if self.enable_img2img:
self.strength = float(j.get('strength'))
self.fit = 'fit' in j
# Generation
self.enable_generate = 'enable_generate' in j and bool(j.get('enable_generate'))
if self.enable_generate:
self.prompt = j.get('prompt')
self.seed = int(j.get('seed'))
self.steps = int(j.get('steps'))
self.width = int(j.get('width'))
self.height = int(j.get('height'))
self.cfg_scale = float(j.get('cfgscale') or j.get('cfg_scale'))
self.sampler_name = j.get('sampler') or j.get('sampler_name')
# model: str = None # The model to use (currently unused)
# embeddings = None # The embeddings to use (currently unused)
self.seamless = 'seamless' in j
self.progress_images = 'progress_images' in j
# GFPGAN
self.enable_gfpgan = 'enable_gfpgan' in j and bool(j.get('enable_gfpgan'))
if self.enable_gfpgan:
self.gfpgan_strength = float(j.get('gfpgan_strength'))
# Upscale
self.enable_upscale = 'enable_upscale' in j and bool(j.get('enable_upscale'))
if self.enable_upscale:
self.upscale_level = j.get('upscale_level')
self.upscale_strength = j.get('upscale_strength')
self.upscale = None if self.upscale_level in {None,''} else [int(self.upscale_level),float(self.upscale_strength)]
# Embiggen
self.enable_embiggen = 'enable_embiggen' in j and bool(j.get('enable_embiggen'))
if self.enable_embiggen:
self.embiggen = j.get('embiggen')
self.embiggen_tiles = j.get('embiggen_tiles')
# Metadata
self.time = int(j.get('time')) if ('time' in j and not new_instance) else int(datetime.now(timezone.utc).timestamp())
class DreamResult(DreamBase):
# Result
has_upscaled: False
has_gfpgan: False
# TODO: use something else for state tracking # TODO: use something else for state tracking
images_generated: int = 0 images_generated: int = 0
images_upscaled: int = 0 images_upscaled: int = 0
def id(self, seed = None, upscaled = False) -> str: def __init__(self):
return f"{self.time}.{seed or self.seed}{'.u' if upscaled else ''}" super().__init__()
# TODO: handle this more cleanly (probably by splitting this into a Job and Result class) def clone_without_img(self):
# TODO: Set iterations to 1 or remove it from the dream result? And just keep it on the job? copy = deepcopy(self)
def clone_without_image(self, seed = None): copy.initimg = None
data = deepcopy(self) return copy
data.initimg = None
if seed:
data.seed = seed
return data def to_json(self):
copy = deepcopy(self)
def to_json(self, seed: int = None): copy.initimg = None
copy = self.clone_without_image(seed) j = json.dumps(copy.__dict__)
return json.dumps(copy.__dict__) return j
@staticmethod @staticmethod
def from_json(j, newTime: bool = False): def from_json(j, newTime: bool = False):
d = DreamRequest() d = DreamResult()
d.prompt = j.get('prompt') d.parse_json(j)
d.initimg = j.get('initimg')
d.strength = float(j.get('strength'))
d.iterations = int(j.get('iterations'))
d.steps = int(j.get('steps'))
d.width = int(j.get('width'))
d.height = int(j.get('height'))
d.fit = 'fit' in j
d.seamless = 'seamless' in j
d.cfgscale = float(j.get('cfgscale'))
d.sampler_name = j.get('sampler')
d.variation_amount = float(j.get('variation_amount'))
d.with_variations = j.get('with_variations')
d.gfpgan_strength = float(j.get('gfpgan_strength'))
d.upscale_level = j.get('upscale_level')
d.upscale_strength = j.get('upscale_strength')
d.upscale = [int(d.upscale_level),float(d.upscale_strength)] if d.upscale_level != '' else None
d.progress_images = 'progress_images' in j
d.seed = int(j.get('seed'))
d.time = int(datetime.now(timezone.utc).timestamp()) if newTime else int(j.get('time'))
return d return d
# TODO: switch this to a pipelined request, with pluggable steps
# Will likely require generator code changes to accomplish
class JobRequest(DreamBase):
# Iteration
iterations: int = 1
variation_amount = None
with_variations = None
# Results
results: List[DreamResult] = []
def __init__(self):
super().__init__()
def newDreamResult(self) -> DreamResult:
result = DreamResult()
result.parse_json(self.__dict__, new_instance=True)
return result
@staticmethod
def from_json(j):
job = JobRequest()
job.parse_json(j)
# Metadata
job.time = int(j.get('time')) if ('time' in j) else int(datetime.now(timezone.utc).timestamp())
# Iteration
if job.enable_generate:
job.iterations = int(j.get('iterations'))
job.variation_amount = float(j.get('variation_amount'))
job.with_variations = j.get('with_variations')
return job
class ProgressType(Enum): class ProgressType(Enum):
GENERATION = 1 GENERATION = 1
UPSCALING_STARTED = 2 UPSCALING_STARTED = 2
@ -102,11 +207,11 @@ class Signal():
# TODO: use a result id or something? Like a sub-job # TODO: use a result id or something? Like a sub-job
@staticmethod @staticmethod
def image_result(jobId: str, dreamId: str, dreamRequest: DreamRequest): def image_result(jobId: str, dreamId: str, dreamResult: DreamResult):
return Signal('dream_result', { return Signal('dream_result', {
'jobId': jobId, 'jobId': jobId,
'dreamId': dreamId, 'dreamId': dreamId,
'dreamRequest': dreamRequest.__dict__ 'dreamRequest': dreamResult.clone_without_img().__dict__
}, room=jobId, broadcast=True) }, room=jobId, broadcast=True)
@staticmethod @staticmethod
@ -126,3 +231,21 @@ class Signal():
return Signal('job_canceled', { return Signal('job_canceled', {
'jobId': jobId 'jobId': jobId
}, room=jobId, broadcast=True) }, room=jobId, broadcast=True)
class PaginatedItems():
items: List[Any]
page: int # Current Page
pages: int # Total number of pages
per_page: int # Number of items per page
total: int # Total number of items in result
def __init__(self, items: List[Any], page: int, pages: int, per_page: int, total: int):
self.items = items
self.page = page
self.pages = pages
self.per_page = per_page
self.total = total
def to_json(self):
return json.dumps(self.__dict__)

View File

@ -1,25 +1,33 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from argparse import ArgumentParser
import base64 import base64
from datetime import datetime, timezone
import glob
import json
import os import os
from pathlib import Path
from queue import Empty, Queue from queue import Empty, Queue
import shlex
from threading import Thread from threading import Thread
import time import time
from flask import app, url_for
from flask_socketio import SocketIO, join_room, leave_room from flask_socketio import SocketIO, join_room, leave_room
from ldm.dream.args import Args
from ldm.dream.generator import embiggen
from PIL import Image
from ldm.dream.pngwriter import PngWriter from ldm.dream.pngwriter import PngWriter
from ldm.dream.server import CanceledException from ldm.dream.server import CanceledException
from ldm.generate import Generate from ldm.generate import Generate
from server.models import DreamRequest, ProgressType, Signal from server.models import DreamResult, JobRequest, PaginatedItems, ProgressType, Signal
class JobQueueService: class JobQueueService:
__queue: Queue = Queue() __queue: Queue = Queue()
def push(self, dreamRequest: DreamRequest): def push(self, dreamRequest: DreamResult):
self.__queue.put(dreamRequest) self.__queue.put(dreamRequest)
def get(self, timeout: float = None) -> DreamRequest: def get(self, timeout: float = None) -> DreamResult:
return self.__queue.get(timeout= timeout) return self.__queue.get(timeout= timeout)
class SignalQueueService: class SignalQueueService:
@ -85,25 +93,28 @@ class LogService:
self.__location = location self.__location = location
self.__logFile = file self.__logFile = file
def log(self, dreamRequest: DreamRequest, seed = None, upscaled = False): def log(self, dreamResult: DreamResult, seed = None, upscaled = False):
with open(os.path.join(self.__location, self.__logFile), "a") as log: with open(os.path.join(self.__location, self.__logFile), "a") as log:
log.write(f"{dreamRequest.id(seed, upscaled)}: {dreamRequest.to_json(seed)}\n") log.write(f"{dreamResult.id}: {dreamResult.to_json()}\n")
class ImageStorageService: class ImageStorageService:
__location: str __location: str
__pngWriter: PngWriter __pngWriter: PngWriter
__legacyParser: ArgumentParser
def __init__(self, location): def __init__(self, location):
self.__location = location self.__location = location
self.__pngWriter = PngWriter(self.__location) self.__pngWriter = PngWriter(self.__location)
self.__legacyParser = Args() # TODO: inject this?
def __getName(self, dreamId: str, postfix: str = '') -> str: def __getName(self, dreamId: str, postfix: str = '') -> str:
return f'{dreamId}{postfix}.png' return f'{dreamId}{postfix}.png'
def save(self, image, dreamRequest, seed = None, upscaled = False, postfix: str = '', metadataPostfix: str = '') -> str: def save(self, image, dreamResult: DreamResult, postfix: str = '') -> str:
name = self.__getName(dreamRequest.id(seed, upscaled), postfix) name = self.__getName(dreamResult.id, postfix)
path = self.__pngWriter.save_image_and_prompt_to_png(image, f'{dreamRequest.prompt} -S{seed or dreamRequest.seed}{metadataPostfix}', name) meta = dreamResult.to_json() # TODO: make all methods consistent with writing metadata. Standardize metadata.
path = self.__pngWriter.save_image_and_prompt_to_png(image, dream_prompt=meta, metadata=None, name=name)
return path return path
def path(self, dreamId: str, postfix: str = '') -> str: def path(self, dreamId: str, postfix: str = '') -> str:
@ -111,6 +122,88 @@ class ImageStorageService:
path = os.path.join(self.__location, name) path = os.path.join(self.__location, name)
return path return path
# Returns true if found, false if not found or error
def delete(self, dreamId: str, postfix: str = '') -> bool:
path = self.path(dreamId, postfix)
if (os.path.exists(path)):
os.remove(path)
return True
else:
return False
def getMetadata(self, dreamId: str, postfix: str = '') -> DreamResult:
path = self.path(dreamId, postfix)
image = Image.open(path)
text = image.text
if text.__contains__('Dream'):
dreamMeta = text.get('Dream')
try:
j = json.loads(dreamMeta)
return DreamResult.from_json(j)
except ValueError:
# Try to parse command-line format (legacy metadata format)
try:
opt = self.__parseLegacyMetadata(dreamMeta)
optd = opt.__dict__
if (not 'width' in optd) or (optd.get('width') is None):
optd['width'] = image.width
if (not 'height' in optd) or (optd.get('height') is None):
optd['height'] = image.height
if (not 'steps' in optd) or (optd.get('steps') is None):
optd['steps'] = 10 # No way around this unfortunately - seems like it wasn't storing this previously
optd['time'] = os.path.getmtime(path) # Set timestamp manually (won't be exactly correct though)
return DreamResult.from_json(optd)
except:
return None
else:
return None
def __parseLegacyMetadata(self, command: str) -> DreamResult:
# before splitting, escape single quotes so as not to mess
# up the parser
command = command.replace("'", "\\'")
try:
elements = shlex.split(command)
except ValueError as e:
return None
# rearrange the arguments to mimic how it works in the Dream bot.
switches = ['']
switches_started = False
for el in elements:
if el[0] == '-' and not switches_started:
switches_started = True
if switches_started:
switches.append(el)
else:
switches[0] += el
switches[0] += ' '
switches[0] = switches[0][: len(switches[0]) - 1]
try:
opt = self.__legacyParser.parse_cmd(switches)
return opt
except SystemExit:
return None
def list_files(self, page: int, perPage: int) -> PaginatedItems:
files = sorted(glob.glob(os.path.join(self.__location,'*.png')), key=os.path.getmtime, reverse=True)
count = len(files)
startId = page * perPage
pageCount = int(count / perPage) + 1
endId = min(startId + perPage, count)
items = [] if startId >= count else files[startId:endId]
items = list(map(lambda f: Path(f).stem, items))
return PaginatedItems(items, page, pageCount, perPage, count)
class GeneratorService: class GeneratorService:
__model: Generate __model: Generate
@ -144,13 +237,11 @@ class GeneratorService:
# TODO: Consider moving this to its own service if there's benefit in separating the generator # TODO: Consider moving this to its own service if there's benefit in separating the generator
def __process(self): def __process(self):
# preload the model # preload the model
# TODO: support multiple models
print('Preloading model') print('Preloading model')
tic = time.time() tic = time.time()
self.__model.load_model() self.__model.load_model()
print( print(f'>> model loaded in', '%4.2fs' % (time.time() - tic))
f'>> model loaded in', '%4.2fs' % (time.time() - tic)
)
print('Started generation queue processor') print('Started generation queue processor')
try: try:
@ -162,103 +253,136 @@ class GeneratorService:
print('Generation queue processor stopped') print('Generation queue processor stopped')
def __start(self, dreamRequest: DreamRequest): def __on_start(self, jobRequest: JobRequest):
if dreamRequest.start_callback: self.__signal_service.emit(Signal.job_started(jobRequest.id))
dreamRequest.start_callback()
self.__signal_service.emit(Signal.job_started(dreamRequest.id()))
def __done(self, dreamRequest: DreamRequest, image, seed, upscaled=False): def __on_image_result(self, jobRequest: JobRequest, image, seed, upscaled=False):
self.__imageStorage.save(image, dreamRequest, seed, upscaled) dreamResult = jobRequest.newDreamResult()
dreamResult.seed = seed
dreamResult.has_upscaled = upscaled
dreamResult.iterations = 1
jobRequest.results.append(dreamResult)
# TODO: Separate status of GFPGAN?
self.__imageStorage.save(image, dreamResult)
# TODO: handle upscaling logic better (this is appending data to log, but only on first generation) # TODO: handle upscaling logic better (this is appending data to log, but only on first generation)
if not upscaled: if not upscaled:
self.__log.log(dreamRequest, seed, upscaled) self.__log.log(dreamResult)
self.__signal_service.emit(Signal.image_result(dreamRequest.id(), dreamRequest.id(seed, upscaled), dreamRequest.clone_without_image(seed))) # Send result signal
self.__signal_service.emit(Signal.image_result(jobRequest.id, dreamResult.id, dreamResult))
upscaling_requested = dreamRequest.upscale or dreamRequest.gfpgan_strength>0 upscaling_requested = dreamResult.enable_upscale or dreamResult.enable_gfpgan
if upscaled: # Report upscaling status
dreamRequest.images_upscaled += 1 # TODO: this is very coupled to logic inside the generator. Fix that.
else: if upscaling_requested and any(result.has_upscaled for result in jobRequest.results):
dreamRequest.images_generated +=1 progressType = ProgressType.UPSCALING_STARTED if len(jobRequest.results) < 2 * jobRequest.iterations else ProgressType.UPSCALING_DONE
if upscaling_requested: upscale_count = sum(1 for i in jobRequest.results if i.has_upscaled)
# action = None self.__signal_service.emit(Signal.image_progress(jobRequest.id, dreamResult.id, upscale_count, jobRequest.iterations, progressType))
if dreamRequest.images_generated >= dreamRequest.iterations:
progressType = ProgressType.UPSCALING_DONE
if dreamRequest.images_upscaled < dreamRequest.iterations:
progressType = ProgressType.UPSCALING_STARTED
self.__signal_service.emit(Signal.image_progress(dreamRequest.id(), dreamRequest.id(seed), dreamRequest.images_upscaled+1, dreamRequest.iterations, progressType))
def __progress(self, dreamRequest, sample, step): def __on_progress(self, jobRequest: JobRequest, sample, step):
if self.__cancellationRequested: if self.__cancellationRequested:
self.__cancellationRequested = False self.__cancellationRequested = False
raise CanceledException raise CanceledException
# TODO: Progress per request will be easier once the seeds (and ids) can all be pre-generated
hasProgressImage = False hasProgressImage = False
if dreamRequest.progress_images and step % 5 == 0 and step < dreamRequest.steps - 1: s = str(len(jobRequest.results))
if jobRequest.progress_images and step % 5 == 0 and step < jobRequest.steps - 1:
image = self.__model._sample_to_image(sample) image = self.__model._sample_to_image(sample)
self.__intermediateStorage.save(image, dreamRequest, self.__model.seed, postfix=f'.{step}', metadataPostfix=f' [intermediate]')
# TODO: clean this up, use a pre-defined dream result
result = DreamResult()
result.parse_json(jobRequest.__dict__, new_instance=False)
self.__intermediateStorage.save(image, result, postfix=f'.{s}.{step}')
hasProgressImage = True hasProgressImage = True
self.__signal_service.emit(Signal.image_progress(dreamRequest.id(), dreamRequest.id(self.__model.seed), step, dreamRequest.steps, ProgressType.GENERATION, hasProgressImage)) self.__signal_service.emit(Signal.image_progress(jobRequest.id, f'{jobRequest.id}.{s}', step, jobRequest.steps, ProgressType.GENERATION, hasProgressImage))
def __generate(self, dreamRequest: DreamRequest): def __generate(self, jobRequest: JobRequest):
try: try:
initimgfile = None # TODO: handle this file a file service for init images
if dreamRequest.initimg is not None: initimgfile = None # TODO: support this on the model directly?
if (jobRequest.enable_init_image):
if jobRequest.initimg is not None:
with open("./img2img-tmp.png", "wb") as f: with open("./img2img-tmp.png", "wb") as f:
initimg = dreamRequest.initimg.split(",")[1] # Ignore mime type initimg = jobRequest.initimg.split(",")[1] # Ignore mime type
f.write(base64.b64decode(initimg)) f.write(base64.b64decode(initimg))
initimgfile = "./img2img-tmp.png" initimgfile = "./img2img-tmp.png"
# Get a random seed if we don't have one yet # Use previous seed if set to -1
# TODO: handle "previous" seed usage? initSeed = jobRequest.seed
if dreamRequest.seed == -1: if initSeed == -1:
dreamRequest.seed = self.__model.seed initSeed = self.__model.seed
# Zero gfpgan strength if the model doesn't exist # Zero gfpgan strength if the model doesn't exist
# TODO: determine if this could be at the top now? Used to cause circular import # TODO: determine if this could be at the top now? Used to cause circular import
from ldm.gfpgan.gfpgan_tools import gfpgan_model_exists from ldm.gfpgan.gfpgan_tools import gfpgan_model_exists
if not gfpgan_model_exists: if not gfpgan_model_exists:
dreamRequest.gfpgan_strength = 0 jobRequest.enable_gfpgan = False
# Signal start
self.__on_start(jobRequest)
# Generate in model
# TODO: Split job generation requests instead of fitting all parameters here
# TODO: Support no generation (just upscaling/gfpgan)
upscale = None if not jobRequest.enable_upscale else jobRequest.upscale
gfpgan_strength = 0 if not jobRequest.enable_gfpgan else jobRequest.gfpgan_strength
if not jobRequest.enable_generate:
# If not generating, check if we're upscaling or running gfpgan
if not upscale and not gfpgan_strength:
# Invalid settings (TODO: Add message to help user)
raise CanceledException()
image = Image.open(initimgfile)
# TODO: support progress for upscale?
self.__model.upscale_and_reconstruct(
image_list = [[image,0]],
upscale = upscale,
strength = gfpgan_strength,
save_original = False,
image_callback = lambda image, seed, upscaled=False: self.__on_image_result(jobRequest, image, seed, upscaled))
else:
# Generating - run the generation
init_img = None if (not jobRequest.enable_img2img or jobRequest.strength == 0) else initimgfile
self.__start(dreamRequest)
self.__model.prompt2image( self.__model.prompt2image(
prompt = dreamRequest.prompt, prompt = jobRequest.prompt,
init_img = initimgfile, # TODO: ensure this works init_img = init_img, # TODO: ensure this works
strength = None if initimgfile is None else dreamRequest.strength, strength = None if init_img is None else jobRequest.strength,
fit = None if initimgfile is None else dreamRequest.fit, fit = None if init_img is None else jobRequest.fit,
iterations = dreamRequest.iterations, iterations = jobRequest.iterations,
cfg_scale = dreamRequest.cfgscale, cfg_scale = jobRequest.cfg_scale,
width = dreamRequest.width, width = jobRequest.width,
height = dreamRequest.height, height = jobRequest.height,
seed = dreamRequest.seed, seed = jobRequest.seed,
steps = dreamRequest.steps, steps = jobRequest.steps,
variation_amount = dreamRequest.variation_amount, variation_amount = jobRequest.variation_amount,
with_variations = dreamRequest.with_variations, with_variations = jobRequest.with_variations,
gfpgan_strength = dreamRequest.gfpgan_strength, gfpgan_strength = gfpgan_strength,
upscale = dreamRequest.upscale, upscale = upscale,
sampler_name = dreamRequest.sampler_name, sampler_name = jobRequest.sampler_name,
seamless = dreamRequest.seamless, seamless = jobRequest.seamless,
step_callback = lambda sample, step: self.__progress(dreamRequest, sample, step), embiggen = jobRequest.embiggen,
image_callback = lambda image, seed, upscaled=False: self.__done(dreamRequest, image, seed, upscaled)) embiggen_tiles = jobRequest.embiggen_tiles,
step_callback = lambda sample, step: self.__on_progress(jobRequest, sample, step),
image_callback = lambda image, seed, upscaled=False: self.__on_image_result(jobRequest, image, seed, upscaled))
except CanceledException: except CanceledException:
if dreamRequest.cancelled_callback: self.__signal_service.emit(Signal.job_canceled(jobRequest.id))
dreamRequest.cancelled_callback()
self.__signal_service.emit(Signal.job_canceled(dreamRequest.id()))
finally: finally:
if dreamRequest.done_callback: self.__signal_service.emit(Signal.job_done(jobRequest.id))
dreamRequest.done_callback()
self.__signal_service.emit(Signal.job_done(dreamRequest.id()))
# Remove the temp file # Remove the temp file
if (initimgfile is not None): if (initimgfile is not None):

View File

@ -8,7 +8,7 @@ from flask import current_app, jsonify, request, Response, send_from_directory,
from flask.views import MethodView from flask.views import MethodView
from dependency_injector.wiring import inject, Provide from dependency_injector.wiring import inject, Provide
from server.models import DreamRequest from server.models import DreamResult, JobRequest
from server.services import GeneratorService, ImageStorageService, JobQueueService from server.services import GeneratorService, ImageStorageService, JobQueueService
from server.containers import Container from server.containers import Container
@ -16,23 +16,14 @@ class ApiJobs(MethodView):
@inject @inject
def post(self, job_queue_service: JobQueueService = Provide[Container.generation_queue_service]): def post(self, job_queue_service: JobQueueService = Provide[Container.generation_queue_service]):
dreamRequest = DreamRequest.from_json(request.json, newTime = True) jobRequest = JobRequest.from_json(request.json)
#self.canceled.clear() print(f">> Request to generate with prompt: {jobRequest.prompt}")
print(f">> Request to generate with prompt: {dreamRequest.prompt}")
q = Queue()
dreamRequest.start_callback = None
dreamRequest.image_callback = None
dreamRequest.progress_callback = None
dreamRequest.cancelled_callback = None
dreamRequest.done_callback = None
# Push the request # Push the request
job_queue_service.push(dreamRequest) job_queue_service.push(jobRequest)
return { 'dreamId': dreamRequest.id() } return { 'jobId': jobRequest.id }
class WebIndex(MethodView): class WebIndex(MethodView):
@ -68,6 +59,7 @@ class ApiCancel(MethodView):
return Response(status=204) return Response(status=204)
# TODO: Combine all image storage access
class ApiImages(MethodView): class ApiImages(MethodView):
init_every_request = False init_every_request = False
__pathRoot = None __pathRoot = None
@ -83,6 +75,27 @@ class ApiImages(MethodView):
fullpath=os.path.join(self.__pathRoot, name) fullpath=os.path.join(self.__pathRoot, name)
return send_from_directory(os.path.dirname(fullpath), os.path.basename(fullpath)) return send_from_directory(os.path.dirname(fullpath), os.path.basename(fullpath))
def delete(self, dreamId):
result = self.__storage.delete(dreamId)
return Response(status=204) if result else Response(status=404)
class ApiImagesMetadata(MethodView):
init_every_request = False
__pathRoot = None
__storage: ImageStorageService
@inject
def __init__(self, pathBase, storage: ImageStorageService = Provide[Container.image_storage_service]):
self.__pathRoot = os.path.abspath(os.path.join(os.path.dirname(__file__), pathBase))
self.__storage = storage
def get(self, dreamId):
meta = self.__storage.getMetadata(dreamId)
j = {} if meta is None else meta.__dict__
return j
class ApiIntermediates(MethodView): class ApiIntermediates(MethodView):
init_every_request = False init_every_request = False
__pathRoot = None __pathRoot = None
@ -97,3 +110,23 @@ class ApiIntermediates(MethodView):
name = self.__storage.path(dreamId, postfix=f'.{step}') name = self.__storage.path(dreamId, postfix=f'.{step}')
fullpath=os.path.join(self.__pathRoot, name) fullpath=os.path.join(self.__pathRoot, name)
return send_from_directory(os.path.dirname(fullpath), os.path.basename(fullpath)) return send_from_directory(os.path.dirname(fullpath), os.path.basename(fullpath))
def delete(self, dreamId):
result = self.__storage.delete(dreamId)
return Response(status=204) if result else Response(status=404)
class ApiImagesList(MethodView):
init_every_request = False
__storage: ImageStorageService
@inject
def __init__(self, storage: ImageStorageService = Provide[Container.image_storage_service]):
self.__storage = storage
def get(self):
page = request.args.get("page", default=0, type=int)
perPage = request.args.get("per_page", default=10, type=int)
result = self.__storage.list_files(page, perPage)
return result.__dict__

View File

@ -1,3 +1,8 @@
:root {
--fields-dark:#DCDCDC;
--fields-light:#F5F5F5;
}
* { * {
font-family: 'Arial'; font-family: 'Arial';
font-size: 100%; font-size: 100%;
@ -18,15 +23,26 @@ fieldset {
border: none; border: none;
line-height: 2.2em; line-height: 2.2em;
} }
fieldset > legend {
width: auto;
margin-left: 0;
margin-right: auto;
font-weight:bold;
}
select, input { select, input {
margin-right: 10px; margin-right: 10px;
padding: 2px; padding: 2px;
} }
input:disabled {
cursor:auto;
}
input[type=submit] { input[type=submit] {
cursor: pointer;
background-color: #666; background-color: #666;
color: white; color: white;
} }
input[type=checkbox] { input[type=checkbox] {
cursor: pointer;
margin-right: 0px; margin-right: 0px;
width: 20px; width: 20px;
height: 20px; height: 20px;
@ -87,11 +103,11 @@ header h1 {
} }
#results img { #results img {
border-radius: 5px; border-radius: 5px;
object-fit: cover; object-fit: contain;
background-color: var(--fields-dark);
} }
#fieldset-config { #fieldset-config {
line-height:2em; line-height:2em;
background-color: #F0F0F0;
} }
input[type="number"] { input[type="number"] {
width: 60px; width: 60px;
@ -118,35 +134,46 @@ label {
#progress-image { #progress-image {
width: 30vh; width: 30vh;
height: 30vh; height: 30vh;
object-fit: contain;
background-color: var(--fields-dark);
} }
#cancel-button { #cancel-button {
cursor: pointer; cursor: pointer;
color: red; color: red;
} }
#basic-parameters {
background-color: #EEEEEE;
}
#txt2img { #txt2img {
background-color: #DCDCDC; background-color: var(--fields-dark);
} }
#variations { #variations {
background-color: #EEEEEE; background-color: var(--fields-light);
}
#initimg {
background-color: var(--fields-dark);
} }
#img2img { #img2img {
background-color: #DCDCDC; background-color: var(--fields-light);
} }
#gfpgan { #initimg > :not(legend) {
background-color: #EEEEEE; background-color: var(--fields-light);
margin: .5em;
}
#postprocess, #initimg {
display:flex;
flex-wrap:wrap;
padding: 0;
margin-top: 1em;
background-color: var(--fields-dark);
}
#postprocess > fieldset, #initimg > * {
flex-grow: 1;
}
#postprocess > fieldset {
background-color: var(--fields-dark);
} }
#progress-section { #progress-section {
background-color: #F5F5F5; background-color: var(--fields-light);
}
.section-header {
text-align: left;
font-weight: bold;
padding: 0 0 0 0;
} }
#no-results-message:not(:only-child) { #no-results-message:not(:only-child) {
display: none; display: none;
} }

View File

@ -1,41 +1,50 @@
<html lang="en"> <html lang="en">
<head>
<head>
<title>Stable Diffusion Dream Server</title> <title>Stable Diffusion Dream Server</title>
<meta charset="utf-8"> <meta charset="utf-8">
<link rel="icon" href="data:,"> <link rel="icon" type="image/x-icon" href="static/dream_web/favicon.ico" />
<meta name="viewport" content="width=device-width, initial-scale=1.0"> <meta name="viewport" content="width=device-width, initial-scale=1.0">
<link rel="stylesheet" href="index.css">
<script src="config.js"></script> <script src="config.js"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/socket.io/4.0.1/socket.io.js" integrity="sha512-q/dWJ3kcmjBLU4Qc47E4A9kTB4m3wuTY7vkFJDTZKjTs8jhyGQnaUrxa0Ytd0ssMZhbNua9hE+E7Qv1j+DyZwA==" crossorigin="anonymous"></script> <script src="https://cdnjs.cloudflare.com/ajax/libs/socket.io/4.0.1/socket.io.js"
integrity="sha512-q/dWJ3kcmjBLU4Qc47E4A9kTB4m3wuTY7vkFJDTZKjTs8jhyGQnaUrxa0Ytd0ssMZhbNua9hE+E7Qv1j+DyZwA=="
crossorigin="anonymous"></script>
<link rel="stylesheet" href="index.css">
<script src="index.js"></script> <script src="index.js"></script>
</head> </head>
<body>
<body>
<header> <header>
<h1>Stable Diffusion Dream Server</h1> <h1>Stable Diffusion Dream Server</h1>
<div id="about"> <div id="about">
For news and support for this web service, visit our <a href="http://github.com/lstein/stable-diffusion">GitHub site</a> For news and support for this web service, visit our <a href="http://github.com/lstein/stable-diffusion">GitHub
site</a>
</div> </div>
</header> </header>
<main> <main>
<!--
<div id="dropper" style="background-color:red;width:200px;height:200px;">
</div>
-->
<form id="generate-form" method="post" action="api/jobs"> <form id="generate-form" method="post" action="api/jobs">
<fieldset id="txt2img"> <fieldset id="txt2img">
<legend>
<input type="checkbox" name="enable_generate" id="enable_generate" checked>
<label for="enable_generate">Generate</label>
</legend>
<div id="search-box"> <div id="search-box">
<textarea rows="3" id="prompt" name="prompt"></textarea> <textarea rows="3" id="prompt" name="prompt"></textarea>
<input type="submit" id="submit" value="Generate">
</div> </div>
</fieldset>
<fieldset id="fieldset-config">
<div class="section-header">Basic options</div>
<label for="iterations">Images to generate:</label> <label for="iterations">Images to generate:</label>
<input value="1" type="number" id="iterations" name="iterations" size="4"> <input value="1" type="number" id="iterations" name="iterations" size="4">
<label for="steps">Steps:</label> <label for="steps">Steps:</label>
<input value="50" type="number" id="steps" name="steps"> <input value="50" type="number" id="steps" name="steps">
<label for="cfgscale">Cfg Scale:</label> <label for="cfg_scale">Cfg Scale:</label>
<input value="7.5" type="number" id="cfgscale" name="cfgscale" step="any"> <input value="7.5" type="number" id="cfg_scale" name="cfg_scale" step="any">
<label for="sampler">Sampler:</label> <label for="sampler_name">Sampler:</label>
<select id="sampler" name="sampler" value="k_lms"> <select id="sampler_name" name="sampler_name" value="k_lms">
<option value="ddim">DDIM</option> <option value="ddim">DDIM</option>
<option value="plms">PLMS</option> <option value="plms">PLMS</option>
<option value="k_lms" selected>KLMS</option> <option value="k_lms" selected>KLMS</option>
@ -50,25 +59,41 @@
<br> <br>
<label title="Set to multiple of 64" for="width">Width:</label> <label title="Set to multiple of 64" for="width">Width:</label>
<select id="width" name="width" value="512"> <select id="width" name="width" value="512">
<option value="64">64</option> <option value="128">128</option> <option value="64">64</option>
<option value="192">192</option> <option value="256">256</option> <option value="128">128</option>
<option value="320">320</option> <option value="384">384</option> <option value="192">192</option>
<option value="448">448</option> <option value="512" selected>512</option> <option value="256">256</option>
<option value="576">576</option> <option value="640">640</option> <option value="320">320</option>
<option value="704">704</option> <option value="768">768</option> <option value="384">384</option>
<option value="832">832</option> <option value="896">896</option> <option value="448">448</option>
<option value="960">960</option> <option value="1024">1024</option> <option value="512" selected>512</option>
<option value="576">576</option>
<option value="640">640</option>
<option value="704">704</option>
<option value="768">768</option>
<option value="832">832</option>
<option value="896">896</option>
<option value="960">960</option>
<option value="1024">1024</option>
</select> </select>
<label title="Set to multiple of 64" for="height">Height:</label> <label title="Set to multiple of 64" for="height">Height:</label>
<select id="height" name="height" value="512"> <select id="height" name="height" value="512">
<option value="64">64</option> <option value="128">128</option> <option value="64">64</option>
<option value="192">192</option> <option value="256">256</option> <option value="128">128</option>
<option value="320">320</option> <option value="384">384</option> <option value="192">192</option>
<option value="448">448</option> <option value="512" selected>512</option> <option value="256">256</option>
<option value="576">576</option> <option value="640">640</option> <option value="320">320</option>
<option value="704">704</option> <option value="768">768</option> <option value="384">384</option>
<option value="832">832</option> <option value="896">896</option> <option value="448">448</option>
<option value="960">960</option> <option value="1024">1024</option> <option value="512" selected>512</option>
<option value="576">576</option>
<option value="640">640</option>
<option value="704">704</option>
<option value="768">768</option>
<option value="832">832</option>
<option value="896">896</option>
<option value="960">960</option>
<option value="1024">1024</option>
</select> </select>
<label title="Set to 0 for random seed" for="seed">Seed:</label> <label title="Set to 0 for random seed" for="seed">Seed:</label>
<input value="0" type="number" id="seed" name="seed"> <input value="0" type="number" id="seed" name="seed">
@ -76,29 +101,52 @@
<input type="checkbox" name="progress_images" id="progress_images"> <input type="checkbox" name="progress_images" id="progress_images">
<label for="progress_images">Display in-progress images (slower)</label> <label for="progress_images">Display in-progress images (slower)</label>
<button type="button" id="reset-all">Reset to Defaults</button> <button type="button" id="reset-all">Reset to Defaults</button>
<span id="variations"> <div id="variations">
<label title="If > 0, generates variations on the initial seed instead of random seeds per iteration. Must be between 0 and 1. Higher values will be more different." for="variation_amount">Variation amount (0 to disable):</label> <label
title="If > 0, generates variations on the initial seed instead of random seeds per iteration. Must be between 0 and 1. Higher values will be more different."
for="variation_amount">Variation amount (0 to disable):</label>
<input value="0" type="number" id="variation_amount" name="variation_amount" step="0.01" min="0" max="1"> <input value="0" type="number" id="variation_amount" name="variation_amount" step="0.01" min="0" max="1">
<label title="list of variations to apply, in the format `seed:weight,seed:weight,..." for="with_variations">With variations (seed:weight,seed:weight,...):</label> <label title="list of variations to apply, in the format `seed:weight,seed:weight,..."
for="with_variations">With variations (seed:weight,seed:weight,...):</label>
<input value="" type="text" id="with_variations" name="with_variations"> <input value="" type="text" id="with_variations" name="with_variations">
</span> </div>
</fieldset> </fieldset>
<fieldset id="img2img"> <fieldset id="initimg">
<div class="section-header">Image-to-image options</div> <legend>
<input type="checkbox" name="enable_init_image" id="enable_init_image" checked>
<label for="enable_init_image">Enable init image</label>
</legend>
<div>
<label title="Upload an image to use img2img" for="initimg">Initial image:</label> <label title="Upload an image to use img2img" for="initimg">Initial image:</label>
<input type="file" id="initimg" name="initimg" accept=".jpg, .jpeg, .png"> <input type="file" id="initimg" name="initimg" accept=".jpg, .jpeg, .png">
<button type="button" id="remove-image">Remove Image</button> <button type="button" id="remove-image">Remove Image</button>
<br> </div>
<fieldset id="img2img">
<legend>
<input type="checkbox" name="enable_img2img" id="enable_img2img" checked>
<label for="enable_img2img">Enable Img2Img</label>
</legend>
<label for="strength">Img2Img Strength:</label> <label for="strength">Img2Img Strength:</label>
<input value="0.75" type="number" id="strength" name="strength" step="0.01" min="0" max="1"> <input value="0.75" type="number" id="strength" name="strength" step="0.01" min="0" max="1">
<input type="checkbox" id="fit" name="fit" checked> <input type="checkbox" id="fit" name="fit" checked>
<label title="Rescale image to fit within requested width and height" for="fit">Fit to width/height</label> <label title="Rescale image to fit within requested width and height" for="fit">Fit to width/height:</label>
</fieldset> </fieldset>
</fieldset>
<div id="postprocess">
<fieldset id="gfpgan"> <fieldset id="gfpgan">
<div class="section-header">Post-processing options</div> <legend>
<label title="Strength of the gfpgan (face fixing) algorithm." for="gfpgan_strength">GPFGAN Strength (0 to disable):</label> <input type="checkbox" name="enable_gfpgan" id="enable_gfpgan">
<input value="0.0" min="0" max="1" type="number" id="gfpgan_strength" name="gfpgan_strength" step="0.1"> <label for="enable_gfpgan">Enable gfpgan</label>
<label title="Upscaling to perform using ESRGAN." for="upscale_level">Upscaling Level</label> </legend>
<label title="Strength of the gfpgan (face fixing) algorithm." for="gfpgan_strength">GPFGAN Strength:</label>
<input value="0.8" min="0" max="1" type="number" id="gfpgan_strength" name="gfpgan_strength" step="0.05">
</fieldset>
<fieldset id="upscale">
<legend>
<input type="checkbox" name="enable_upscale" id="enable_upscale">
<label for="enable_upscale">Enable Upscaling</label>
</legend>
<label title="Upscaling to perform using ESRGAN." for="upscale_level">Upscaling Level:</label>
<select id="upscale_level" name="upscale_level" value=""> <select id="upscale_level" name="upscale_level" value="">
<option value="" selected>None</option> <option value="" selected>None</option>
<option value="2">2x</option> <option value="2">2x</option>
@ -107,6 +155,8 @@
<label title="Strength of the esrgan (upscaling) algorithm." for="upscale_strength">Upscale Strength:</label> <label title="Strength of the esrgan (upscaling) algorithm." for="upscale_strength">Upscale Strength:</label>
<input value="0.75" min="0" max="1" type="number" id="upscale_strength" name="upscale_strength" step="0.05"> <input value="0.75" min="0" max="1" type="number" id="upscale_strength" name="upscale_strength" step="0.05">
</fieldset> </fieldset>
</div>
<input type="submit" id="submit" value="Generate">
</form> </form>
<br> <br>
<section id="progress-section"> <section id="progress-section">
@ -118,14 +168,12 @@
<div id="scaling-inprocess-message"> <div id="scaling-inprocess-message">
<i><span>Postprocessing...</span><span id="processing_cnt">1</span>/<span id="processing_total">3</span></i> <i><span>Postprocessing...</span><span id="processing_cnt">1</span>/<span id="processing_total">3</span></i>
</div> </div>
</span> </div>
</section> </section>
<div id="results"> <div id="results">
<div id="no-results-message">
<i><p>No results...</p></i>
</div>
</div> </div>
</main> </main>
</body> </body>
</html> </html>

View File

@ -1,5 +1,73 @@
const socket = io(); const socket = io();
var priorResultsLoadState = {
page: 0,
pages: 1,
per_page: 10,
total: 20,
offset: 0, // number of items generated since last load
loading: false,
initialized: false
};
function loadPriorResults() {
// Fix next page by offset
let offsetPages = priorResultsLoadState.offset / priorResultsLoadState.per_page;
priorResultsLoadState.page += offsetPages;
priorResultsLoadState.pages += offsetPages;
priorResultsLoadState.total += priorResultsLoadState.offset;
priorResultsLoadState.offset = 0;
if (priorResultsLoadState.loading) {
return;
}
if (priorResultsLoadState.page >= priorResultsLoadState.pages) {
return; // Nothing more to load
}
// Load
priorResultsLoadState.loading = true
let url = new URL('/api/images', document.baseURI);
url.searchParams.append('page', priorResultsLoadState.initialized ? priorResultsLoadState.page + 1 : priorResultsLoadState.page);
url.searchParams.append('per_page', priorResultsLoadState.per_page);
fetch(url.href, {
method: 'GET',
headers: new Headers({'content-type': 'application/json'})
})
.then(response => response.json())
.then(data => {
priorResultsLoadState.page = data.page;
priorResultsLoadState.pages = data.pages;
priorResultsLoadState.per_page = data.per_page;
priorResultsLoadState.total = data.total;
data.items.forEach(function(dreamId, index) {
let src = 'api/images/' + dreamId;
fetch('/api/images/' + dreamId + '/metadata', {
method: 'GET',
headers: new Headers({'content-type': 'application/json'})
})
.then(response => response.json())
.then(metadata => {
let seed = metadata.seed || 0; // TODO: Parse old metadata
appendOutput(src, seed, metadata, true);
});
});
// Load until page is full
if (!priorResultsLoadState.initialized) {
if (document.body.scrollHeight <= window.innerHeight) {
loadPriorResults();
}
}
})
.finally(() => {
priorResultsLoadState.loading = false;
priorResultsLoadState.initialized = true;
});
}
function resetForm() { function resetForm() {
var form = document.getElementById('generate-form'); var form = document.getElementById('generate-form');
form.querySelector('fieldset').removeAttribute('disabled'); form.querySelector('fieldset').removeAttribute('disabled');
@ -45,48 +113,64 @@ function toBase64(file) {
}); });
} }
function appendOutput(src, seed, config) { function ondragdream(event) {
let outputNode = document.createElement("figure"); let dream = event.target.dataset.dream;
let altText = seed.toString() + " | " + config.prompt; event.dataTransfer.setData("dream", dream);
}
const figureContents = ` function seedClick(event) {
<a href="${src}" target="_blank"> // Get element
<img src="${src}" alt="${altText}" title="${altText}"> var image = event.target.closest('figure').querySelector('img');
</a> var dream = JSON.parse(decodeURIComponent(image.dataset.dream));
<figcaption>${seed}</figcaption>
`;
outputNode.innerHTML = figureContents;
let figcaption = outputNode.querySelector('figcaption')
// Reload image config
figcaption.addEventListener('click', () => {
let form = document.querySelector("#generate-form"); let form = document.querySelector("#generate-form");
for (const [k, v] of new FormData(form)) { for (const [k, v] of new FormData(form)) {
if (k == 'initimg') { continue; } if (k == 'initimg') { continue; }
form.querySelector(`*[name=${k}]`).value = config[k]; let formElem = form.querySelector(`*[name=${k}]`);
} formElem.value = dream[k] !== undefined ? dream[k] : formElem.defaultValue;
if (config.variation_amount > 0 || config.with_variations != '') {
document.querySelector("#seed").value = config.seed;
} else {
document.querySelector("#seed").value = seed;
} }
if (config.variation_amount > 0) { document.querySelector("#seed").value = dream.seed;
let oldVarAmt = document.querySelector("#variation_amount").value document.querySelector('#iterations').value = 1; // Reset to 1 iteration since we clicked a single image (not a full job)
let oldVariations = document.querySelector("#with_variations").value
let varSep = '' // NOTE: leaving this manual for the user for now - it was very confusing with this behavior
document.querySelector("#variation_amount").value = 0; // document.querySelector("#with_variations").value = variations || '';
if (document.querySelector("#with_variations").value != '') { // if (document.querySelector("#variation_amount").value <= 0) {
varSep = "," // document.querySelector("#variation_amount").value = 0.2;
} // }
document.querySelector("#with_variations").value = oldVariations + varSep + seed + ':' + config.variation_amount
}
saveFields(document.querySelector("#generate-form")); saveFields(document.querySelector("#generate-form"));
}); }
function appendOutput(src, seed, config, toEnd=false) {
let outputNode = document.createElement("figure");
let altText = seed.toString() + " | " + config.prompt;
// img needs width and height for lazy loading to work
// TODO: store the full config in a data attribute on the image?
const figureContents = `
<a href="${src}" target="_blank">
<img src="${src}"
alt="${altText}"
title="${altText}"
loading="lazy"
width="256"
height="256"
draggable="true"
ondragstart="ondragdream(event, this)"
data-dream="${encodeURIComponent(JSON.stringify(config))}"
data-dreamId="${encodeURIComponent(config.dreamId)}">
</a>
<figcaption onclick="seedClick(event, this)">${seed}</figcaption>
`;
outputNode.innerHTML = figureContents;
if (toEnd) {
document.querySelector("#results").append(outputNode);
} else {
document.querySelector("#results").prepend(outputNode); document.querySelector("#results").prepend(outputNode);
}
document.querySelector("#no-results-message")?.remove(); document.querySelector("#no-results-message")?.remove();
} }
@ -119,14 +203,33 @@ async function generateSubmit(form) {
// Convert file data to base64 // Convert file data to base64
// TODO: Should probably uplaod files with formdata or something, and store them in the backend? // TODO: Should probably uplaod files with formdata or something, and store them in the backend?
let formData = Object.fromEntries(new FormData(form)); let formData = Object.fromEntries(new FormData(form));
if (!formData.enable_generate && !formData.enable_init_image) {
gen_label = document.querySelector("label[for=enable_generate]").innerHTML;
initimg_label = document.querySelector("label[for=enable_init_image]").innerHTML;
alert(`Error: one of "${gen_label}" or "${initimg_label}" must be set`);
}
formData.initimg_name = formData.initimg.name formData.initimg_name = formData.initimg.name
formData.initimg = formData.initimg.name !== '' ? await toBase64(formData.initimg) : null; formData.initimg = formData.initimg.name !== '' ? await toBase64(formData.initimg) : null;
// Evaluate all checkboxes
let checkboxes = form.querySelectorAll('input[type=checkbox]');
checkboxes.forEach(function (checkbox) {
if (checkbox.checked) {
formData[checkbox.name] = 'true';
}
});
let strength = formData.strength; let strength = formData.strength;
let totalSteps = formData.initimg ? Math.floor(strength * formData.steps) : formData.steps; let totalSteps = formData.initimg ? Math.floor(strength * formData.steps) : formData.steps;
let showProgressImages = formData.progress_images;
// Set enabling flags
// Initialize the progress bar // Initialize the progress bar
initProgress(totalSteps); initProgress(totalSteps, showProgressImages);
// POST, use response to listen for events // POST, use response to listen for events
fetch(form.action, { fetch(form.action, {
@ -136,13 +239,19 @@ async function generateSubmit(form) {
}) })
.then(response => response.json()) .then(response => response.json())
.then(data => { .then(data => {
var dreamId = data.dreamId; var jobId = data.jobId;
socket.emit('join_room', { 'room': dreamId }); socket.emit('join_room', { 'room': jobId });
}); });
form.querySelector('fieldset').setAttribute('disabled',''); form.querySelector('fieldset').setAttribute('disabled','');
} }
function fieldSetEnableChecked(event) {
cb = event.target;
fields = cb.closest('fieldset');
fields.disabled = !cb.checked;
}
// Socket listeners // Socket listeners
socket.on('job_started', (data) => {}) socket.on('job_started', (data) => {})
@ -152,6 +261,7 @@ socket.on('dream_result', (data) => {
var dreamRequest = data.dreamRequest; var dreamRequest = data.dreamRequest;
var src = 'api/images/' + dreamId; var src = 'api/images/' + dreamId;
priorResultsLoadState.offset += 1;
appendOutput(src, dreamRequest.seed, dreamRequest); appendOutput(src, dreamRequest.seed, dreamRequest);
resetProgress(false); resetProgress(false);
@ -193,7 +303,13 @@ socket.on('job_done', (data) => {
resetProgress(); resetProgress();
}) })
window.onload = () => { window.onload = async () => {
document.querySelector("#prompt").addEventListener("keydown", (e) => {
if (e.key === "Enter" && !e.shiftKey) {
const form = e.target.form;
generateSubmit(form);
}
});
document.querySelector("#generate-form").addEventListener('submit', (e) => { document.querySelector("#generate-form").addEventListener('submit', (e) => {
e.preventDefault(); e.preventDefault();
const form = e.target; const form = e.target;
@ -216,12 +332,65 @@ window.onload = () => {
loadFields(document.querySelector("#generate-form")); loadFields(document.querySelector("#generate-form"));
document.querySelector('#cancel-button').addEventListener('click', () => { document.querySelector('#cancel-button').addEventListener('click', () => {
fetch('/cancel').catch(e => { fetch('/api/cancel').catch(e => {
console.error(e); console.error(e);
}); });
}); });
document.documentElement.addEventListener('keydown', (e) => {
if (e.key === "Escape")
fetch('/api/cancel').catch(err => {
console.error(err);
});
});
if (!config.gfpgan_model_exists) { if (!config.gfpgan_model_exists) {
document.querySelector("#gfpgan").style.display = 'none'; document.querySelector("#gfpgan").style.display = 'none';
} }
window.addEventListener("scroll", () => {
if ((window.innerHeight + window.pageYOffset) >= document.body.offsetHeight) {
loadPriorResults();
}
});
// Enable/disable forms by checkboxes
document.querySelectorAll("legend > input[type=checkbox]").forEach(function(cb) {
cb.addEventListener('change', fieldSetEnableChecked);
fieldSetEnableChecked({ target: cb})
});
// Load some of the previous results
loadPriorResults();
// Image drop/upload WIP
/*
let drop = document.getElementById('dropper');
function ondrop(event) {
let dreamData = event.dataTransfer.getData('dream');
if (dreamData) {
var dream = JSON.parse(decodeURIComponent(dreamData));
alert(dream.dreamId);
}
};
function ondragenter(event) {
event.preventDefault();
};
function ondragover(event) {
event.preventDefault();
};
function ondragleave(event) {
}
drop.addEventListener('drop', ondrop);
drop.addEventListener('dragenter', ondragenter);
drop.addEventListener('dragover', ondragover);
drop.addEventListener('dragleave', ondragleave);
*/
}; };