# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)

from argparse import ArgumentParser
import base64
from datetime import datetime, timezone
import glob
import json
import os
from pathlib import Path
from queue import Empty, Queue
import shlex
from threading import Thread
import time
from flask_socketio import SocketIO, join_room, leave_room
from ldm.invoke.args import Args
from ldm.invoke.generator import embiggen
from PIL import Image

from ldm.invoke.pngwriter import PngWriter
from ldm.invoke.server import CanceledException
from ldm.generate import Generate
from server.models import DreamResult, JobRequest, PaginatedItems, ProgressType, Signal

class JobQueueService:
  __queue: Queue = Queue()

  def push(self, dreamRequest: DreamResult):
    self.__queue.put(dreamRequest)

  def get(self, timeout: float = None) -> DreamResult:
    return self.__queue.get(timeout= timeout)

class SignalQueueService:
  __queue: Queue = Queue()

  def push(self, signal: Signal):
    self.__queue.put(signal)

  def get(self) -> Signal:
    return self.__queue.get(block=False)


class SignalService:
  __socketio: SocketIO
  __queue: SignalQueueService

  def __init__(self, socketio: SocketIO, queue: SignalQueueService):
    self.__socketio = socketio
    self.__queue = queue

    def on_join(data):
      room = data['room']
      join_room(room)
      self.__socketio.emit("test", "something", room=room)
      
    def on_leave(data):
      room = data['room']
      leave_room(room)

    self.__socketio.on_event('join_room', on_join)
    self.__socketio.on_event('leave_room', on_leave)

    self.__socketio.start_background_task(self.__process)

  def __process(self):
    # preload the model
    print('Started signal queue processor')
    try:
      while True:
        try:
          signal = self.__queue.get()
          self.__socketio.emit(signal.event, signal.data, room=signal.room, broadcast=signal.broadcast)
        except Empty:
          pass
        finally:
          self.__socketio.sleep(0.001)

    except KeyboardInterrupt:
        print('Signal queue processor stopped')


  def emit(self, signal: Signal):
    self.__queue.push(signal)


# TODO: Name this better?
# TODO: Logging and signals should probably be event based (multiple listeners for an event)
class LogService:
  __location: str
  __logFile: str

  def __init__(self, location:str, file:str):
    self.__location = location
    self.__logFile = file

  def log(self, dreamResult: DreamResult, seed = None, upscaled = False):
    with open(os.path.join(self.__location, self.__logFile), "a") as log:
      log.write(f"{dreamResult.id}: {dreamResult.to_json()}\n")


class ImageStorageService:
  __location: str
  __pngWriter: PngWriter
  __legacyParser: ArgumentParser

  def __init__(self, location):
    self.__location = location
    self.__pngWriter = PngWriter(self.__location)
    self.__legacyParser = Args() # TODO: inject this?

  def __getName(self, dreamId: str, postfix: str = '') -> str:
    return f'{dreamId}{postfix}.png'

  def save(self, image, dreamResult: DreamResult, postfix: str = '') -> str:
    name = self.__getName(dreamResult.id, postfix)
    meta = dreamResult.to_json() # TODO: make all methods consistent with writing metadata. Standardize metadata.
    path = self.__pngWriter.save_image_and_prompt_to_png(image, dream_prompt=meta, metadata=None, name=name)
    return path

  def path(self, dreamId: str, postfix: str = '') -> str:
    name = self.__getName(dreamId, postfix)
    path = os.path.join(self.__location, name)
    return path
  
  # Returns true if found, false if not found or error
  def delete(self, dreamId: str, postfix: str = '') -> bool:
    path = self.path(dreamId, postfix)
    if (os.path.exists(path)):
      os.remove(path)
      return True
    else:
      return False
  
  def getMetadata(self, dreamId: str, postfix: str = '') -> DreamResult:
    path = self.path(dreamId, postfix)
    image = Image.open(path)
    text = image.text
    if text.__contains__('Dream'):
      dreamMeta = text.get('Dream')
      try:
        j = json.loads(dreamMeta)
        return DreamResult.from_json(j)
      except ValueError:
        # Try to parse command-line format (legacy metadata format)
        try:
          opt = self.__parseLegacyMetadata(dreamMeta)
          optd = opt.__dict__
          if (not 'width' in optd) or (optd.get('width') is None):
            optd['width'] = image.width
          if (not 'height' in optd) or (optd.get('height') is None):
            optd['height'] = image.height
          if (not 'steps' in optd) or (optd.get('steps') is None):
            optd['steps'] = 10 # No way around this unfortunately - seems like it wasn't storing this previously

          optd['time'] = os.path.getmtime(path) # Set timestamp manually (won't be exactly correct though)

          return DreamResult.from_json(optd)

        except:
          return None
    else:
      return None

  def __parseLegacyMetadata(self, command: str) -> DreamResult:
    # before splitting, escape single quotes so as not to mess
    # up the parser
    command = command.replace("'", "\\'")

    try:
        elements = shlex.split(command)
    except ValueError as e:
        return None

    # rearrange the arguments to mimic how it works in the Dream bot.
    switches = ['']
    switches_started = False

    for el in elements:
        if el[0] == '-' and not switches_started:
            switches_started = True
        if switches_started:
            switches.append(el)
        else:
            switches[0] += el
            switches[0] += ' '
    switches[0] = switches[0][: len(switches[0]) - 1]

    try:
        opt = self.__legacyParser.parse_cmd(switches)
        return opt
    except SystemExit:
        return None

  def list_files(self, page: int, perPage: int) -> PaginatedItems:
    files = sorted(glob.glob(os.path.join(self.__location,'*.png')), key=os.path.getmtime, reverse=True)
    count = len(files)

    startId = page * perPage
    pageCount = int(count / perPage) + 1
    endId = min(startId + perPage, count)
    items = [] if startId >= count else files[startId:endId]

    items = list(map(lambda f: Path(f).stem, items))

    return PaginatedItems(items, page, pageCount, perPage, count)


