import mimetypes import transformers import json import os import traceback import eventlet import glob import shlex import math import shutil from flask_socketio import SocketIO from flask import Flask, send_from_directory, url_for, jsonify from pathlib import Path from PIL import Image from pytorch_lightning import logging from threading import Event from uuid import uuid4 from send2trash import send2trash from ldm.gfpgan.gfpgan_tools import real_esrgan_upscale from ldm.gfpgan.gfpgan_tools import run_gfpgan from ldm.generate import Generate from ldm.dream.pngwriter import PngWriter, retrieve_metadata from ldm.dream.args import APP_ID, APP_VERSION, calculate_init_img_hash from ldm.dream.conditioning import split_weighted_subprompts from modules.parameters import parameters_to_command """ USER CONFIG """ output_dir = "outputs/" # Base output directory for images # host = 'localhost' # Web & socket.io host host = "localhost" # Web & socket.io host port = 9090 # Web & socket.io port verbose = False # enables copious socket.io logging additional_allowed_origins = [ "http://localhost:5173" ] # additional CORS allowed origins model = "stable-diffusion-1.4" """ END USER CONFIG """ """ SERVER SETUP """ # fix missing mimetypes on windows due to registry wonkiness mimetypes.add_type("application/javascript", ".js") mimetypes.add_type("text/css", ".css") app = Flask(__name__, static_url_path="", static_folder="../frontend/dist/") app.config["OUTPUTS_FOLDER"] = "../outputs" @app.route("/outputs/") def outputs(filename): return send_from_directory(app.config["OUTPUTS_FOLDER"], filename) @app.route("/", defaults={"path": ""}) def serve(path): return send_from_directory(app.static_folder, "index.html") logger = True if verbose else False engineio_logger = True if verbose else False # default 1,000,000, needs to be higher for socketio to accept larger images max_http_buffer_size = 10000000 cors_allowed_origins = [f"http://{host}:{port}"] + additional_allowed_origins socketio = SocketIO( app, logger=logger, engineio_logger=engineio_logger, max_http_buffer_size=max_http_buffer_size, cors_allowed_origins=cors_allowed_origins, ) """ END SERVER SETUP """ """ APP SETUP """ class CanceledException(Exception): pass canceled = Event() # reduce logging outputs to error transformers.logging.set_verbosity_error() logging.getLogger("pytorch_lightning").setLevel(logging.ERROR) # Initialize and load model generate = Generate(model) generate.load_model() # location for "finished" images result_path = os.path.join(output_dir, "img-samples/") # temporary path for intermediates intermediate_path = os.path.join(result_path, "intermediates/") # path for user-uploaded init images and masks init_image_path = os.path.join(result_path, "init-images/") mask_image_path = os.path.join(result_path, "mask-images/") # txt log log_path = os.path.join(result_path, "dream_log.txt") # make all output paths [ os.makedirs(path, exist_ok=True) for path in [result_path, intermediate_path, init_image_path, mask_image_path] ] """ END APP SETUP """ """ SOCKET.IO LISTENERS """ @socketio.on("requestSystemConfig") def handle_request_capabilities(): print(f">> System config requested") config = get_system_config() socketio.emit("systemConfig", config) @socketio.on("requestAllImages") def handle_request_all_images(): print(f">> All images requested") paths = list(filter(os.path.isfile, glob.glob(result_path + "*.png"))) paths.sort(key=lambda x: os.path.getmtime(x)) image_array = [] for path in paths: metadata = retrieve_metadata(path) image_array.append({"url": path, "metadata": metadata["sd-metadata"]}) socketio.emit("galleryImages", {"images": image_array}) eventlet.sleep(0) @socketio.on("generateImage") def handle_generate_image_event( generation_parameters, esrgan_parameters, gfpgan_parameters ): print( f">> Image generation requested: {generation_parameters}\nESRGAN parameters: {esrgan_parameters}\nGFPGAN parameters: {gfpgan_parameters}" ) generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters) @socketio.on("runESRGAN") def handle_run_esrgan_event(original_image, esrgan_parameters): print( f'>> ESRGAN upscale requested for "{original_image["url"]}": {esrgan_parameters}' ) progress = { "currentStep": 1, "totalSteps": 1, "currentIteration": 1, "totalIterations": 1, "currentStatus": "Preparing", "isProcessing": True, "currentStatusHasSteps": False, } socketio.emit("progressUpdate", progress) eventlet.sleep(0) image = Image.open(original_image["url"]) seed = ( original_image["metadata"]["seed"] if "seed" in original_image["metadata"] else "unknown_seed" ) progress["currentStatus"] = "Upscaling" socketio.emit("progressUpdate", progress) eventlet.sleep(0) image = real_esrgan_upscale( image=image, upsampler_scale=esrgan_parameters["upscale"][0], strength=esrgan_parameters["upscale"][1], seed=seed, ) progress["currentStatus"] = "Saving image" socketio.emit("progressUpdate", progress) eventlet.sleep(0) esrgan_parameters["seed"] = seed metadata = parameters_to_post_processed_image_metadata( parameters=esrgan_parameters, original_image_path=original_image["url"], type="esrgan", ) command = parameters_to_command(esrgan_parameters) path = save_image(image, command, metadata, result_path, postprocessing="esrgan") write_log_message(f'[Upscaled] "{original_image["url"]}" > "{path}": {command}') progress["currentStatus"] = "Finished" progress["currentStep"] = 0 progress["totalSteps"] = 0 progress["currentIteration"] = 0 progress["totalIterations"] = 0 progress["isProcessing"] = False socketio.emit("progressUpdate", progress) eventlet.sleep(0) socketio.emit( "esrganResult", { "url": os.path.relpath(path), "metadata": metadata, }, ) @socketio.on("runGFPGAN") def handle_run_gfpgan_event(original_image, gfpgan_parameters): print( f'>> GFPGAN face fix requested for "{original_image["url"]}": {gfpgan_parameters}' ) progress = { "currentStep": 1, "totalSteps": 1, "currentIteration": 1, "totalIterations": 1, "currentStatus": "Preparing", "isProcessing": True, "currentStatusHasSteps": False, } socketio.emit("progressUpdate", progress) eventlet.sleep(0) image = Image.open(original_image["url"]) seed = ( original_image["metadata"]["seed"] if "seed" in original_image["metadata"] else "unknown_seed" ) progress["currentStatus"] = "Fixing faces" socketio.emit("progressUpdate", progress) eventlet.sleep(0) image = run_gfpgan( image=image, strength=gfpgan_parameters["gfpgan_strength"], seed=seed, upsampler_scale=1, ) progress["currentStatus"] = "Saving image" socketio.emit("progressUpdate", progress) eventlet.sleep(0) gfpgan_parameters["seed"] = seed metadata = parameters_to_post_processed_image_metadata( parameters=gfpgan_parameters, original_image_path=original_image["url"], type="gfpgan", ) command = parameters_to_command(gfpgan_parameters) path = save_image(image, command, metadata, result_path, postprocessing="gfpgan") write_log_message(f'[Fixed faces] "{original_image["url"]}" > "{path}": {command}') progress["currentStatus"] = "Finished" progress["currentStep"] = 0 progress["totalSteps"] = 0 progress["currentIteration"] = 0 progress["totalIterations"] = 0 progress["isProcessing"] = False socketio.emit("progressUpdate", progress) eventlet.sleep(0) socketio.emit( "gfpganResult", { "url": os.path.relpath(path), "metadata": metadata, }, ) @socketio.on("cancel") def handle_cancel(): print(f">> Cancel processing requested") canceled.set() socketio.emit("processingCanceled") # TODO: I think this needs a safety mechanism. @socketio.on("deleteImage") def handle_delete_image(path, uuid): print(f'>> Delete requested "{path}"') send2trash(path) socketio.emit("imageDeleted", {"url": path, "uuid": uuid}) # TODO: I think this needs a safety mechanism. @socketio.on("uploadInitialImage") def handle_upload_initial_image(bytes, name): print(f'>> Init image upload requested "{name}"') uuid = uuid4().hex split = os.path.splitext(name) name = f"{split[0]}.{uuid}{split[1]}" file_path = os.path.join(init_image_path, name) os.makedirs(os.path.dirname(file_path), exist_ok=True) newFile = open(file_path, "wb") newFile.write(bytes) socketio.emit("initialImageUploaded", {"url": file_path, "uuid": ""}) # TODO: I think this needs a safety mechanism. @socketio.on("uploadMaskImage") def handle_upload_mask_image(bytes, name): print(f'>> Mask image upload requested "{name}"') uuid = uuid4().hex split = os.path.splitext(name) name = f"{split[0]}.{uuid}{split[1]}" file_path = os.path.join(mask_image_path, name) os.makedirs(os.path.dirname(file_path), exist_ok=True) newFile = open(file_path, "wb") newFile.write(bytes) socketio.emit("maskImageUploaded", {"url": file_path, "uuid": ""}) """ END SOCKET.IO LISTENERS """ """ ADDITIONAL FUNCTIONS """ def get_system_config(): return { "model": "stable diffusion", "model_id": model, "model_hash": generate.model_hash, "app_id": APP_ID, "app_version": APP_VERSION, } def parameters_to_post_processed_image_metadata(parameters, original_image_path, type): # top-level metadata minus `image` or `images` metadata = get_system_config() orig_hash = calculate_init_img_hash(original_image_path) image = {"orig_path": original_image_path, "orig_hash": orig_hash} if type == "esrgan": image["type"] = "esrgan" image["scale"] = parameters["upscale"][0] image["strength"] = parameters["upscale"][1] elif type == "gfpgan": image["type"] = "gfpgan" image["strength"] = parameters["gfpgan_strength"] else: raise TypeError(f"Invalid type: {type}") metadata["image"] = image return metadata def parameters_to_generated_image_metadata(parameters): # top-level metadata minus `image` or `images` metadata = get_system_config() # remove any image keys not mentioned in RFC #266 rfc266_img_fields = [ "type", "postprocessing", "sampler", "prompt", "seed", "variations", "steps", "cfg_scale", "step_number", "width", "height", "extra", "seamless", ] rfc_dict = {} for item in parameters.items(): key, value = item if key in rfc266_img_fields: rfc_dict[key] = value postprocessing = [] # 'postprocessing' is either null or an if "gfpgan_strength" in parameters: postprocessing.append( {"type": "gfpgan", "strength": float(parameters["gfpgan_strength"])} ) if "upscale" in parameters: postprocessing.append( { "type": "esrgan", "scale": int(parameters["upscale"][0]), "strength": float(parameters["upscale"][1]), } ) rfc_dict["postprocessing"] = postprocessing if len(postprocessing) > 0 else None # semantic drift rfc_dict["sampler"] = parameters["sampler_name"] # display weighted subprompts (liable to change) subprompts = split_weighted_subprompts(parameters["prompt"]) subprompts = [{"prompt": x[0], "weight": x[1]} for x in subprompts] rfc_dict["prompt"] = subprompts # 'variations' should always exist and be an array, empty or consisting of {'seed': seed, 'weight': weight} pairs variations = [] if "with_variations" in parameters: variations = [ {"seed": x[0], "weight": x[1]} for x in parameters["with_variations"] ] rfc_dict["variations"] = variations if "init_img" in parameters: rfc_dict["type"] = "img2img" rfc_dict["strength"] = parameters["strength"] rfc_dict["fit"] = parameters["fit"] # TODO: Noncompliant rfc_dict["orig_hash"] = calculate_init_img_hash(parameters["init_img"]) rfc_dict["init_image_path"] = parameters["init_img"] # TODO: Noncompliant rfc_dict["sampler"] = "ddim" # TODO: FIX ME WHEN IMG2IMG SUPPORTS ALL SAMPLERS if "init_mask" in parameters: rfc_dict["mask_hash"] = calculate_init_img_hash( parameters["init_mask"] ) # TODO: Noncompliant rfc_dict["mask_image_path"] = parameters["init_mask"] # TODO: Noncompliant else: rfc_dict["type"] = "txt2img" metadata["image"] = rfc_dict return metadata def make_unique_init_image_filename(name): uuid = uuid4().hex split = os.path.splitext(name) name = f"{split[0]}.{uuid}{split[1]}" return name def write_log_message(message, log_path=log_path): """Logs the filename and parameters used to generate or process that image to log file""" message = f"{message}\n" with open(log_path, "a", encoding="utf-8") as file: file.writelines(message) def save_image( image, command, metadata, output_dir, step_index=None, postprocessing=False ): pngwriter = PngWriter(output_dir) prefix = pngwriter.unique_prefix() seed = "unknown_seed" if "image" in metadata: if "seed" in metadata["image"]: seed = metadata["image"]["seed"] filename = f"{prefix}.{seed}" if step_index: filename += f".{step_index}" if postprocessing: filename += f".postprocessed" filename += ".png" path = pngwriter.save_image_and_prompt_to_png( image=image, dream_prompt=command, metadata=metadata, name=filename ) return path def calculate_real_steps(steps, strength, has_init_image): return math.floor(strength * steps) if has_init_image else steps def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters): canceled.clear() step_index = 1 """ If a result image is used as an init image, and then deleted, we will want to be able to use it as an init image in the future. Need to copy it. If the init/mask image doesn't exist in the init_image_path/mask_image_path, make a unique filename for it and copy it there. """ if "init_img" in generation_parameters: filename = os.path.basename(generation_parameters["init_img"]) if not os.path.exists(os.path.join(init_image_path, filename)): unique_filename = make_unique_init_image_filename(filename) new_path = os.path.join(init_image_path, unique_filename) shutil.copy(generation_parameters["init_img"], new_path) generation_parameters["init_img"] = new_path if "init_mask" in generation_parameters: filename = os.path.basename(generation_parameters["init_mask"]) if not os.path.exists(os.path.join(mask_image_path, filename)): unique_filename = make_unique_init_image_filename(filename) new_path = os.path.join(init_image_path, unique_filename) shutil.copy(generation_parameters["init_img"], new_path) generation_parameters["init_mask"] = new_path totalSteps = calculate_real_steps( steps=generation_parameters["steps"], strength=generation_parameters["strength"] if "strength" in generation_parameters else None, has_init_image="init_img" in generation_parameters, ) progress = { "currentStep": 1, "totalSteps": totalSteps, "currentIteration": 1, "totalIterations": generation_parameters["iterations"], "currentStatus": "Preparing", "isProcessing": True, "currentStatusHasSteps": False, } socketio.emit("progressUpdate", progress) eventlet.sleep(0) def image_progress(sample, step): if canceled.is_set(): raise CanceledException nonlocal step_index nonlocal generation_parameters nonlocal progress progress["currentStep"] = step + 1 progress["currentStatus"] = "Generating" progress["currentStatusHasSteps"] = True if ( generation_parameters["progress_images"] and step % 5 == 0 and step < generation_parameters["steps"] - 1 ): image = generate.sample_to_image(sample) path = save_image( image, generation_parameters, intermediate_path, step_index ) step_index += 1 socketio.emit( "intermediateResult", {"url": os.path.relpath(path), "metadata": generation_parameters}, ) socketio.emit("progressUpdate", progress) eventlet.sleep(0) def image_done(image, seed): nonlocal generation_parameters nonlocal esrgan_parameters nonlocal gfpgan_parameters nonlocal progress step_index = 1 progress["currentStatus"] = "Generation complete" socketio.emit("progressUpdate", progress) eventlet.sleep(0) all_parameters = generation_parameters postprocessing = False if esrgan_parameters: progress["currentStatus"] = "Upscaling" progress["currentStatusHasSteps"] = False socketio.emit("progressUpdate", progress) eventlet.sleep(0) image = real_esrgan_upscale( image=image, strength=esrgan_parameters["strength"], upsampler_scale=esrgan_parameters["level"], seed=seed, ) postprocessing = True all_parameters["upscale"] = [ esrgan_parameters["level"], esrgan_parameters["strength"], ] if gfpgan_parameters: progress["currentStatus"] = "Fixing faces" progress["currentStatusHasSteps"] = False socketio.emit("progressUpdate", progress) eventlet.sleep(0) image = run_gfpgan( image=image, strength=gfpgan_parameters["strength"], seed=seed, upsampler_scale=1, ) postprocessing = True all_parameters["gfpgan_strength"] = gfpgan_parameters["strength"] all_parameters["seed"] = seed progress["currentStatus"] = "Saving image" socketio.emit("progressUpdate", progress) eventlet.sleep(0) metadata = parameters_to_generated_image_metadata(all_parameters) command = parameters_to_command(all_parameters) path = save_image( image, command, metadata, result_path, postprocessing=postprocessing ) print(f'>> Image generated: "{path}"') write_log_message(f'[Generated] "{path}": {command}') if progress["totalIterations"] > progress["currentIteration"]: progress["currentStep"] = 1 progress["currentIteration"] += 1 progress["currentStatus"] = "Iteration finished" progress["currentStatusHasSteps"] = False else: progress["currentStep"] = 0 progress["totalSteps"] = 0 progress["currentIteration"] = 0 progress["totalIterations"] = 0 progress["currentStatus"] = "Finished" progress["isProcessing"] = False socketio.emit("progressUpdate", progress) eventlet.sleep(0) socketio.emit( "generationResult", {"url": os.path.relpath(path), "metadata": metadata}, ) eventlet.sleep(0) try: generate.prompt2image( **generation_parameters, step_callback=image_progress, image_callback=image_done, ) except KeyboardInterrupt: raise except CanceledException: pass except Exception as e: socketio.emit("error", {"message": (str(e))}) print("\n") traceback.print_exc() print("\n") """ END ADDITIONAL FUNCTIONS """ if __name__ == "__main__": print(f">> Starting server at http://{host}:{port}") socketio.run(app, host=host, port=port)