import base64
import glob
import io
import json
import math
import mimetypes
import os
import shutil
import traceback
from threading import Event
from uuid import uuid4

import eventlet
from pathlib import Path
from PIL import Image
from PIL.Image import Image as ImageType
from flask import Flask, redirect, send_from_directory, request, make_response
from flask_socketio import SocketIO
from werkzeug.utils import secure_filename

from invokeai.backend.modules.get_canvas_generation_mode import (
    get_canvas_generation_mode,
)
from invokeai.backend.modules.parameters import parameters_to_command
import invokeai.frontend.dist as frontend
from ldm.generate import Generate
from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
from ldm.invoke.conditioning import get_tokens_for_prompt, get_prompt_structure
from ldm.invoke.generator.diffusers_pipeline import PipelineIntermediateState
from ldm.invoke.generator.inpaint import infill_methods
from ldm.invoke.globals import Globals, global_converted_ckpts_dir
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata
from ldm.invoke.prompt_parser import split_weighted_subprompts, Blend
from ldm.invoke.globals import global_models_dir
from ldm.invoke.merge_diffusers import merge_diffusion_models

# Loading Arguments
opt = Args()
args = opt.parse_args()

# Set the root directory for static files and relative paths
args.root_dir = os.path.expanduser(args.root_dir or "..")
if not os.path.isabs(args.outdir):
    args.outdir = os.path.join(args.root_dir, args.outdir)

# normalize the config directory relative to root
if not os.path.isabs(opt.conf):
    opt.conf = os.path.normpath(os.path.join(Globals.root, opt.conf))