class GeneratorService:
  __model: Generate
  __queue: JobQueueService
  __imageStorage: ImageStorageService
  __intermediateStorage: ImageStorageService
  __log: LogService
  __thread: Thread
  __cancellationRequested: bool = False
  __signal_service: SignalService

  def __init__(self, model: Generate, queue: JobQueueService, imageStorage: ImageStorageService, intermediateStorage: ImageStorageService, log: LogService, signal_service: SignalService):
    self.__model = model
    self.__queue = queue
    self.__imageStorage = imageStorage
    self.__intermediateStorage = intermediateStorage
    self.__log = log
    self.__signal_service = signal_service

    # Create the background thread
    self.__thread = Thread(target=self.__process, name = "GeneratorService")
    self.__thread.daemon = True
    self.__thread.start()


  # Request cancellation of the current job
  def cancel(self):
    self.__cancellationRequested = True


  # TODO: Consider moving this to its own service if there's benefit in separating the generator
  def __process(self):
    # preload the model
    # TODO: support multiple models
    print('Preloading model')
    tic = time.time()
    self.__model.load_model()
    print(f'>> model loaded in', '%4.2fs' % (time.time() - tic))

    print('Started generation queue processor')
    try:
      while True:
        dreamRequest = self.__queue.get()
        self.__generate(dreamRequest)

    except KeyboardInterrupt:
        print('Generation queue processor stopped')


  def __on_start(self, jobRequest: JobRequest):
    self.__signal_service.emit(Signal.job_started(jobRequest.id))


  def __on_image_result(self, jobRequest: JobRequest, image, seed, upscaled=False):
    dreamResult = jobRequest.newDreamResult()
    dreamResult.seed = seed
    dreamResult.has_upscaled = upscaled
    dreamResult.iterations = 1
    jobRequest.results.append(dreamResult)
    # TODO: Separate status of GFPGAN?

    self.__imageStorage.save(image, dreamResult)
    
    # TODO: handle upscaling logic better (this is appending data to log, but only on first generation)
    if not upscaled:
      self.__log.log(dreamResult)

    # Send result signal
    self.__signal_service.emit(Signal.image_result(jobRequest.id, dreamResult.id, dreamResult))

    upscaling_requested = dreamResult.enable_upscale or dreamResult.enable_gfpgan
    
    # Report upscaling status
    # TODO: this is very coupled to logic inside the generator. Fix that.
    if upscaling_requested and any(result.has_upscaled for result in jobRequest.results):
      progressType = ProgressType.UPSCALING_STARTED if len(jobRequest.results) < 2 * jobRequest.iterations else ProgressType.UPSCALING_DONE
      upscale_count = sum(1 for i in jobRequest.results if i.has_upscaled)
      self.__signal_service.emit(Signal.image_progress(jobRequest.id, dreamResult.id, upscale_count, jobRequest.iterations, progressType))


  def __on_progress(self, jobRequest: JobRequest, sample, step):
    if self.__cancellationRequested:
      self.__cancellationRequested = False
      raise CanceledException

    # TODO: Progress per request will be easier once the seeds (and ids) can all be pre-generated
    hasProgressImage = False
    s = str(len(jobRequest.results))
    if jobRequest.progress_images and step % 5 == 0 and step < jobRequest.steps - 1:
      image = self.__model._sample_to_image(sample)

      # TODO: clean this up, use a pre-defined dream result
      result = DreamResult()
      result.parse_json(jobRequest.__dict__, new_instance=False)
      self.__intermediateStorage.save(image, result, postfix=f'.{s}.{step}')
      hasProgressImage = True

    self.__signal_service.emit(Signal.image_progress(jobRequest.id, f'{jobRequest.id}.{s}', step, jobRequest.steps, ProgressType.GENERATION, hasProgressImage))


  def __generate(self, jobRequest: JobRequest):
    try:
      # TODO: handle this file a file service for init images
      initimgfile = None # TODO: support this on the model directly?
      if (jobRequest.enable_init_image):
        if jobRequest.initimg is not None:
          with open("./img2img-tmp.png", "wb") as f:
            initimg = jobRequest.initimg.split(",")[1] # Ignore mime type
            f.write(base64.b64decode(initimg))
            initimgfile = "./img2img-tmp.png"

      # Use previous seed if set to -1
      initSeed = jobRequest.seed
      if initSeed == -1:
        initSeed = self.__model.seed

      # Zero gfpgan strength if the model doesn't exist
      # TODO: determine if this could be at the top now? Used to cause circular import
      from ldm.gfpgan.gfpgan_tools import gfpgan_model_exists
      if not gfpgan_model_exists:
        jobRequest.enable_gfpgan = False

      # Signal start
      self.__on_start(jobRequest)

      # Generate in model
      # TODO: Split job generation requests instead of fitting all parameters here
      # TODO: Support no generation (just upscaling/gfpgan)

      upscale = None if not jobRequest.enable_upscale else jobRequest.upscale
      facetool_strength = 0 if not jobRequest.enable_gfpgan else jobRequest.facetool_strength

      if not jobRequest.enable_generate:
        # If not generating, check if we're upscaling or running gfpgan
        if not upscale and not facetool_strength:
          # Invalid settings (TODO: Add message to help user)
          raise CanceledException()

        image = Image.open(initimgfile)
        # TODO: support progress for upscale?
        self.__model.upscale_and_reconstruct(
          image_list = [[image,0]],
          upscale = upscale,
          strength = facetool_strength,
          save_original = False,
          image_callback = lambda image, seed, upscaled=False: self.__on_image_result(jobRequest, image, seed, upscaled))

      else:
        # Generating - run the generation
        init_img = None if (not jobRequest.enable_img2img or jobRequest.strength == 0) else initimgfile


        self.__model.prompt2image(
          prompt           = jobRequest.prompt,
          init_img         = init_img, # TODO: ensure this works
          strength         = None if init_img is None else jobRequest.strength,
          fit              = None if init_img is None else jobRequest.fit,
          iterations       = jobRequest.iterations,
          cfg_scale        = jobRequest.cfg_scale,
          threshold        = jobRequest.threshold,
          perlin           = jobRequest.perlin,
          width            = jobRequest.width,
          height           = jobRequest.height,
          seed             = jobRequest.seed,
          steps            = jobRequest.steps,
          variation_amount = jobRequest.variation_amount,
          with_variations  = jobRequest.with_variations,
          facetool_strength = facetool_strength,
          upscale          = upscale,
          sampler_name     = jobRequest.sampler_name,
          seamless         = jobRequest.seamless,
          hires_fix        = jobRequest.hires_fix,
          embiggen         = jobRequest.embiggen,
          embiggen_tiles   = jobRequest.embiggen_tiles,
          step_callback    = lambda sample, step: self.__on_progress(jobRequest, sample, step),
          image_callback   = lambda image, seed, upscaled=False: self.__on_image_result(jobRequest, image, seed, upscaled))

    except CanceledException:
      self.__signal_service.emit(Signal.job_canceled(jobRequest.id))

    finally:
      self.__signal_service.emit(Signal.job_done(jobRequest.id))

      # Remove the temp file
      if (initimgfile is not None):
        os.remove("./img2img-tmp.png")