diff --git a/.gitignore b/.gitignore index e8d7c1f189..7e4dd4bea9 100644 --- a/.gitignore +++ b/.gitignore @@ -191,3 +191,4 @@ checkpoints .scratch/ .vscode/ gfpgan/ +models/ldm/stable-diffusion-v1/model.sha256 diff --git a/docs/other/CONTRIBUTORS.md b/docs/other/CONTRIBUTORS.md index 78f33d93d1..948795d3f2 100644 --- a/docs/other/CONTRIBUTORS.md +++ b/docs/other/CONTRIBUTORS.md @@ -51,6 +51,7 @@ We thank them for all of their time and hard work. - [Any Winter](https://github.com/any-winter-4079) - [Doggettx](https://github.com/doggettx) - [Matthias Wild](https://github.com/mauwii) +- [Kyle Schouviller](https://github.com/kyle0654) ## __Original CompVis Authors:__ diff --git a/ldm/dream/pngwriter.py b/ldm/dream/pngwriter.py index ecbc3c0e15..5cda259357 100644 --- a/ldm/dream/pngwriter.py +++ b/ldm/dream/pngwriter.py @@ -33,11 +33,12 @@ class PngWriter: # saves image named _image_ to outdir/name, writing metadata from prompt # returns full path of output - def save_image_and_prompt_to_png(self, image, dream_prompt, metadata, name): + def save_image_and_prompt_to_png(self, image, dream_prompt, name, metadata=None): path = os.path.join(self.outdir, name) info = PngImagePlugin.PngInfo() info.add_text('Dream', dream_prompt) - info.add_text('sd-metadata', json.dumps(metadata)) + if metadata: # TODO: merge command line app's method of writing metadata and always just write metadata + info.add_text('sd-metadata', json.dumps(metadata)) image.save(path, 'PNG', pnginfo=info) return path diff --git a/ldm/dream/server.py b/ldm/dream/server.py index 003ec70533..cde3957a1f 100644 --- a/ldm/dream/server.py +++ b/ldm/dream/server.py @@ -230,7 +230,7 @@ class DreamServer(BaseHTTPRequestHandler): image = self.model.sample_to_image(sample) name = f'{prefix}.{opt.seed}.{step_index}.png' metadata = f'{opt.prompt} -S{opt.seed} [intermediate]' - path = step_writer.save_image_and_prompt_to_png(image, metadata, name) + path = step_writer.save_image_and_prompt_to_png(image, dream_prompt=metadata, name=name) step_index += 1 self.wfile.write(bytes(json.dumps( {'event': 'step', 'step': step + 1, 'url': path} diff --git a/ldm/generate.py b/ldm/generate.py index 6b5cb3d794..1b3c8544e0 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -181,7 +181,7 @@ class Generate: for image, seed in results: name = f'{prefix}.{seed}.png' path = pngwriter.save_image_and_prompt_to_png( - image, f'{prompt} -S{seed}', name) + image, dream_prompt=f'{prompt} -S{seed}', name=name) outputs.append([path, seed]) return outputs diff --git a/requirements.txt b/requirements.txt index 2007ca4caf..efc55b7971 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,6 +22,11 @@ test-tube torch-fidelity torchmetrics transformers +flask==2.1.3 +flask_socketio==5.3.0 +flask_cors==3.0.10 +dependency_injector==4.40.0 +eventlet git+https://github.com/openai/CLIP.git@main#egg=clip git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k-diffusion git+https://github.com/lstein/GFPGAN@fix-dark-cast-images#egg=gfpgan diff --git a/server/application.py b/server/application.py index 2e8d77ce0f..2501f4b63d 100644 --- a/server/application.py +++ b/server/application.py @@ -7,9 +7,10 @@ import os import sys from flask import Flask from flask_cors import CORS -from flask_socketio import SocketIO, join_room, leave_room +from flask_socketio import SocketIO from omegaconf import OmegaConf from dependency_injector.wiring import inject, Provide +from ldm.dream.args import Args from server import views from server.containers import Container from server.services import GeneratorService, SignalService @@ -58,6 +59,8 @@ def run_app(config, host, port) -> Flask: # TODO: Get storage root from config app.add_url_rule('/api/images/', view_func=views.ApiImages.as_view('api_images', '../')) + app.add_url_rule('/api/images//metadata', view_func=views.ApiImagesMetadata.as_view('api_images_metadata', '../')) + app.add_url_rule('/api/images', view_func=views.ApiImagesList.as_view('api_images_list')) app.add_url_rule('/api/intermediates//', view_func=views.ApiIntermediates.as_view('api_intermediates', '../')) app.static_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../static/dream_web/')) @@ -79,30 +82,28 @@ def run_app(config, host, port) -> Flask: def main(): """Initialize command-line parsers and the diffusion model""" - from scripts.dream import create_argv_parser - arg_parser = create_argv_parser() + arg_parser = Args() opt = arg_parser.parse_args() if opt.laion400m: - print('--laion400m flag has been deprecated. Please use --model laion400m instead.') - sys.exit(-1) - if opt.weights != 'model': - print('--weights argument has been deprecated. Please configure ./configs/models.yaml, and call it using --model instead.') - sys.exit(-1) + print('--laion400m flag has been deprecated. Please use --model laion400m instead.') + sys.exit(-1) + if opt.weights: + print('--weights argument has been deprecated. Please edit ./configs/models.yaml, and select the weights using --model instead.') + sys.exit(-1) - try: - models = OmegaConf.load(opt.config) - width = models[opt.model].width - height = models[opt.model].height - config = models[opt.model].config - weights = models[opt.model].weights - except (FileNotFoundError, IOError, KeyError) as e: - print(f'{e}. Aborting.') - sys.exit(-1) + # try: + # models = OmegaConf.load(opt.config) + # width = models[opt.model].width + # height = models[opt.model].height + # config = models[opt.model].config + # weights = models[opt.model].weights + # except (FileNotFoundError, IOError, KeyError) as e: + # print(f'{e}. Aborting.') + # sys.exit(-1) - print('* Initializing, be patient...\n') + #print('* Initializing, be patient...\n') sys.path.append('.') - from pytorch_lightning import logging # these two lines prevent a horrible warning message from appearing # when the frozen CLIP tokenizer is imported @@ -110,26 +111,28 @@ def main(): transformers.logging.set_verbosity_error() - appConfig = { - "model": { - "width": width, - "height": height, - "sampler_name": opt.sampler_name, - "weights": weights, - "full_precision": opt.full_precision, - "config": config, - "grid": opt.grid, - "latent_diffusion_weights": opt.laion400m, - "embedding_path": opt.embedding_path, - "device_type": opt.device - } - } + appConfig = opt.__dict__ + + # appConfig = { + # "model": { + # "width": width, + # "height": height, + # "sampler_name": opt.sampler_name, + # "weights": weights, + # "full_precision": opt.full_precision, + # "config": config, + # "grid": opt.grid, + # "latent_diffusion_weights": opt.laion400m, + # "embedding_path": opt.embedding_path + # } + # } # make sure the output directory exists if not os.path.exists(opt.outdir): os.makedirs(opt.outdir) # gets rid of annoying messages about random seed + from pytorch_lightning import logging logging.getLogger('pytorch_lightning').setLevel(logging.ERROR) print('\n* starting api server...') diff --git a/server/containers.py b/server/containers.py index 08ef01c4b6..a3318c5ff0 100644 --- a/server/containers.py +++ b/server/containers.py @@ -17,18 +17,24 @@ class Container(containers.DeclarativeContainer): app = None ) + # TODO: Add a model provider service that provides model(s) dynamically model_singleton = providers.ThreadSafeSingleton( Generate, - width = config.model.width, - height = config.model.height, - sampler_name = config.model.sampler_name, - weights = config.model.weights, - full_precision = config.model.full_precision, - config = config.model.config, - grid = config.model.grid, - seamless = config.model.seamless, - embedding_path = config.model.embedding_path, - device_type = config.model.device_type + model = config.model, + sampler_name = config.sampler_name, + embedding_path = config.embedding_path, + full_precision = config.full_precision + # config = config.model.config, + + # width = config.model.width, + # height = config.model.height, + # sampler_name = config.model.sampler_name, + # weights = config.model.weights, + # full_precision = config.model.full_precision, + # grid = config.model.grid, + # seamless = config.model.seamless, + # embedding_path = config.model.embedding_path, + # device_type = config.model.device_type ) # TODO: get location from config diff --git a/server/models.py b/server/models.py index 17c6d0dfe4..fc4a5f41c4 100644 --- a/server/models.py +++ b/server/models.py @@ -1,77 +1,182 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) +from base64 import urlsafe_b64encode import json import string from copy import deepcopy from datetime import datetime, timezone from enum import Enum +from typing import Any, Dict, List, Union +from uuid import uuid4 -class DreamRequest(): - prompt: string - initimg: string - strength: float - iterations: int - steps: int - width: int - height: int - fit = None - cfgscale: float - sampler_name: string - gfpgan_strength: float - upscale_level: int - upscale_strength: float + +class DreamBase(): + # Id + id: str + + # Initial Image + enable_init_image: bool + initimg: string = None + + # Img2Img + enable_img2img: bool # TODO: support this better + strength: float = 0 # TODO: name this something related to img2img to make it clearer? + fit = None # Fit initial image dimensions + + # Generation + enable_generate: bool + prompt: string = "" + seed: int = 0 # 0 is random + steps: int = 10 + width: int = 512 + height: int = 512 + cfg_scale: float = 7.5 + sampler_name: string = 'klms' + seamless: bool = False + model: str = None # The model to use (currently unused) + embeddings = None # The embeddings to use (currently unused) + progress_images: bool = False + + # GFPGAN + enable_gfpgan: bool + gfpgan_strength: float = 0 + + # Upscale + enable_upscale: bool upscale: None - progress_images = None - seed: int + upscale_level: int = None + upscale_strength: float = 0.75 + + # Embiggen + enable_embiggen: bool + embiggen: Union[None, List[float]] = None + embiggen_tiles: Union[None, List[int]] = None + + # Metadata time: int + def __init__(self): + self.id = urlsafe_b64encode(uuid4().bytes).decode('ascii') + + def parse_json(self, j, new_instance=False): + # Id + if 'id' in j and not new_instance: + self.id = j.get('id') + + # Initial Image + self.enable_init_image = 'enable_init_image' in j and bool(j.get('enable_init_image')) + if self.enable_init_image: + self.initimg = j.get('initimg') + + # Img2Img + self.enable_img2img = 'enable_img2img' in j and bool(j.get('enable_img2img')) + if self.enable_img2img: + self.strength = float(j.get('strength')) + self.fit = 'fit' in j + + # Generation + self.enable_generate = 'enable_generate' in j and bool(j.get('enable_generate')) + if self.enable_generate: + self.prompt = j.get('prompt') + self.seed = int(j.get('seed')) + self.steps = int(j.get('steps')) + self.width = int(j.get('width')) + self.height = int(j.get('height')) + self.cfg_scale = float(j.get('cfgscale') or j.get('cfg_scale')) + self.sampler_name = j.get('sampler') or j.get('sampler_name') + # model: str = None # The model to use (currently unused) + # embeddings = None # The embeddings to use (currently unused) + self.seamless = 'seamless' in j + self.progress_images = 'progress_images' in j + + # GFPGAN + self.enable_gfpgan = 'enable_gfpgan' in j and bool(j.get('enable_gfpgan')) + if self.enable_gfpgan: + self.gfpgan_strength = float(j.get('gfpgan_strength')) + + # Upscale + self.enable_upscale = 'enable_upscale' in j and bool(j.get('enable_upscale')) + if self.enable_upscale: + self.upscale_level = j.get('upscale_level') + self.upscale_strength = j.get('upscale_strength') + self.upscale = None if self.upscale_level in {None,''} else [int(self.upscale_level),float(self.upscale_strength)] + + # Embiggen + self.enable_embiggen = 'enable_embiggen' in j and bool(j.get('enable_embiggen')) + if self.enable_embiggen: + self.embiggen = j.get('embiggen') + self.embiggen_tiles = j.get('embiggen_tiles') + + # Metadata + self.time = int(j.get('time')) if ('time' in j and not new_instance) else int(datetime.now(timezone.utc).timestamp()) + + +class DreamResult(DreamBase): + # Result + has_upscaled: False + has_gfpgan: False + # TODO: use something else for state tracking images_generated: int = 0 images_upscaled: int = 0 - def id(self, seed = None, upscaled = False) -> str: - return f"{self.time}.{seed or self.seed}{'.u' if upscaled else ''}" + def __init__(self): + super().__init__() - # TODO: handle this more cleanly (probably by splitting this into a Job and Result class) - # TODO: Set iterations to 1 or remove it from the dream result? And just keep it on the job? - def clone_without_image(self, seed = None): - data = deepcopy(self) - data.initimg = None - if seed: - data.seed = seed + def clone_without_img(self): + copy = deepcopy(self) + copy.initimg = None + return copy - return data - - def to_json(self, seed: int = None): - copy = self.clone_without_image(seed) - return json.dumps(copy.__dict__) + def to_json(self): + copy = deepcopy(self) + copy.initimg = None + j = json.dumps(copy.__dict__) + return j @staticmethod def from_json(j, newTime: bool = False): - d = DreamRequest() - d.prompt = j.get('prompt') - d.initimg = j.get('initimg') - d.strength = float(j.get('strength')) - d.iterations = int(j.get('iterations')) - d.steps = int(j.get('steps')) - d.width = int(j.get('width')) - d.height = int(j.get('height')) - d.fit = 'fit' in j - d.seamless = 'seamless' in j - d.cfgscale = float(j.get('cfgscale')) - d.sampler_name = j.get('sampler') - d.variation_amount = float(j.get('variation_amount')) - d.with_variations = j.get('with_variations') - d.gfpgan_strength = float(j.get('gfpgan_strength')) - d.upscale_level = j.get('upscale_level') - d.upscale_strength = j.get('upscale_strength') - d.upscale = [int(d.upscale_level),float(d.upscale_strength)] if d.upscale_level != '' else None - d.progress_images = 'progress_images' in j - d.seed = int(j.get('seed')) - d.time = int(datetime.now(timezone.utc).timestamp()) if newTime else int(j.get('time')) + d = DreamResult() + d.parse_json(j) return d +# TODO: switch this to a pipelined request, with pluggable steps +# Will likely require generator code changes to accomplish +class JobRequest(DreamBase): + # Iteration + iterations: int = 1 + variation_amount = None + with_variations = None + + # Results + results: List[DreamResult] = [] + + def __init__(self): + super().__init__() + + def newDreamResult(self) -> DreamResult: + result = DreamResult() + result.parse_json(self.__dict__, new_instance=True) + return result + + @staticmethod + def from_json(j): + job = JobRequest() + job.parse_json(j) + + # Metadata + job.time = int(j.get('time')) if ('time' in j) else int(datetime.now(timezone.utc).timestamp()) + + # Iteration + if job.enable_generate: + job.iterations = int(j.get('iterations')) + job.variation_amount = float(j.get('variation_amount')) + job.with_variations = j.get('with_variations') + + return job + + class ProgressType(Enum): GENERATION = 1 UPSCALING_STARTED = 2 @@ -102,11 +207,11 @@ class Signal(): # TODO: use a result id or something? Like a sub-job @staticmethod - def image_result(jobId: str, dreamId: str, dreamRequest: DreamRequest): + def image_result(jobId: str, dreamId: str, dreamResult: DreamResult): return Signal('dream_result', { 'jobId': jobId, 'dreamId': dreamId, - 'dreamRequest': dreamRequest.__dict__ + 'dreamRequest': dreamResult.clone_without_img().__dict__ }, room=jobId, broadcast=True) @staticmethod @@ -126,3 +231,21 @@ class Signal(): return Signal('job_canceled', { 'jobId': jobId }, room=jobId, broadcast=True) + + +class PaginatedItems(): + items: List[Any] + page: int # Current Page + pages: int # Total number of pages + per_page: int # Number of items per page + total: int # Total number of items in result + + def __init__(self, items: List[Any], page: int, pages: int, per_page: int, total: int): + self.items = items + self.page = page + self.pages = pages + self.per_page = per_page + self.total = total + + def to_json(self): + return json.dumps(self.__dict__) diff --git a/server/services.py b/server/services.py index 0b53cc9141..444f47cccf 100644 --- a/server/services.py +++ b/server/services.py @@ -1,25 +1,33 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) +from argparse import ArgumentParser import base64 +from datetime import datetime, timezone +import glob +import json import os +from pathlib import Path from queue import Empty, Queue +import shlex from threading import Thread import time -from flask import app, url_for from flask_socketio import SocketIO, join_room, leave_room +from ldm.dream.args import Args +from ldm.dream.generator import embiggen +from PIL import Image from ldm.dream.pngwriter import PngWriter from ldm.dream.server import CanceledException from ldm.generate import Generate -from server.models import DreamRequest, ProgressType, Signal +from server.models import DreamResult, JobRequest, PaginatedItems, ProgressType, Signal class JobQueueService: __queue: Queue = Queue() - def push(self, dreamRequest: DreamRequest): + def push(self, dreamRequest: DreamResult): self.__queue.put(dreamRequest) - def get(self, timeout: float = None) -> DreamRequest: + def get(self, timeout: float = None) -> DreamResult: return self.__queue.get(timeout= timeout) class SignalQueueService: @@ -85,31 +93,116 @@ class LogService: self.__location = location self.__logFile = file - def log(self, dreamRequest: DreamRequest, seed = None, upscaled = False): + def log(self, dreamResult: DreamResult, seed = None, upscaled = False): with open(os.path.join(self.__location, self.__logFile), "a") as log: - log.write(f"{dreamRequest.id(seed, upscaled)}: {dreamRequest.to_json(seed)}\n") + log.write(f"{dreamResult.id}: {dreamResult.to_json()}\n") class ImageStorageService: __location: str __pngWriter: PngWriter + __legacyParser: ArgumentParser def __init__(self, location): self.__location = location self.__pngWriter = PngWriter(self.__location) + self.__legacyParser = Args() # TODO: inject this? def __getName(self, dreamId: str, postfix: str = '') -> str: return f'{dreamId}{postfix}.png' - def save(self, image, dreamRequest, seed = None, upscaled = False, postfix: str = '', metadataPostfix: str = '') -> str: - name = self.__getName(dreamRequest.id(seed, upscaled), postfix) - path = self.__pngWriter.save_image_and_prompt_to_png(image, f'{dreamRequest.prompt} -S{seed or dreamRequest.seed}{metadataPostfix}', name) + def save(self, image, dreamResult: DreamResult, postfix: str = '') -> str: + name = self.__getName(dreamResult.id, postfix) + meta = dreamResult.to_json() # TODO: make all methods consistent with writing metadata. Standardize metadata. + path = self.__pngWriter.save_image_and_prompt_to_png(image, dream_prompt=meta, metadata=None, name=name) return path def path(self, dreamId: str, postfix: str = '') -> str: name = self.__getName(dreamId, postfix) path = os.path.join(self.__location, name) return path + + # Returns true if found, false if not found or error + def delete(self, dreamId: str, postfix: str = '') -> bool: + path = self.path(dreamId, postfix) + if (os.path.exists(path)): + os.remove(path) + return True + else: + return False + + def getMetadata(self, dreamId: str, postfix: str = '') -> DreamResult: + path = self.path(dreamId, postfix) + image = Image.open(path) + text = image.text + if text.__contains__('Dream'): + dreamMeta = text.get('Dream') + try: + j = json.loads(dreamMeta) + return DreamResult.from_json(j) + except ValueError: + # Try to parse command-line format (legacy metadata format) + try: + opt = self.__parseLegacyMetadata(dreamMeta) + optd = opt.__dict__ + if (not 'width' in optd) or (optd.get('width') is None): + optd['width'] = image.width + if (not 'height' in optd) or (optd.get('height') is None): + optd['height'] = image.height + if (not 'steps' in optd) or (optd.get('steps') is None): + optd['steps'] = 10 # No way around this unfortunately - seems like it wasn't storing this previously + + optd['time'] = os.path.getmtime(path) # Set timestamp manually (won't be exactly correct though) + + return DreamResult.from_json(optd) + + except: + return None + else: + return None + + def __parseLegacyMetadata(self, command: str) -> DreamResult: + # before splitting, escape single quotes so as not to mess + # up the parser + command = command.replace("'", "\\'") + + try: + elements = shlex.split(command) + except ValueError as e: + return None + + # rearrange the arguments to mimic how it works in the Dream bot. + switches = [''] + switches_started = False + + for el in elements: + if el[0] == '-' and not switches_started: + switches_started = True + if switches_started: + switches.append(el) + else: + switches[0] += el + switches[0] += ' ' + switches[0] = switches[0][: len(switches[0]) - 1] + + try: + opt = self.__legacyParser.parse_cmd(switches) + return opt + except SystemExit: + return None + + def list_files(self, page: int, perPage: int) -> PaginatedItems: + files = sorted(glob.glob(os.path.join(self.__location,'*.png')), key=os.path.getmtime, reverse=True) + count = len(files) + + startId = page * perPage + pageCount = int(count / perPage) + 1 + endId = min(startId + perPage, count) + items = [] if startId >= count else files[startId:endId] + + items = list(map(lambda f: Path(f).stem, items)) + + return PaginatedItems(items, page, pageCount, perPage, count) class GeneratorService: @@ -144,13 +237,11 @@ class GeneratorService: # TODO: Consider moving this to its own service if there's benefit in separating the generator def __process(self): # preload the model + # TODO: support multiple models print('Preloading model') - tic = time.time() self.__model.load_model() - print( - f'>> model loaded in', '%4.2fs' % (time.time() - tic) - ) + print(f'>> model loaded in', '%4.2fs' % (time.time() - tic)) print('Started generation queue processor') try: @@ -162,103 +253,136 @@ class GeneratorService: print('Generation queue processor stopped') - def __start(self, dreamRequest: DreamRequest): - if dreamRequest.start_callback: - dreamRequest.start_callback() - self.__signal_service.emit(Signal.job_started(dreamRequest.id())) + def __on_start(self, jobRequest: JobRequest): + self.__signal_service.emit(Signal.job_started(jobRequest.id)) - def __done(self, dreamRequest: DreamRequest, image, seed, upscaled=False): - self.__imageStorage.save(image, dreamRequest, seed, upscaled) + def __on_image_result(self, jobRequest: JobRequest, image, seed, upscaled=False): + dreamResult = jobRequest.newDreamResult() + dreamResult.seed = seed + dreamResult.has_upscaled = upscaled + dreamResult.iterations = 1 + jobRequest.results.append(dreamResult) + # TODO: Separate status of GFPGAN? + + self.__imageStorage.save(image, dreamResult) - # TODO: handle upscaling logic better (this is appending data to log, but only on first generation) if not upscaled: - self.__log.log(dreamRequest, seed, upscaled) + self.__log.log(dreamResult) - self.__signal_service.emit(Signal.image_result(dreamRequest.id(), dreamRequest.id(seed, upscaled), dreamRequest.clone_without_image(seed))) + # Send result signal + self.__signal_service.emit(Signal.image_result(jobRequest.id, dreamResult.id, dreamResult)) - upscaling_requested = dreamRequest.upscale or dreamRequest.gfpgan_strength>0 + upscaling_requested = dreamResult.enable_upscale or dreamResult.enable_gfpgan - if upscaled: - dreamRequest.images_upscaled += 1 - else: - dreamRequest.images_generated +=1 - if upscaling_requested: - # action = None - if dreamRequest.images_generated >= dreamRequest.iterations: - progressType = ProgressType.UPSCALING_DONE - if dreamRequest.images_upscaled < dreamRequest.iterations: - progressType = ProgressType.UPSCALING_STARTED - self.__signal_service.emit(Signal.image_progress(dreamRequest.id(), dreamRequest.id(seed), dreamRequest.images_upscaled+1, dreamRequest.iterations, progressType)) + # Report upscaling status + # TODO: this is very coupled to logic inside the generator. Fix that. + if upscaling_requested and any(result.has_upscaled for result in jobRequest.results): + progressType = ProgressType.UPSCALING_STARTED if len(jobRequest.results) < 2 * jobRequest.iterations else ProgressType.UPSCALING_DONE + upscale_count = sum(1 for i in jobRequest.results if i.has_upscaled) + self.__signal_service.emit(Signal.image_progress(jobRequest.id, dreamResult.id, upscale_count, jobRequest.iterations, progressType)) - def __progress(self, dreamRequest, sample, step): + def __on_progress(self, jobRequest: JobRequest, sample, step): if self.__cancellationRequested: self.__cancellationRequested = False raise CanceledException + # TODO: Progress per request will be easier once the seeds (and ids) can all be pre-generated hasProgressImage = False - if dreamRequest.progress_images and step % 5 == 0 and step < dreamRequest.steps - 1: + s = str(len(jobRequest.results)) + if jobRequest.progress_images and step % 5 == 0 and step < jobRequest.steps - 1: image = self.__model._sample_to_image(sample) - self.__intermediateStorage.save(image, dreamRequest, self.__model.seed, postfix=f'.{step}', metadataPostfix=f' [intermediate]') + + # TODO: clean this up, use a pre-defined dream result + result = DreamResult() + result.parse_json(jobRequest.__dict__, new_instance=False) + self.__intermediateStorage.save(image, result, postfix=f'.{s}.{step}') hasProgressImage = True - self.__signal_service.emit(Signal.image_progress(dreamRequest.id(), dreamRequest.id(self.__model.seed), step, dreamRequest.steps, ProgressType.GENERATION, hasProgressImage)) + self.__signal_service.emit(Signal.image_progress(jobRequest.id, f'{jobRequest.id}.{s}', step, jobRequest.steps, ProgressType.GENERATION, hasProgressImage)) - def __generate(self, dreamRequest: DreamRequest): + def __generate(self, jobRequest: JobRequest): try: - initimgfile = None - if dreamRequest.initimg is not None: - with open("./img2img-tmp.png", "wb") as f: - initimg = dreamRequest.initimg.split(",")[1] # Ignore mime type - f.write(base64.b64decode(initimg)) - initimgfile = "./img2img-tmp.png" + # TODO: handle this file a file service for init images + initimgfile = None # TODO: support this on the model directly? + if (jobRequest.enable_init_image): + if jobRequest.initimg is not None: + with open("./img2img-tmp.png", "wb") as f: + initimg = jobRequest.initimg.split(",")[1] # Ignore mime type + f.write(base64.b64decode(initimg)) + initimgfile = "./img2img-tmp.png" - # Get a random seed if we don't have one yet - # TODO: handle "previous" seed usage? - if dreamRequest.seed == -1: - dreamRequest.seed = self.__model.seed + # Use previous seed if set to -1 + initSeed = jobRequest.seed + if initSeed == -1: + initSeed = self.__model.seed # Zero gfpgan strength if the model doesn't exist # TODO: determine if this could be at the top now? Used to cause circular import from ldm.gfpgan.gfpgan_tools import gfpgan_model_exists if not gfpgan_model_exists: - dreamRequest.gfpgan_strength = 0 + jobRequest.enable_gfpgan = False - self.__start(dreamRequest) + # Signal start + self.__on_start(jobRequest) - self.__model.prompt2image( - prompt = dreamRequest.prompt, - init_img = initimgfile, # TODO: ensure this works - strength = None if initimgfile is None else dreamRequest.strength, - fit = None if initimgfile is None else dreamRequest.fit, - iterations = dreamRequest.iterations, - cfg_scale = dreamRequest.cfgscale, - width = dreamRequest.width, - height = dreamRequest.height, - seed = dreamRequest.seed, - steps = dreamRequest.steps, - variation_amount = dreamRequest.variation_amount, - with_variations = dreamRequest.with_variations, - gfpgan_strength = dreamRequest.gfpgan_strength, - upscale = dreamRequest.upscale, - sampler_name = dreamRequest.sampler_name, - seamless = dreamRequest.seamless, - step_callback = lambda sample, step: self.__progress(dreamRequest, sample, step), - image_callback = lambda image, seed, upscaled=False: self.__done(dreamRequest, image, seed, upscaled)) + # Generate in model + # TODO: Split job generation requests instead of fitting all parameters here + # TODO: Support no generation (just upscaling/gfpgan) + + upscale = None if not jobRequest.enable_upscale else jobRequest.upscale + gfpgan_strength = 0 if not jobRequest.enable_gfpgan else jobRequest.gfpgan_strength + + if not jobRequest.enable_generate: + # If not generating, check if we're upscaling or running gfpgan + if not upscale and not gfpgan_strength: + # Invalid settings (TODO: Add message to help user) + raise CanceledException() + + image = Image.open(initimgfile) + # TODO: support progress for upscale? + self.__model.upscale_and_reconstruct( + image_list = [[image,0]], + upscale = upscale, + strength = gfpgan_strength, + save_original = False, + image_callback = lambda image, seed, upscaled=False: self.__on_image_result(jobRequest, image, seed, upscaled)) + + else: + # Generating - run the generation + init_img = None if (not jobRequest.enable_img2img or jobRequest.strength == 0) else initimgfile + + + self.__model.prompt2image( + prompt = jobRequest.prompt, + init_img = init_img, # TODO: ensure this works + strength = None if init_img is None else jobRequest.strength, + fit = None if init_img is None else jobRequest.fit, + iterations = jobRequest.iterations, + cfg_scale = jobRequest.cfg_scale, + width = jobRequest.width, + height = jobRequest.height, + seed = jobRequest.seed, + steps = jobRequest.steps, + variation_amount = jobRequest.variation_amount, + with_variations = jobRequest.with_variations, + gfpgan_strength = gfpgan_strength, + upscale = upscale, + sampler_name = jobRequest.sampler_name, + seamless = jobRequest.seamless, + embiggen = jobRequest.embiggen, + embiggen_tiles = jobRequest.embiggen_tiles, + step_callback = lambda sample, step: self.__on_progress(jobRequest, sample, step), + image_callback = lambda image, seed, upscaled=False: self.__on_image_result(jobRequest, image, seed, upscaled)) except CanceledException: - if dreamRequest.cancelled_callback: - dreamRequest.cancelled_callback() - - self.__signal_service.emit(Signal.job_canceled(dreamRequest.id())) + self.__signal_service.emit(Signal.job_canceled(jobRequest.id)) finally: - if dreamRequest.done_callback: - dreamRequest.done_callback() - self.__signal_service.emit(Signal.job_done(dreamRequest.id())) + self.__signal_service.emit(Signal.job_done(jobRequest.id)) # Remove the temp file if (initimgfile is not None): diff --git a/server/views.py b/server/views.py index 590adc6532..db4857d14f 100644 --- a/server/views.py +++ b/server/views.py @@ -8,7 +8,7 @@ from flask import current_app, jsonify, request, Response, send_from_directory, from flask.views import MethodView from dependency_injector.wiring import inject, Provide -from server.models import DreamRequest +from server.models import DreamResult, JobRequest from server.services import GeneratorService, ImageStorageService, JobQueueService from server.containers import Container @@ -16,23 +16,14 @@ class ApiJobs(MethodView): @inject def post(self, job_queue_service: JobQueueService = Provide[Container.generation_queue_service]): - dreamRequest = DreamRequest.from_json(request.json, newTime = True) + jobRequest = JobRequest.from_json(request.json) - #self.canceled.clear() - print(f">> Request to generate with prompt: {dreamRequest.prompt}") - - q = Queue() - - dreamRequest.start_callback = None - dreamRequest.image_callback = None - dreamRequest.progress_callback = None - dreamRequest.cancelled_callback = None - dreamRequest.done_callback = None + print(f">> Request to generate with prompt: {jobRequest.prompt}") # Push the request - job_queue_service.push(dreamRequest) + job_queue_service.push(jobRequest) - return { 'dreamId': dreamRequest.id() } + return { 'jobId': jobRequest.id } class WebIndex(MethodView): @@ -68,6 +59,7 @@ class ApiCancel(MethodView): return Response(status=204) +# TODO: Combine all image storage access class ApiImages(MethodView): init_every_request = False __pathRoot = None @@ -82,6 +74,27 @@ class ApiImages(MethodView): name = self.__storage.path(dreamId) fullpath=os.path.join(self.__pathRoot, name) return send_from_directory(os.path.dirname(fullpath), os.path.basename(fullpath)) + + def delete(self, dreamId): + result = self.__storage.delete(dreamId) + return Response(status=204) if result else Response(status=404) + + +class ApiImagesMetadata(MethodView): + init_every_request = False + __pathRoot = None + __storage: ImageStorageService + + @inject + def __init__(self, pathBase, storage: ImageStorageService = Provide[Container.image_storage_service]): + self.__pathRoot = os.path.abspath(os.path.join(os.path.dirname(__file__), pathBase)) + self.__storage = storage + + def get(self, dreamId): + meta = self.__storage.getMetadata(dreamId) + j = {} if meta is None else meta.__dict__ + return j + class ApiIntermediates(MethodView): init_every_request = False @@ -97,3 +110,23 @@ class ApiIntermediates(MethodView): name = self.__storage.path(dreamId, postfix=f'.{step}') fullpath=os.path.join(self.__pathRoot, name) return send_from_directory(os.path.dirname(fullpath), os.path.basename(fullpath)) + + def delete(self, dreamId): + result = self.__storage.delete(dreamId) + return Response(status=204) if result else Response(status=404) + + +class ApiImagesList(MethodView): + init_every_request = False + __storage: ImageStorageService + + @inject + def __init__(self, storage: ImageStorageService = Provide[Container.image_storage_service]): + self.__storage = storage + + def get(self): + page = request.args.get("page", default=0, type=int) + perPage = request.args.get("per_page", default=10, type=int) + + result = self.__storage.list_files(page, perPage) + return result.__dict__ diff --git a/static/dream_web/index.css b/static/dream_web/index.css index 51f0f267c3..25a0994a3d 100644 --- a/static/dream_web/index.css +++ b/static/dream_web/index.css @@ -1,3 +1,8 @@ +:root { + --fields-dark:#DCDCDC; + --fields-light:#F5F5F5; +} + * { font-family: 'Arial'; font-size: 100%; @@ -18,15 +23,26 @@ fieldset { border: none; line-height: 2.2em; } +fieldset > legend { + width: auto; + margin-left: 0; + margin-right: auto; + font-weight:bold; +} select, input { margin-right: 10px; padding: 2px; } +input:disabled { + cursor:auto; +} input[type=submit] { + cursor: pointer; background-color: #666; color: white; } input[type=checkbox] { + cursor: pointer; margin-right: 0px; width: 20px; height: 20px; @@ -87,11 +103,11 @@ header h1 { } #results img { border-radius: 5px; - object-fit: cover; + object-fit: contain; + background-color: var(--fields-dark); } #fieldset-config { line-height:2em; - background-color: #F0F0F0; } input[type="number"] { width: 60px; @@ -118,35 +134,46 @@ label { #progress-image { width: 30vh; height: 30vh; + object-fit: contain; + background-color: var(--fields-dark); } #cancel-button { cursor: pointer; color: red; } -#basic-parameters { - background-color: #EEEEEE; -} #txt2img { - background-color: #DCDCDC; + background-color: var(--fields-dark); } #variations { - background-color: #EEEEEE; + background-color: var(--fields-light); +} +#initimg { + background-color: var(--fields-dark); } #img2img { - background-color: #DCDCDC; + background-color: var(--fields-light); } -#gfpgan { - background-color: #EEEEEE; +#initimg > :not(legend) { + background-color: var(--fields-light); + margin: .5em; +} + +#postprocess, #initimg { + display:flex; + flex-wrap:wrap; + padding: 0; + margin-top: 1em; + background-color: var(--fields-dark); +} +#postprocess > fieldset, #initimg > * { + flex-grow: 1; +} +#postprocess > fieldset { + background-color: var(--fields-dark); } #progress-section { - background-color: #F5F5F5; -} -.section-header { - text-align: left; - font-weight: bold; - padding: 0 0 0 0; + background-color: var(--fields-light); } #no-results-message:not(:only-child) { display: none; } - diff --git a/static/dream_web/index.html b/static/dream_web/index.html index b8c80fe838..9dbd213669 100644 --- a/static/dream_web/index.html +++ b/static/dream_web/index.html @@ -1,104 +1,152 @@ - - Stable Diffusion Dream Server - - - - - - - - - -
-

Stable Diffusion Dream Server

-
- For news and support for this web service, visit our GitHub site + + Stable Diffusion Dream Server + + + + + + + + + + + +
+

Stable Diffusion Dream Server

+
+ For news and support for this web service, visit our GitHub + site +
+
+ +
+ +
+
+ + + + + + + + + + + + + + + +
+ + + + + + + + + + +
+ + + + +
+
+
+ + + + +
-
- - - - -
+
+
+ + + + + + + + +
+ +
-
Post-processing options
- - - + + + + + + +
+
+ + + + +
- -
-
-
- - -
- -
- Postprocessing...1/3 -
- -
- -
-
-

No results...

+
+ + +
+
+
+ + +
+ +
+ Postprocessing...1/3
- - +
+ +
+
+ + + diff --git a/static/dream_web/index.js b/static/dream_web/index.js index 3af0308fb5..5de690297d 100644 --- a/static/dream_web/index.js +++ b/static/dream_web/index.js @@ -1,5 +1,73 @@ const socket = io(); +var priorResultsLoadState = { + page: 0, + pages: 1, + per_page: 10, + total: 20, + offset: 0, // number of items generated since last load + loading: false, + initialized: false +}; + +function loadPriorResults() { + // Fix next page by offset + let offsetPages = priorResultsLoadState.offset / priorResultsLoadState.per_page; + priorResultsLoadState.page += offsetPages; + priorResultsLoadState.pages += offsetPages; + priorResultsLoadState.total += priorResultsLoadState.offset; + priorResultsLoadState.offset = 0; + + if (priorResultsLoadState.loading) { + return; + } + + if (priorResultsLoadState.page >= priorResultsLoadState.pages) { + return; // Nothing more to load + } + + // Load + priorResultsLoadState.loading = true + let url = new URL('/api/images', document.baseURI); + url.searchParams.append('page', priorResultsLoadState.initialized ? priorResultsLoadState.page + 1 : priorResultsLoadState.page); + url.searchParams.append('per_page', priorResultsLoadState.per_page); + fetch(url.href, { + method: 'GET', + headers: new Headers({'content-type': 'application/json'}) + }) + .then(response => response.json()) + .then(data => { + priorResultsLoadState.page = data.page; + priorResultsLoadState.pages = data.pages; + priorResultsLoadState.per_page = data.per_page; + priorResultsLoadState.total = data.total; + + data.items.forEach(function(dreamId, index) { + let src = 'api/images/' + dreamId; + fetch('/api/images/' + dreamId + '/metadata', { + method: 'GET', + headers: new Headers({'content-type': 'application/json'}) + }) + .then(response => response.json()) + .then(metadata => { + let seed = metadata.seed || 0; // TODO: Parse old metadata + appendOutput(src, seed, metadata, true); + }); + }); + + // Load until page is full + if (!priorResultsLoadState.initialized) { + if (document.body.scrollHeight <= window.innerHeight) { + loadPriorResults(); + } + } + }) + .finally(() => { + priorResultsLoadState.loading = false; + priorResultsLoadState.initialized = true; + }); +} + function resetForm() { var form = document.getElementById('generate-form'); form.querySelector('fieldset').removeAttribute('disabled'); @@ -45,48 +113,64 @@ function toBase64(file) { }); } -function appendOutput(src, seed, config) { +function ondragdream(event) { + let dream = event.target.dataset.dream; + event.dataTransfer.setData("dream", dream); +} + +function seedClick(event) { + // Get element + var image = event.target.closest('figure').querySelector('img'); + var dream = JSON.parse(decodeURIComponent(image.dataset.dream)); + + let form = document.querySelector("#generate-form"); + for (const [k, v] of new FormData(form)) { + if (k == 'initimg') { continue; } + let formElem = form.querySelector(`*[name=${k}]`); + formElem.value = dream[k] !== undefined ? dream[k] : formElem.defaultValue; + } + + document.querySelector("#seed").value = dream.seed; + document.querySelector('#iterations').value = 1; // Reset to 1 iteration since we clicked a single image (not a full job) + + // NOTE: leaving this manual for the user for now - it was very confusing with this behavior + // document.querySelector("#with_variations").value = variations || ''; + // if (document.querySelector("#variation_amount").value <= 0) { + // document.querySelector("#variation_amount").value = 0.2; + // } + + saveFields(document.querySelector("#generate-form")); +} + +function appendOutput(src, seed, config, toEnd=false) { let outputNode = document.createElement("figure"); let altText = seed.toString() + " | " + config.prompt; + // img needs width and height for lazy loading to work + // TODO: store the full config in a data attribute on the image? const figureContents = ` - ${altText} + ${altText} -
${seed}
+
${seed}
`; outputNode.innerHTML = figureContents; - let figcaption = outputNode.querySelector('figcaption') - // Reload image config - figcaption.addEventListener('click', () => { - let form = document.querySelector("#generate-form"); - for (const [k, v] of new FormData(form)) { - if (k == 'initimg') { continue; } - form.querySelector(`*[name=${k}]`).value = config[k]; - } - if (config.variation_amount > 0 || config.with_variations != '') { - document.querySelector("#seed").value = config.seed; - } else { - document.querySelector("#seed").value = seed; - } - - if (config.variation_amount > 0) { - let oldVarAmt = document.querySelector("#variation_amount").value - let oldVariations = document.querySelector("#with_variations").value - let varSep = '' - document.querySelector("#variation_amount").value = 0; - if (document.querySelector("#with_variations").value != '') { - varSep = "," - } - document.querySelector("#with_variations").value = oldVariations + varSep + seed + ':' + config.variation_amount - } - - saveFields(document.querySelector("#generate-form")); - }); - - document.querySelector("#results").prepend(outputNode); + if (toEnd) { + document.querySelector("#results").append(outputNode); + } else { + document.querySelector("#results").prepend(outputNode); + } document.querySelector("#no-results-message")?.remove(); } @@ -119,14 +203,33 @@ async function generateSubmit(form) { // Convert file data to base64 // TODO: Should probably uplaod files with formdata or something, and store them in the backend? let formData = Object.fromEntries(new FormData(form)); + if (!formData.enable_generate && !formData.enable_init_image) { + gen_label = document.querySelector("label[for=enable_generate]").innerHTML; + initimg_label = document.querySelector("label[for=enable_init_image]").innerHTML; + alert(`Error: one of "${gen_label}" or "${initimg_label}" must be set`); + } + + formData.initimg_name = formData.initimg.name formData.initimg = formData.initimg.name !== '' ? await toBase64(formData.initimg) : null; + // Evaluate all checkboxes + let checkboxes = form.querySelectorAll('input[type=checkbox]'); + checkboxes.forEach(function (checkbox) { + if (checkbox.checked) { + formData[checkbox.name] = 'true'; + } + }); + let strength = formData.strength; let totalSteps = formData.initimg ? Math.floor(strength * formData.steps) : formData.steps; + let showProgressImages = formData.progress_images; + + // Set enabling flags + // Initialize the progress bar - initProgress(totalSteps); + initProgress(totalSteps, showProgressImages); // POST, use response to listen for events fetch(form.action, { @@ -136,13 +239,19 @@ async function generateSubmit(form) { }) .then(response => response.json()) .then(data => { - var dreamId = data.dreamId; - socket.emit('join_room', { 'room': dreamId }); + var jobId = data.jobId; + socket.emit('join_room', { 'room': jobId }); }); form.querySelector('fieldset').setAttribute('disabled',''); } +function fieldSetEnableChecked(event) { + cb = event.target; + fields = cb.closest('fieldset'); + fields.disabled = !cb.checked; +} + // Socket listeners socket.on('job_started', (data) => {}) @@ -152,6 +261,7 @@ socket.on('dream_result', (data) => { var dreamRequest = data.dreamRequest; var src = 'api/images/' + dreamId; + priorResultsLoadState.offset += 1; appendOutput(src, dreamRequest.seed, dreamRequest); resetProgress(false); @@ -193,7 +303,13 @@ socket.on('job_done', (data) => { resetProgress(); }) -window.onload = () => { +window.onload = async () => { + document.querySelector("#prompt").addEventListener("keydown", (e) => { + if (e.key === "Enter" && !e.shiftKey) { + const form = e.target.form; + generateSubmit(form); + } + }); document.querySelector("#generate-form").addEventListener('submit', (e) => { e.preventDefault(); const form = e.target; @@ -216,12 +332,65 @@ window.onload = () => { loadFields(document.querySelector("#generate-form")); document.querySelector('#cancel-button').addEventListener('click', () => { - fetch('/cancel').catch(e => { + fetch('/api/cancel').catch(e => { console.error(e); }); }); + document.documentElement.addEventListener('keydown', (e) => { + if (e.key === "Escape") + fetch('/api/cancel').catch(err => { + console.error(err); + }); + }); if (!config.gfpgan_model_exists) { document.querySelector("#gfpgan").style.display = 'none'; } + + window.addEventListener("scroll", () => { + if ((window.innerHeight + window.pageYOffset) >= document.body.offsetHeight) { + loadPriorResults(); + } + }); + + + + // Enable/disable forms by checkboxes + document.querySelectorAll("legend > input[type=checkbox]").forEach(function(cb) { + cb.addEventListener('change', fieldSetEnableChecked); + fieldSetEnableChecked({ target: cb}) + }); + + + // Load some of the previous results + loadPriorResults(); + + // Image drop/upload WIP + /* + let drop = document.getElementById('dropper'); + function ondrop(event) { + let dreamData = event.dataTransfer.getData('dream'); + if (dreamData) { + var dream = JSON.parse(decodeURIComponent(dreamData)); + alert(dream.dreamId); + } + }; + + function ondragenter(event) { + event.preventDefault(); + }; + + function ondragover(event) { + event.preventDefault(); + }; + + function ondragleave(event) { + + } + + drop.addEventListener('drop', ondrop); + drop.addEventListener('dragenter', ondragenter); + drop.addEventListener('dragover', ondragover); + drop.addEventListener('dragleave', ondragleave); + */ };