class InvokeAIWebServer:
    def __init__(self, generate: Generate, gfpgan, codeformer, esrgan) -> None:
        self.host = args.host
        self.port = args.port

        self.generate = generate
        self.gfpgan = gfpgan
        self.codeformer = codeformer
        self.esrgan = esrgan

        self.canceled = Event()
        self.ALLOWED_EXTENSIONS = {"png", "jpg", "jpeg"}

    def allowed_file(self, filename: str) -> bool:
        return (
            "." in filename
            and filename.rsplit(".", 1)[1].lower() in self.ALLOWED_EXTENSIONS
        )

    def run(self):
        self.setup_app()
        self.setup_flask()

    def setup_flask(self):
        # Fix missing mimetypes on Windows
        mimetypes.add_type("application/javascript", ".js")
        mimetypes.add_type("text/css", ".css")
        # Socket IO
        logger = True if args.web_verbose else False
        engineio_logger = True if args.web_verbose else False
        max_http_buffer_size = 10000000

        socketio_args = {
            "logger": logger,
            "engineio_logger": engineio_logger,
            "max_http_buffer_size": max_http_buffer_size,
            "ping_interval": (50, 50),
            "ping_timeout": 60,
        }

        if opt.cors:
            _cors = opt.cors
            # convert list back into comma-separated string,
            # be defensive here, not sure in what form this arrives
            if isinstance(_cors, list):
                _cors = ",".join(_cors)
            if "," in _cors:
                _cors = _cors.split(",")
            socketio_args["cors_allowed_origins"] = _cors

        self.app = Flask(
            __name__, static_url_path="", static_folder=frontend.__path__[0]
        )

        self.socketio = SocketIO(self.app, **socketio_args)

        # Keep Server Alive Route
        @self.app.route("/flaskwebgui-keep-server-alive")
        def keep_alive():
            return {"message": "Server Running"}

        # Outputs Route
        self.app.config["OUTPUTS_FOLDER"] = os.path.abspath(args.outdir)

        @self.app.route("/outputs/<path:file_path>")
        def outputs(file_path):
            return send_from_directory(self.app.config["OUTPUTS_FOLDER"], file_path)

        # Base Route
        @self.app.route("/")
        def serve():
            if args.web_develop:
                return redirect("http://127.0.0.1:5173")
            else:
                return send_from_directory(self.app.static_folder, "index.html")

        @self.app.route("/upload", methods=["POST"])
        def upload():
            try:
                data = json.loads(request.form["data"])
                filename = ""
                # check if the post request has the file part
                if "file" in request.files:
                    file = request.files["file"]
                    # If the user does not select a file, the browser submits an
                    # empty file without a filename.
                    if file.filename == "":
                        return make_response("No file selected", 400)
                    filename = file.filename
                elif "dataURL" in data:
                    file = dataURL_to_bytes(data["dataURL"])
                    if "filename" not in data or data["filename"] == "":
                        return make_response("No filename provided", 400)
                    filename = data["filename"]
                else:
                    return make_response("No file or dataURL", 400)

                kind = data["kind"]

                if kind == "init":
                    path = self.init_image_path
                elif kind == "temp":
                    path = self.temp_image_path
                elif kind == "result":
                    path = self.result_path
                elif kind == "mask":
                    path = self.mask_image_path
                else:
                    return make_response(f"Invalid upload kind: {kind}", 400)

                if not self.allowed_file(filename):
                    return make_response(
                        f'Invalid file type, must be one of: {", ".join(self.ALLOWED_EXTENSIONS)}',
                        400,
                    )

                secured_filename = secure_filename(filename)

                uuid = uuid4().hex
                truncated_uuid = uuid[:8]

                split = os.path.splitext(secured_filename)
                name = f"{split[0]}.{truncated_uuid}{split[1]}"

                file_path = os.path.join(path, name)

                if "dataURL" in data:
                    with open(file_path, "wb") as f:
                        f.write(file)
                else:
                    file.save(file_path)

                mtime = os.path.getmtime(file_path)

                pil_image = Image.open(file_path)

                if "cropVisible" in data and data["cropVisible"] == True:
                    visible_image_bbox = pil_image.getbbox()
                    pil_image = pil_image.crop(visible_image_bbox)
                    pil_image.save(file_path)

                (width, height) = pil_image.size

                thumbnail_path = save_thumbnail(
                    pil_image, os.path.basename(
                        file_path), self.thumbnail_image_path
                )

                response = {
                    "url": self.get_url_from_image_path(file_path),
                    "thumbnail": self.get_url_from_image_path(thumbnail_path),
                    "mtime": mtime,
                    "width": width,
                    "height": height,
                }

                return make_response(response, 200)

            except Exception as e:
                self.handle_exceptions(e)
                return make_response("Error uploading file", 500)

        self.load_socketio_listeners(self.socketio)

        if args.gui:
            print(">> Launching Invoke AI GUI")
            try:
                from flaskwebgui import FlaskUI

                FlaskUI(
                    app=self.app,
                    socketio=self.socketio,
                    server="flask_socketio",
                    width=1600,
                    height=1000,
                    port=self.port
                ).run()
            except KeyboardInterrupt:
                import sys

                sys.exit(0)
        else:
            useSSL = args.certfile or args.keyfile
            print(">> Started Invoke AI Web Server!")
            if self.host == "0.0.0.0":
                print(
                    f"Point your browser at http{'s' if useSSL else ''}://localhost:{self.port} or use the host's DNS name or IP address."
                )
            else:
                print(
                    ">> Default host address now 127.0.0.1 (localhost). Use --host 0.0.0.0 to bind any address."
                )
                print(
                    f">> Point your browser at http{'s' if useSSL else ''}://{self.host}:{self.port}"
                )
            if not useSSL:
                self.socketio.run(app=self.app, host=self.host, port=self.port)
            else:
                self.socketio.run(
                    app=self.app,
                    host=self.host,
                    port=self.port,
                    certfile=args.certfile,
                    keyfile=args.keyfile,
                )

    def setup_app(self):
        self.result_url = "outputs/"
        self.init_image_url = "outputs/init-images/"
        self.mask_image_url = "outputs/mask-images/"
        self.intermediate_url = "outputs/intermediates/"
        self.temp_image_url = "outputs/temp-images/"
        self.thumbnail_image_url = "outputs/thumbnails/"
        # location for "finished" images
        self.result_path = args.outdir
        # temporary path for intermediates
        self.intermediate_path = os.path.join(
            self.result_path, "intermediates/")
        # path for user-uploaded init images and masks
        self.init_image_path = os.path.join(self.result_path, "init-images/")
        self.mask_image_path = os.path.join(self.result_path, "mask-images/")
        # path for temp images e.g. gallery generations which are not committed
        self.temp_image_path = os.path.join(self.result_path, "temp-images/")
        # path for thumbnail images
        self.thumbnail_image_path = os.path.join(
            self.result_path, "thumbnails/")
        # txt log
        self.log_path = os.path.join(self.result_path, "invoke_log.txt")
        # make all output paths
        [
            os.makedirs(path, exist_ok=True)
            for path in [
                self.result_path,
                self.intermediate_path,
                self.init_image_path,
                self.mask_image_path,
                self.temp_image_path,
                self.thumbnail_image_path,
            ]
        ]

    def load_socketio_listeners(self, socketio):
        @socketio.on("requestSystemConfig")
        def handle_request_capabilities():
            print(">> System config requested")
            config = self.get_system_config()
            config["model_list"] = self.generate.model_manager.list_models()
            config["infill_methods"] = infill_methods()
            socketio.emit("systemConfig", config)

        @socketio.on('searchForModels')
        def handle_search_models(search_folder: str):
            try:
                if not search_folder:
                    socketio.emit(
                        "foundModels",
                        {'search_folder': None, 'found_models': None},
                    )
                else:
                    search_folder, found_models = self.generate.model_manager.search_models(
                        search_folder)
                    socketio.emit(
                        "foundModels",
                        {'search_folder': search_folder,
                            'found_models': found_models},
                    )
            except Exception as e:
                self.handle_exceptions(e)
                print("\n")

        @socketio.on("addNewModel")
        def handle_add_model(new_model_config: dict):
            try:
                model_name = new_model_config['name']
                del new_model_config['name']
                model_attributes = new_model_config
                if len(model_attributes['vae']) == 0:
                    del model_attributes['vae']
                update = False
                current_model_list = self.generate.model_manager.list_models()
                if model_name in current_model_list:
                    update = True

                print(f">> Adding New Model: {model_name}")

                self.generate.model_manager.add_model(
                    model_name=model_name, model_attributes=model_attributes, clobber=True)
                self.generate.model_manager.commit(opt.conf)

                new_model_list = self.generate.model_manager.list_models()
                socketio.emit(
                    "newModelAdded",
                    {"new_model_name": model_name,
                     "model_list": new_model_list, 'update': update},
                )
                print(f">> New Model Added: {model_name}")
            except Exception as e:
                self.handle_exceptions(e)

        @socketio.on("deleteModel")
        def handle_delete_model(model_name: str):
            try:
                print(f">> Deleting Model: {model_name}")
                self.generate.model_manager.del_model(model_name)
                self.generate.model_manager.commit(opt.conf)
                updated_model_list = self.generate.model_manager.list_models()
                socketio.emit(
                    "modelDeleted",
                    {"deleted_model_name": model_name,
                     "model_list": updated_model_list},
                )
                print(f">> Model Deleted: {model_name}")
            except Exception as e:
                self.handle_exceptions(e)

        @socketio.on("requestModelChange")
        def handle_set_model(model_name: str):
            try:
                print(f">> Model change requested: {model_name}")
                model = self.generate.set_model(model_name)
                model_list = self.generate.model_manager.list_models()
                if model is None:
                    socketio.emit(
                        "modelChangeFailed",
                        {"model_name": model_name, "model_list": model_list},
                    )
                else:
                    socketio.emit(
                        "modelChanged",
                        {"model_name": model_name, "model_list": model_list},
                    )
            except Exception as e:
                self.handle_exceptions(e)

        @socketio.on('convertToDiffusers')
        def convert_to_diffusers(model_to_convert: dict):
            try:
                if (model_info := self.generate.model_manager.model_info(model_name=model_to_convert['model_name'])):
                    if 'weights' in model_info:
                        ckpt_path = Path(model_info['weights'])
                        original_config_file = Path(model_info['config'])
                        model_name = model_to_convert['model_name']
                        model_description = model_info['description']
                    else:
                        self.socketio.emit(
                            "error", {"message": "Model is not a valid checkpoint file"})
                else:
                    self.socketio.emit(
                        "error", {"message": "Could not retrieve model info."})

                if not ckpt_path.is_absolute():
                    ckpt_path = Path(Globals.root, ckpt_path)

                if original_config_file and not original_config_file.is_absolute():
                    original_config_file = Path(
                        Globals.root, original_config_file)

                diffusers_path = Path(
                    ckpt_path.parent.absolute(),
                    f'{model_name}_diffusers'
                )

                if model_to_convert['save_location'] == 'root':
                    diffusers_path = Path(
                        global_converted_ckpts_dir(), f'{model_name}_diffusers')

                if model_to_convert['save_location'] == 'custom' and model_to_convert['custom_location'] is not None:
                    diffusers_path = Path(
                        model_to_convert['custom_location'], f'{model_name}_diffusers')

                if diffusers_path.exists():
                    shutil.rmtree(diffusers_path)

                self.generate.model_manager.convert_and_import(
                    ckpt_path,
                    diffusers_path,
                    model_name=model_name,
                    model_description=model_description,
                    vae=None,
                    original_config_file=original_config_file,
                    commit_to_conf=opt.conf,
                )

                new_model_list = self.generate.model_manager.list_models()
                socketio.emit(
                    "modelConverted",
                    {"new_model_name": model_name,
                     "model_list": new_model_list, 'update': True},
                )
                print(f">> Model Converted: {model_name}")
            except Exception as e:
                self.handle_exceptions(e)

        @socketio.on('mergeDiffusersModels')
        def merge_diffusers_models(model_merge_info: dict):
            try:
                models_to_merge = model_merge_info['models_to_merge']
                model_ids_or_paths = [
                    self.generate.model_manager.model_name_or_path(x) for x in models_to_merge]
                merged_pipe = merge_diffusion_models(
                    model_ids_or_paths, model_merge_info['alpha'], model_merge_info['interp'], model_merge_info['force'])

                dump_path = global_models_dir() / 'merged_models'
                if model_merge_info['model_merge_save_path'] is not None:
                    dump_path = Path(model_merge_info['model_merge_save_path'])

                os.makedirs(dump_path, exist_ok=True)
                dump_path = dump_path / model_merge_info['merged_model_name']
                merged_pipe.save_pretrained(dump_path, safe_serialization=1)

                merged_model_config = dict(
                    model_name=model_merge_info['merged_model_name'],
                    description=f'Merge of models {", ".join(models_to_merge)}',
                    commit_to_conf=opt.conf
                )

                if vae := self.generate.model_manager.config[models_to_merge[0]].get("vae", None):
                    print(
                        f">> Using configured VAE assigned to {models_to_merge[0]}")
                    merged_model_config.update(vae=vae)

                self.generate.model_manager.import_diffuser_model(
                    dump_path, **merged_model_config)
                new_model_list = self.generate.model_manager.list_models()

                socketio.emit(
                    "modelsMerged",
                    {"merged_models": models_to_merge,
                     "merged_model_name": model_merge_info['merged_model_name'],
                     "model_list": new_model_list, 'update': True},
                )
                print(f">> Models Merged: {models_to_merge}")
                print(
                    f">> New Model Added: {model_merge_info['merged_model_name']}")
            except Exception as e:
                self.handle_exceptions(e)

        @socketio.on("requestEmptyTempFolder")
        def empty_temp_folder():
            try:
                temp_files = glob.glob(os.path.join(self.temp_image_path, "*"))
                for f in temp_files:
                    try:
                        os.remove(f)
                        thumbnail_path = os.path.join(
                            self.thumbnail_image_path,
                            os.path.splitext(os.path.basename(f))[0] + ".webp",
                        )
                        os.remove(thumbnail_path)
                    except Exception as e:
                        socketio.emit(
                            "error", {"message": f"Unable to delete {f}: {str(e)}"})
                        pass

                socketio.emit("tempFolderEmptied")
            except Exception as e:
                self.handle_exceptions(e)

        @socketio.on("requestSaveStagingAreaImageToGallery")
        def save_temp_image_to_gallery(url):
            try:
                image_path = self.get_image_path_from_url(url)
                new_path = os.path.join(
                    self.result_path, os.path.basename(image_path))
                shutil.copy2(image_path, new_path)

                if os.path.splitext(new_path)[1] == ".png":
                    metadata = retrieve_metadata(new_path)
                else:
                    metadata = {}

                pil_image = Image.open(new_path)

                (width, height) = pil_image.size

                thumbnail_path = save_thumbnail(
                    pil_image, os.path.basename(
                        new_path), self.thumbnail_image_path
                )

                image_array = [
                    {
                        "url": self.get_url_from_image_path(new_path),
                        "thumbnail": self.get_url_from_image_path(thumbnail_path),
                        "mtime": os.path.getmtime(new_path),
                        "metadata": metadata,
                        "width": width,
                        "height": height,
                        "category": "result",
                    }
                ]

                socketio.emit(
                    "galleryImages",
                    {"images": image_array, "category": "result"},
                )

            except Exception as e:
                self.handle_exceptions(e)

        @socketio.on("requestLatestImages")
        def handle_request_latest_images(category, latest_mtime):
            try:
                base_path = (
                    self.result_path if category == "result" else self.init_image_path
                )

                paths = []

                for ext in ("*.png", "*.jpg", "*.jpeg"):
                    paths.extend(glob.glob(os.path.join(base_path, ext)))

                image_paths = sorted(
                    paths, key=lambda x: os.path.getmtime(x), reverse=True
                )

                image_paths = list(
                    filter(
                        lambda x: os.path.getmtime(x) > latest_mtime,
                        image_paths,
                    )
                )

                image_array = []

                for path in image_paths:
                    try:
                        if os.path.splitext(path)[1] == ".png":
                            metadata = retrieve_metadata(path)
                        else:
                            metadata = {}

                        pil_image = Image.open(path)
                        (width, height) = pil_image.size

                        thumbnail_path = save_thumbnail(
                            pil_image, os.path.basename(
                                path), self.thumbnail_image_path
                        )

                        image_array.append(
                            {
                                "url": self.get_url_from_image_path(path),
                                "thumbnail": self.get_url_from_image_path(
                                    thumbnail_path
                                ),
                                "mtime": os.path.getmtime(path),
                                "metadata": metadata.get("sd-metadata"),
                                "dreamPrompt": metadata.get("Dream"),
                                "width": width,
                                "height": height,
                                "category": category,
                            }
                        )
                    except Exception as e:
                        socketio.emit(
                            "error", {"message": f"Unable to load {path}: {str(e)}"})
                        pass

                socketio.emit(
                    "galleryImages",
                    {"images": image_array, "category": category},
                )
            except Exception as e:
                self.handle_exceptions(e)

        @socketio.on("requestImages")
        def handle_request_images(category, earliest_mtime=None):
            try:
                page_size = 50

                base_path = (
                    self.result_path if category == "result" else self.init_image_path
                )

                paths = []
                for ext in ("*.png", "*.jpg", "*.jpeg"):
                    paths.extend(glob.glob(os.path.join(base_path, ext)))

                image_paths = sorted(
                    paths, key=lambda x: os.path.getmtime(x), reverse=True
                )

                if earliest_mtime:
                    image_paths = list(
                        filter(
                            lambda x: os.path.getmtime(x) < earliest_mtime,
                            image_paths,
                        )
                    )

                areMoreImagesAvailable = len(image_paths) >= page_size
                image_paths = image_paths[slice(0, page_size)]

                image_array = []
                for path in image_paths:
                    try:
                        if os.path.splitext(path)[1] == ".png":
                            metadata = retrieve_metadata(path)
                        else:
                            metadata = {}

                        pil_image = Image.open(path)
                        (width, height) = pil_image.size

                        thumbnail_path = save_thumbnail(
                            pil_image, os.path.basename(
                                path), self.thumbnail_image_path
                        )

                        image_array.append(
                            {
                                "url": self.get_url_from_image_path(path),
                                "thumbnail": self.get_url_from_image_path(
                                    thumbnail_path
                                ),
                                "mtime": os.path.getmtime(path),
                                "metadata": metadata.get("sd-metadata"),
                                "dreamPrompt": metadata.get("Dream"),
                                "width": width,
                                "height": height,
                                "category": category,
                            }
                        )
                    except Exception as e:
                        print(f">> Unable to load {path}")
                        socketio.emit(
                            "error", {"message": f"Unable to load {path}: {str(e)}"})
                        pass

                socketio.emit(
                    "galleryImages",
                    {
                        "images": image_array,
                        "areMoreImagesAvailable": areMoreImagesAvailable,
                        "category": category,
                    },
                )
            except Exception as e:
                self.handle_exceptions(e)

        @socketio.on("generateImage")
        def handle_generate_image_event(
            generation_parameters, esrgan_parameters, facetool_parameters
        ):
            try:
                # truncate long init_mask/init_img base64 if needed
                printable_parameters = {
                    **generation_parameters,
                }

                if "init_img" in generation_parameters:
                    printable_parameters["init_img"] = (
                        printable_parameters["init_img"][:64] + "..."
                    )

                if "init_mask" in generation_parameters:
                    printable_parameters["init_mask"] = (
                        printable_parameters["init_mask"][:64] + "..."
                    )

                print(
                    f'\n>> Image Generation Parameters:\n\n{printable_parameters}\n')
                print(f'>> ESRGAN Parameters: {esrgan_parameters}')
                print(f'>> Facetool Parameters: {facetool_parameters}')

                self.generate_images(
                    generation_parameters,
                    esrgan_parameters,
                    facetool_parameters,
                )
            except Exception as e:
                self.handle_exceptions(e)

        @socketio.on("runPostprocessing")
        def handle_run_postprocessing(original_image, postprocessing_parameters):
            try:
                print(
                    f'>> Postprocessing requested for "{original_image["url"]}": {postprocessing_parameters}'
                )

                progress = Progress()

                socketio.emit("progressUpdate", progress.to_formatted_dict())
                eventlet.sleep(0)

                original_image_path = self.get_image_path_from_url(
                    original_image["url"]
                )

                image = Image.open(original_image_path)

                try:
                    seed = original_image["metadata"]["image"]["seed"]
                except KeyError:
                    seed = "unknown_seed"
                    pass

                if postprocessing_parameters["type"] == "esrgan":
                    progress.set_current_status("common.statusUpscalingESRGAN")
                elif postprocessing_parameters["type"] == "gfpgan":
                    progress.set_current_status(
                        "common.statusRestoringFacesGFPGAN")
                elif postprocessing_parameters["type"] == "codeformer":
                    progress.set_current_status(
                        "common.statusRestoringFacesCodeFormer")

                socketio.emit("progressUpdate", progress.to_formatted_dict())
                eventlet.sleep(0)

                if postprocessing_parameters["type"] == "esrgan":
                    image = self.esrgan.process(
                        image=image,
                        upsampler_scale=postprocessing_parameters["upscale"][0],
                        denoise_str=postprocessing_parameters["upscale"][1],
                        strength=postprocessing_parameters["upscale"][2],
                        seed=seed,
                    )
                elif postprocessing_parameters["type"] == "gfpgan":
                    image = self.gfpgan.process(
                        image=image,
                        strength=postprocessing_parameters["facetool_strength"],
                        seed=seed,
                    )
                elif postprocessing_parameters["type"] == "codeformer":
                    image = self.codeformer.process(
                        image=image,
                        strength=postprocessing_parameters["facetool_strength"],
                        fidelity=postprocessing_parameters["codeformer_fidelity"],
                        seed=seed,
                        device="cpu"
                        if str(self.generate.device) == "mps"
                        else self.generate.device,
                    )
                else:
                    raise TypeError(
                        f'{postprocessing_parameters["type"]} is not a valid postprocessing type'
                    )

                progress.set_current_status("common.statusSavingImage")
                socketio.emit("progressUpdate", progress.to_formatted_dict())
                eventlet.sleep(0)

                postprocessing_parameters["seed"] = seed
                metadata = self.parameters_to_post_processed_image_metadata(
                    parameters=postprocessing_parameters,
                    original_image_path=original_image_path,
                )

                command = parameters_to_command(postprocessing_parameters)

                (width, height) = image.size

                path = self.save_result_image(
                    image,
                    command,
                    metadata,
                    self.result_path,
                    postprocessing=postprocessing_parameters["type"],
                )

                thumbnail_path = save_thumbnail(
                    image, os.path.basename(path), self.thumbnail_image_path
                )

                self.write_log_message(
                    f'[Postprocessed] "{original_image_path}" > "{path}": {postprocessing_parameters}'
                )

                progress.mark_complete()
                socketio.emit("progressUpdate", progress.to_formatted_dict())
                eventlet.sleep(0)

                socketio.emit(
                    "postprocessingResult",
                    {
                        "url": self.get_url_from_image_path(path),
                        "thumbnail": self.get_url_from_image_path(thumbnail_path),
                        "mtime": os.path.getmtime(path),
                        "metadata": metadata,
                        "dreamPrompt": command,
                        "width": width,
                        "height": height,
                    },
                )
            except Exception as e:
                self.handle_exceptions(e)

        @socketio.on("cancel")
        def handle_cancel():
            print(">> Cancel processing requested")
            self.canceled.set()

        # TODO: I think this needs a safety mechanism.
        @socketio.on("deleteImage")
        def handle_delete_image(url, thumbnail, uuid, category):
            try:
                print(f'>> Delete requested "{url}"')
                from send2trash import send2trash

                path = self.get_image_path_from_url(url)
                thumbnail_path = self.get_image_path_from_url(thumbnail)

                send2trash(path)
                send2trash(thumbnail_path)

                socketio.emit(
                    "imageDeleted",
                    {"url": url, "uuid": uuid, "category": category},
                )
            except Exception as e:
                self.handle_exceptions(e)

    # App Functions
    def get_system_config(self):
        model_list: dict = self.generate.model_manager.list_models()
        active_model_name = None

        for model_name, model_dict in model_list.items():
            if model_dict["status"] == "active":
                active_model_name = model_name

        return {
            "model": "stable diffusion",
            "model_weights": active_model_name,
            "model_hash": self.generate.model_hash,
            "app_id": APP_ID,
            "app_version": APP_VERSION,
        }

    def generate_images(
        self, generation_parameters, esrgan_parameters, facetool_parameters
    ):
        try:
            self.canceled.clear()

            step_index = 1
            prior_variations = (
                generation_parameters["with_variations"]
                if "with_variations" in generation_parameters
                else []
            )

            actual_generation_mode = generation_parameters["generation_mode"]
            original_bounding_box = None

            progress = Progress(generation_parameters=generation_parameters)

            self.socketio.emit("progressUpdate", progress.to_formatted_dict())
            eventlet.sleep(0)

            """
            TODO:
            If a result image is used as an init image, and then deleted, we will want to be
            able to use it as an init image in the future. Need to handle this case.
            """

            """
            Prepare for generation based on generation_mode
            """
            if generation_parameters["generation_mode"] == "unifiedCanvas":
                """
                generation_parameters["init_img"] is a base64 image
                generation_parameters["init_mask"] is a base64 image

                So we need to convert each into a PIL Image.
                """

                init_img_url = generation_parameters["init_img"]

                original_bounding_box = generation_parameters["bounding_box"].copy(
                )

                initial_image = dataURL_to_image(
                    generation_parameters["init_img"]
                ).convert("RGBA")

                """
                The outpaint image and mask are pre-cropped by the UI, so the bounding box we pass
                to the generator should be:
                    {
                        "x": 0,
                        "y": 0,
                        "width": original_bounding_box["width"],
                        "height": original_bounding_box["height"]
                    }
                """

                generation_parameters["bounding_box"]["x"] = 0
                generation_parameters["bounding_box"]["y"] = 0

                # Convert mask dataURL to an image and convert to greyscale
                mask_image = dataURL_to_image(
                    generation_parameters["init_mask"]
                ).convert("L")

                actual_generation_mode = get_canvas_generation_mode(
                    initial_image, mask_image
                )

                """
                Apply the mask to the init image, creating a "mask" image with
                transparency where inpainting should occur. This is the kind of
                mask that prompt2image() needs.
                """
                alpha_mask = initial_image.copy()
                alpha_mask.putalpha(mask_image)

                generation_parameters["init_img"] = initial_image
                generation_parameters["init_mask"] = alpha_mask

                # Remove the unneeded parameters for whichever mode we are doing
                if actual_generation_mode == "inpainting":
                    generation_parameters.pop("seam_size", None)
                    generation_parameters.pop("seam_blur", None)
                    generation_parameters.pop("seam_strength", None)
                    generation_parameters.pop("seam_steps", None)
                    generation_parameters.pop("tile_size", None)
                    generation_parameters.pop("force_outpaint", None)
                elif actual_generation_mode == "img2img":
                    generation_parameters["height"] = original_bounding_box["height"]
                    generation_parameters["width"] = original_bounding_box["width"]
                    generation_parameters.pop("init_mask", None)
                    generation_parameters.pop("seam_size", None)
                    generation_parameters.pop("seam_blur", None)
                    generation_parameters.pop("seam_strength", None)
                    generation_parameters.pop("seam_steps", None)
                    generation_parameters.pop("tile_size", None)
                    generation_parameters.pop("force_outpaint", None)
                    generation_parameters.pop("infill_method", None)
                elif actual_generation_mode == "txt2img":
                    generation_parameters["height"] = original_bounding_box["height"]
                    generation_parameters["width"] = original_bounding_box["width"]
                    generation_parameters.pop("strength", None)
                    generation_parameters.pop("fit", None)
                    generation_parameters.pop("init_img", None)
                    generation_parameters.pop("init_mask", None)
                    generation_parameters.pop("seam_size", None)
                    generation_parameters.pop("seam_blur", None)
                    generation_parameters.pop("seam_strength", None)
                    generation_parameters.pop("seam_steps", None)
                    generation_parameters.pop("tile_size", None)
                    generation_parameters.pop("force_outpaint", None)
                    generation_parameters.pop("infill_method", None)

            elif generation_parameters["generation_mode"] == "img2img":
                init_img_url = generation_parameters["init_img"]
                init_img_path = self.get_image_path_from_url(init_img_url)
                generation_parameters["init_img"] = Image.open(
                    init_img_path).convert('RGB')

            def image_progress(sample, step):
                if self.canceled.is_set():
                    raise CanceledException

                nonlocal step_index
                nonlocal generation_parameters
                nonlocal progress

                generation_messages = {
                    "txt2img": "common.statusGeneratingTextToImage",
                    "img2img": "common.statusGeneratingImageToImage",
                    "inpainting": "common.statusGeneratingInpainting",
                    "outpainting": "common.statusGeneratingOutpainting",
                }

                progress.set_current_step(step + 1)
                progress.set_current_status(
                    f"{generation_messages[actual_generation_mode]}"
                )
                progress.set_current_status_has_steps(True)

                if (
                    generation_parameters["progress_images"]
                    and step % generation_parameters["save_intermediates"] == 0
                    and step < generation_parameters["steps"] - 1
                ):
                    image = self.generate.sample_to_image(sample)
                    metadata = self.parameters_to_generated_image_metadata(
                        generation_parameters
                    )
                    command = parameters_to_command(generation_parameters)

                    (width, height) = image.size

                    path = self.save_result_image(
                        image,
                        command,
                        metadata,
                        self.intermediate_path,
                        step_index=step_index,
                        postprocessing=False,
                    )

                    step_index += 1
                    self.socketio.emit(
                        "intermediateResult",
                        {
                            "url": self.get_url_from_image_path(path),
                            "mtime": os.path.getmtime(path),
                            "metadata": metadata,
                            "width": width,
                            "height": height,
                            "generationMode": generation_parameters["generation_mode"],
                            "boundingBox": original_bounding_box,
                        },
                    )

                if generation_parameters["progress_latents"]:
                    image = self.generate.sample_to_lowres_estimated_image(
                        sample)
                    (width, height) = image.size
                    width *= 8
                    height *= 8
                    img_base64 = image_to_dataURL(image)
                    self.socketio.emit(
                        "intermediateResult",
                        {
                            "url": img_base64,
                            "isBase64": True,
                            "mtime": 0,
                            "metadata": {},
                            "width": width,
                            "height": height,
                            "generationMode": generation_parameters["generation_mode"],
                            "boundingBox": original_bounding_box,
                        },
                    )

                self.socketio.emit(
                    "progressUpdate", progress.to_formatted_dict())
                eventlet.sleep(0)

            def image_done(image, seed, first_seed, attention_maps_image=None):
                if self.canceled.is_set():
                    raise CanceledException

                nonlocal generation_parameters
                nonlocal esrgan_parameters
                nonlocal facetool_parameters
                nonlocal progress

                nonlocal prior_variations

                """
                Tidy up after generation based on generation_mode
                """
                # paste the inpainting image back onto the original
                if generation_parameters["generation_mode"] == "inpainting":
                    image = paste_image_into_bounding_box(
                        Image.open(init_img_path),
                        image,
                        **generation_parameters["bounding_box"],
                    )

                progress.set_current_status("common.statusGenerationComplete")

                self.socketio.emit(
                    "progressUpdate", progress.to_formatted_dict())
                eventlet.sleep(0)

                all_parameters = generation_parameters
                postprocessing = False

                if (
                    "variation_amount" in all_parameters
                    and all_parameters["variation_amount"] > 0
                ):
                    first_seed = first_seed or seed
                    this_variation = [
                        [seed, all_parameters["variation_amount"]]]
                    all_parameters["with_variations"] = (
                        prior_variations + this_variation
                    )
                    all_parameters["seed"] = first_seed
                elif "with_variations" in all_parameters:
                    all_parameters["seed"] = first_seed
                else:
                    all_parameters["seed"] = seed

                if self.canceled.is_set():
                    raise CanceledException

                if esrgan_parameters:
                    progress.set_current_status("common.statusUpscaling")
                    progress.set_current_status_has_steps(False)
                    self.socketio.emit(
                        "progressUpdate", progress.to_formatted_dict())
                    eventlet.sleep(0)

                    image = self.esrgan.process(
                        image=image,
                        upsampler_scale=esrgan_parameters["level"],
                        denoise_str=esrgan_parameters['denoise_str'],
                        strength=esrgan_parameters["strength"],
                        seed=seed,
                    )

                    postprocessing = True
                    all_parameters["upscale"] = [
                        esrgan_parameters["level"],
                        esrgan_parameters['denoise_str'],
                        esrgan_parameters["strength"],
                    ]

                if self.canceled.is_set():
                    raise CanceledException

                if facetool_parameters:
                    if facetool_parameters["type"] == "gfpgan":
                        progress.set_current_status(
                            "common.statusRestoringFacesGFPGAN")
                    elif facetool_parameters["type"] == "codeformer":
                        progress.set_current_status(
                            "common.statusRestoringFacesCodeFormer")

                    progress.set_current_status_has_steps(False)
                    self.socketio.emit(
                        "progressUpdate", progress.to_formatted_dict())
                    eventlet.sleep(0)

                    if facetool_parameters["type"] == "gfpgan":
                        image = self.gfpgan.process(
                            image=image,
                            strength=facetool_parameters["strength"],
                            seed=seed,
                        )
                    elif facetool_parameters["type"] == "codeformer":
                        image = self.codeformer.process(
                            image=image,
                            strength=facetool_parameters["strength"],
                            fidelity=facetool_parameters["codeformer_fidelity"],
                            seed=seed,
                            device="cpu"
                            if str(self.generate.device) == "mps"
                            else self.generate.device,
                        )
                        all_parameters["codeformer_fidelity"] = facetool_parameters[
                            "codeformer_fidelity"
                        ]

                    postprocessing = True
                    all_parameters["facetool_strength"] = facetool_parameters[
                        "strength"
                    ]
                    all_parameters["facetool_type"] = facetool_parameters["type"]

                progress.set_current_status("common.statusSavingImage")
                self.socketio.emit(
                    "progressUpdate", progress.to_formatted_dict())
                eventlet.sleep(0)

                # restore the stashed URLS and discard the paths, we are about to send the result to client
                all_parameters["init_img"] = (
                    init_img_url
                    if generation_parameters["generation_mode"] == "img2img"
                    else ""
                )

                if "init_mask" in all_parameters:
                    # TODO: store the mask in metadata
                    all_parameters["init_mask"] = ""

                if generation_parameters["generation_mode"] == "unifiedCanvas":
                    all_parameters["bounding_box"] = original_bounding_box

                metadata = self.parameters_to_generated_image_metadata(
                    all_parameters)

                command = parameters_to_command(all_parameters)

                (width, height) = image.size

                generated_image_outdir = (
                    self.result_path
                    if generation_parameters["generation_mode"]
                    in ["txt2img", "img2img"]
                    else self.temp_image_path
                )

                path = self.save_result_image(
                    image,
                    command,
                    metadata,
                    generated_image_outdir,
                    postprocessing=postprocessing,
                )

                thumbnail_path = save_thumbnail(
                    image, os.path.basename(path), self.thumbnail_image_path
                )

                print(f'\n\n>> Image generated: "{path}"\n')
                self.write_log_message(f'[Generated] "{path}": {command}')

                if progress.total_iterations > progress.current_iteration:
                    progress.set_current_step(1)
                    progress.set_current_status(
                        "common.statusIterationComplete")
                    progress.set_current_status_has_steps(False)
                else:
                    progress.mark_complete()

                self.socketio.emit(
                    "progressUpdate", progress.to_formatted_dict())
                eventlet.sleep(0)

                parsed_prompt, _ = get_prompt_structure(
                    generation_parameters["prompt"])
                tokens = None if type(parsed_prompt) is Blend else \
                    get_tokens_for_prompt(self.generate.model, parsed_prompt)
                attention_maps_image_base64_url = None if attention_maps_image is None \
                    else image_to_dataURL(attention_maps_image)

                self.socketio.emit(
                    "generationResult",
                    {
                        "url": self.get_url_from_image_path(path),
                        "thumbnail": self.get_url_from_image_path(thumbnail_path),
                        "mtime": os.path.getmtime(path),
                        "metadata": metadata,
                        "dreamPrompt": command,
                        "width": width,
                        "height": height,
                        "boundingBox": original_bounding_box,
                        "generationMode": generation_parameters["generation_mode"],
                        "attentionMaps": attention_maps_image_base64_url,
                        "tokens": tokens,
                    },
                )
                eventlet.sleep(0)

                progress.set_current_iteration(progress.current_iteration + 1)

            def diffusers_step_callback_adapter(*cb_args, **kwargs):
                if isinstance(cb_args[0], PipelineIntermediateState):
                    progress_state: PipelineIntermediateState = cb_args[0]
                    return image_progress(progress_state.latents, progress_state.step)
                else:
                    return image_progress(*cb_args, **kwargs)

            self.generate.prompt2image(
                **generation_parameters,
                step_callback=diffusers_step_callback_adapter,
                image_callback=image_done
            )

        except KeyboardInterrupt:
            # Clear the CUDA cache on an exception
            self.empty_cuda_cache()
            self.socketio.emit("processingCanceled")
            raise
        except CanceledException:
            # Clear the CUDA cache on an exception
            self.empty_cuda_cache()
            self.socketio.emit("processingCanceled")
            pass
        except Exception as e:
            # Clear the CUDA cache on an exception
            self.empty_cuda_cache()
            print(e)
            self.handle_exceptions(e)

    def empty_cuda_cache(self):
        if self.generate.device.type == "cuda":
            import torch.cuda

            torch.cuda.empty_cache()

    def parameters_to_generated_image_metadata(self, parameters):
        try:
            # top-level metadata minus `image` or `images`
            metadata = self.get_system_config()
            # remove any image keys not mentioned in RFC #266
            rfc266_img_fields = [
                "type",
                "postprocessing",
                "sampler",
                "prompt",
                "seed",
                "variations",
                "steps",
                "cfg_scale",
                "threshold",
                "perlin",
                "step_number",
                "width",
                "height",
                "extra",
                "seamless",
                "hires_fix",
            ]

            rfc_dict = {}

            for item in parameters.items():
                key, value = item
                if key in rfc266_img_fields:
                    rfc_dict[key] = value

            postprocessing = []

            rfc_dict["type"] = parameters["generation_mode"]

            # 'postprocessing' is either null or an
            if "facetool_strength" in parameters:
                facetool_parameters = {
                    "type": str(parameters["facetool_type"]),
                    "strength": float(parameters["facetool_strength"]),
                }

                if parameters["facetool_type"] == "codeformer":
                    facetool_parameters["fidelity"] = float(
                        parameters["codeformer_fidelity"]
                    )

                postprocessing.append(facetool_parameters)

            if "upscale" in parameters:
                postprocessing.append(
                    {
                        "type": "esrgan",
                        "scale": int(parameters["upscale"][0]),
                        "denoise_str": int(parameters["upscale"][1]),
                        "strength": float(parameters["upscale"][2]),
                    }
                )

            rfc_dict["postprocessing"] = (
                postprocessing if len(postprocessing) > 0 else None
            )

            # semantic drift
            rfc_dict["sampler"] = parameters["sampler_name"]

            # display weighted subprompts (liable to change)
            subprompts = split_weighted_subprompts(
                parameters["prompt"], skip_normalize=True
            )
            subprompts = [{"prompt": x[0], "weight": x[1]} for x in subprompts]
            rfc_dict["prompt"] = subprompts

            # 'variations' should always exist and be an array, empty or consisting of {'seed': seed, 'weight': weight} pairs
            variations = []

            if "with_variations" in parameters:
                variations = [
                    {"seed": x[0], "weight": x[1]}
                    for x in parameters["with_variations"]
                ]

            rfc_dict["variations"] = variations

            if rfc_dict["type"] == "img2img":
                rfc_dict["strength"] = parameters["strength"]
                rfc_dict["fit"] = parameters["fit"]  # TODO: Noncompliant
                rfc_dict["orig_hash"] = calculate_init_img_hash(
                    self.get_image_path_from_url(parameters["init_img"])
                )
                rfc_dict["init_image_path"] = parameters[
                    "init_img"
                ]  # TODO: Noncompliant

            metadata["image"] = rfc_dict

            return metadata

        except Exception as e:
            self.handle_exceptions(e)

    def parameters_to_post_processed_image_metadata(
        self, parameters, original_image_path
    ):
        try:
            current_metadata = retrieve_metadata(
                original_image_path)["sd-metadata"]
            postprocessing_metadata = {}

            """
            if we don't have an original image metadata to reconstruct,
            need to record the original image and its hash
            """
            if "image" not in current_metadata:
                current_metadata["image"] = {}

                orig_hash = calculate_init_img_hash(
                    self.get_image_path_from_url(original_image_path)
                )

                postprocessing_metadata["orig_path"] = (original_image_path,)
                postprocessing_metadata["orig_hash"] = orig_hash

            if parameters["type"] == "esrgan":
                postprocessing_metadata["type"] = "esrgan"
                postprocessing_metadata["scale"] = parameters["upscale"][0]
                postprocessing_metadata["denoise_str"] = parameters["upscale"][1]
                postprocessing_metadata["strength"] = parameters["upscale"][2]
            elif parameters["type"] == "gfpgan":
                postprocessing_metadata["type"] = "gfpgan"
                postprocessing_metadata["strength"] = parameters["facetool_strength"]
            elif parameters["type"] == "codeformer":
                postprocessing_metadata["type"] = "codeformer"
                postprocessing_metadata["strength"] = parameters["facetool_strength"]
                postprocessing_metadata["fidelity"] = parameters["codeformer_fidelity"]

            else:
                raise TypeError(f"Invalid type: {parameters['type']}")

            if "postprocessing" in current_metadata["image"] and isinstance(
                current_metadata["image"]["postprocessing"], list
            ):
                current_metadata["image"]["postprocessing"].append(
                    postprocessing_metadata
                )
            else:
                current_metadata["image"]["postprocessing"] = [
                    postprocessing_metadata]

            return current_metadata

        except Exception as e:
            self.handle_exceptions(e)

    def save_result_image(
        self,
        image,
        command,
        metadata,
        output_dir,
        step_index=None,
        postprocessing=False,
    ):
        try:
            pngwriter = PngWriter(output_dir)

            number_prefix = pngwriter.unique_prefix()

            uuid = uuid4().hex
            truncated_uuid = uuid[:8]

            seed = "unknown_seed"

            if "image" in metadata:
                if "seed" in metadata["image"]:
                    seed = metadata["image"]["seed"]

            filename = f"{number_prefix}.{truncated_uuid}.{seed}"

            if step_index:
                filename += f".{step_index}"
            if postprocessing:
                filename += ".postprocessed"

            filename += ".png"

            path = pngwriter.save_image_and_prompt_to_png(
                image=image,
                dream_prompt=command,
                metadata=metadata,
                name=filename,
            )

            return os.path.abspath(path)

        except Exception as e:
            self.handle_exceptions(e)

    def make_unique_init_image_filename(self, name):
        try:
            uuid = uuid4().hex
            split = os.path.splitext(name)
            name = f"{split[0]}.{uuid}{split[1]}"
            return name
        except Exception as e:
            self.handle_exceptions(e)

    def calculate_real_steps(self, steps, strength, has_init_image):
        import math

        return math.floor(strength * steps) if has_init_image else steps

    def write_log_message(self, message):
        """Logs the filename and parameters used to generate or process that image to log file"""
        try:
            message = f"{message}\n"
            with open(self.log_path, "a", encoding="utf-8") as file:
                file.writelines(message)

        except Exception as e:
            self.handle_exceptions(e)

    def get_image_path_from_url(self, url):
        """Given a url to an image used by the client, returns the absolute file path to that image"""
        try:
            if "init-images" in url:
                return os.path.abspath(
                    os.path.join(self.init_image_path, os.path.basename(url))
                )
            elif "mask-images" in url:
                return os.path.abspath(
                    os.path.join(self.mask_image_path, os.path.basename(url))
                )
            elif "intermediates" in url:
                return os.path.abspath(
                    os.path.join(self.intermediate_path, os.path.basename(url))
                )
            elif "temp-images" in url:
                return os.path.abspath(
                    os.path.join(self.temp_image_path, os.path.basename(url))
                )
            elif "thumbnails" in url:
                return os.path.abspath(
                    os.path.join(self.thumbnail_image_path,
                                 os.path.basename(url))
                )
            else:
                return os.path.abspath(
                    os.path.join(self.result_path, os.path.basename(url))
                )
        except Exception as e:
            self.handle_exceptions(e)

    def get_url_from_image_path(self, path):
        """Given an absolute file path to an image, returns the URL that the client can use to load the image"""
        try:
            if "init-images" in path:
                return os.path.join(self.init_image_url, os.path.basename(path))
            elif "mask-images" in path:
                return os.path.join(self.mask_image_url, os.path.basename(path))
            elif "intermediates" in path:
                return os.path.join(self.intermediate_url, os.path.basename(path))
            elif "temp-images" in path:
                return os.path.join(self.temp_image_url, os.path.basename(path))
            elif "thumbnails" in path:
                return os.path.join(self.thumbnail_image_url, os.path.basename(path))
            else:
                return os.path.join(self.result_url, os.path.basename(path))
        except Exception as e:
            self.handle_exceptions(e)

    def save_file_unique_uuid_name(self, bytes, name, path):
        try:
            uuid = uuid4().hex
            truncated_uuid = uuid[:8]

            split = os.path.splitext(name)
            name = f"{split[0]}.{truncated_uuid}{split[1]}"

            file_path = os.path.join(path, name)

            os.makedirs(os.path.dirname(file_path), exist_ok=True)

            newFile = open(file_path, "wb")
            newFile.write(bytes)

            return file_path
        except Exception as e:
            self.handle_exceptions(e)

    def handle_exceptions(self, exception, emit_key: str = 'error'):
        self.socketio.emit(emit_key, {"message": (str(exception))})
        print("\n")
        traceback.print_exc()
        print("\n")


class Progress:
    def __init__(self, generation_parameters=None):
        self.current_step = 1
        self.total_steps = (
            self._calculate_real_steps(
                steps=generation_parameters["steps"],
                strength=generation_parameters["strength"]
                if "strength" in generation_parameters
                else None,
                has_init_image="init_img" in generation_parameters,
            )
            if generation_parameters
            else 1
        )
        self.current_iteration = 1
        self.total_iterations = (
            generation_parameters["iterations"] if generation_parameters else 1
        )
        self.current_status = "common.statusPreparing"
        self.is_processing = True
        self.current_status_has_steps = False
        self.has_error = False

    def set_current_step(self, current_step):
        self.current_step = current_step

    def set_total_steps(self, total_steps):
        self.total_steps = total_steps

    def set_current_iteration(self, current_iteration):
        self.current_iteration = current_iteration

    def set_total_iterations(self, total_iterations):
        self.total_iterations = total_iterations

    def set_current_status(self, current_status):
        self.current_status = current_status

    def set_is_processing(self, is_processing):
        self.is_processing = is_processing

    def set_current_status_has_steps(self, current_status_has_steps):
        self.current_status_has_steps = current_status_has_steps

    def set_has_error(self, has_error):
        self.has_error = has_error

    def mark_complete(self):
        self.current_status = "common.statusProcessingComplete"
        self.current_step = 0
        self.total_steps = 0
        self.current_iteration = 0
        self.total_iterations = 0
        self.is_processing = False

    def to_formatted_dict(
        self,
    ):
        return {
            "currentStep": self.current_step,
            "totalSteps": self.total_steps,
            "currentIteration": self.current_iteration,
            "totalIterations": self.total_iterations,
            "currentStatus": self.current_status,
            "isProcessing": self.is_processing,
            "currentStatusHasSteps": self.current_status_has_steps,
            "hasError": self.has_error,
        }

    def _calculate_real_steps(self, steps, strength, has_init_image):
        return math.floor(strength * steps) if has_init_image else steps


class CanceledException(Exception):
    pass


"""
Returns a copy an image, cropped to a bounding box.
"""


def copy_image_from_bounding_box(
    image: ImageType, x: int, y: int, width: int, height: int
) -> ImageType:
    with image as im:
        bounds = (x, y, x + width, y + height)
        im_cropped = im.crop(bounds)
        return im_cropped


"""
Converts a base64 image dataURL into an image.
The dataURL is split on the first commma.
"""


def dataURL_to_image(dataURL: str) -> ImageType:
    image = Image.open(
        io.BytesIO(
            base64.decodebytes(
                bytes(
                    dataURL.split(",", 1)[1],
                    "utf-8",
                )
            )
        )
    )
    return image


"""
Converts an image into a base64 image dataURL.
"""


def image_to_dataURL(image: ImageType) -> str:
    buffered = io.BytesIO()
    image.save(buffered, format="PNG")
    image_base64 = "data:image/png;base64," + base64.b64encode(
        buffered.getvalue()
    ).decode("UTF-8")
    return image_base64


"""
Converts a base64 image dataURL into bytes.
The dataURL is split on the first commma.
"""


def dataURL_to_bytes(dataURL: str) -> bytes:
    return base64.decodebytes(
        bytes(
            dataURL.split(",", 1)[1],
            "utf-8",
        )
    )


"""
Pastes an image onto another with a bounding box.
"""


def paste_image_into_bounding_box(
    recipient_image: ImageType,
    donor_image: ImageType,
    x: int,
    y: int,
    width: int,
    height: int,
) -> ImageType:
    with recipient_image as im:
        bounds = (x, y, x + width, y + height)
        im.paste(donor_image, bounds)
        return recipient_image


"""
Saves a thumbnail of an image, returning its path.
"""


def save_thumbnail(
    image: ImageType,
    filename: str,
    path: str,
    size: int = 256,
) -> str:
    base_filename = os.path.splitext(filename)[0]
    thumbnail_path = os.path.join(path, base_filename + ".webp")

    if os.path.exists(thumbnail_path):
        return thumbnail_path

    thumbnail_width = size
    thumbnail_height = round(size * (image.height / image.width))

    image_copy = image.copy()
    image_copy.thumbnail(size=(thumbnail_width, thumbnail_height))

    image_copy.save(thumbnail_path, "WEBP")

    return thumbnail_path