diff --git a/backend/invoke_ai_web_server.py b/backend/invoke_ai_web_server.py index 42c4fc9d99..adc423532e 100644 --- a/backend/invoke_ai_web_server.py +++ b/backend/invoke_ai_web_server.py @@ -174,9 +174,12 @@ class InvokeAIWebServer: ) @socketio.on("requestLatestImages") - def handle_request_latest_images(latest_mtime): + def handle_request_latest_images(category, latest_mtime): try: - paths = glob.glob(os.path.join(self.result_path, "*.png")) + base_path = ( + self.result_path if category == "result" else self.init_image_path + ) + paths = glob.glob(os.path.join(base_path, "*.png")) image_paths = sorted( paths, key=lambda x: os.path.getmtime(x), reverse=True @@ -201,14 +204,13 @@ class InvokeAIWebServer: "metadata": metadata["sd-metadata"], "width": width, "height": height, + "category": category, } ) socketio.emit( "galleryImages", - { - "images": image_array, - }, + {"images": image_array, "category": category}, ) except Exception as e: self.socketio.emit("error", {"message": (str(e))}) @@ -218,11 +220,15 @@ class InvokeAIWebServer: print("\n") @socketio.on("requestImages") - def handle_request_images(earliest_mtime=None): + def handle_request_images(category, earliest_mtime=None): try: page_size = 50 - paths = glob.glob(os.path.join(self.result_path, "*.png")) + base_path = ( + self.result_path if category == "result" else self.init_image_path + ) + + paths = glob.glob(os.path.join(base_path, "*.png")) image_paths = sorted( paths, key=lambda x: os.path.getmtime(x), reverse=True @@ -253,6 +259,7 @@ class InvokeAIWebServer: "metadata": metadata["sd-metadata"], "width": width, "height": height, + "category": category, } ) @@ -261,6 +268,7 @@ class InvokeAIWebServer: { "images": image_array, "areMoreImagesAvailable": areMoreImagesAvailable, + "category": category, }, ) except Exception as e: @@ -416,14 +424,17 @@ class InvokeAIWebServer: # TODO: I think this needs a safety mechanism. @socketio.on("deleteImage") - def handle_delete_image(url, uuid): + def handle_delete_image(url, uuid, category): try: print(f'>> Delete requested "{url}"') from send2trash import send2trash path = self.get_image_path_from_url(url) + print(path) send2trash(path) - socketio.emit("imageDeleted", {"url": url, "uuid": uuid}) + socketio.emit( + "imageDeleted", {"url": url, "uuid": uuid, "category": category} + ) except Exception as e: self.socketio.emit("error", {"message": (str(e))}) print("\n") @@ -439,11 +450,17 @@ class InvokeAIWebServer: file_path = self.save_file_unique_uuid_name( bytes=bytes, name=name, path=self.init_image_path ) - + mtime = os.path.getmtime(file_path) + (width, height) = Image.open(file_path).size + print(file_path) socketio.emit( "initialImageUploaded", { "url": self.get_url_from_image_path(file_path), + "mtime": mtime, + "width": width, + "height": height, + "category": "user", }, ) except Exception as e: diff --git a/backend/server.py b/backend/server.py deleted file mode 100644 index 8ad861356c..0000000000 --- a/backend/server.py +++ /dev/null @@ -1,822 +0,0 @@ -import mimetypes -import transformers -import json -import os -import traceback -import eventlet -import glob -import shlex -import math -import shutil -import sys - -sys.path.append(".") - -from argparse import ArgumentTypeError -from modules.create_cmd_parser import create_cmd_parser - -parser = create_cmd_parser() -opt = parser.parse_args() - - -from flask_socketio import SocketIO -from flask import Flask, send_from_directory, url_for, jsonify -from pathlib import Path -from PIL import Image -from pytorch_lightning import logging -from threading import Event -from uuid import uuid4 -from send2trash import send2trash - - -from ldm.generate import Generate -from ldm.invoke.restoration import Restoration -from ldm.invoke.pngwriter import PngWriter, retrieve_metadata -from ldm.invoke.args import APP_ID, APP_VERSION, calculate_init_img_hash -from ldm.invoke.prompt_parser import split_weighted_subprompts - -from modules.parameters import parameters_to_command - - -""" -USER CONFIG -""" -if opt.cors and "*" in opt.cors: - raise ArgumentTypeError('"*" is not an allowed CORS origin') - - -output_dir = "outputs/" # Base output directory for images -host = opt.host # Web & socket.io host -port = opt.port # Web & socket.io port -verbose = opt.verbose # enables copious socket.io logging -precision = opt.precision -free_gpu_mem = opt.free_gpu_mem -embedding_path = opt.embedding_path -additional_allowed_origins = ( - opt.cors if opt.cors else [] -) # additional CORS allowed origins -model = "stable-diffusion-1.4" - -""" -END USER CONFIG -""" - - -print("* Initializing, be patient...\n") - - -""" -SERVER SETUP -""" - - -# fix missing mimetypes on windows due to registry wonkiness -mimetypes.add_type("application/javascript", ".js") -mimetypes.add_type("text/css", ".css") - -app = Flask(__name__, static_url_path="", static_folder="../frontend/dist/") - - -app.config["OUTPUTS_FOLDER"] = "../outputs" - - -@app.route("/outputs/") -def outputs(filename): - return send_from_directory(app.config["OUTPUTS_FOLDER"], filename) - - -@app.route("/", defaults={"path": ""}) -def serve(path): - return send_from_directory(app.static_folder, "index.html") - - -logger = True if verbose else False -engineio_logger = True if verbose else False - -# default 1,000,000, needs to be higher for socketio to accept larger images -max_http_buffer_size = 10000000 - -cors_allowed_origins = [f"http://{host}:{port}"] + additional_allowed_origins - -socketio = SocketIO( - app, - logger=logger, - engineio_logger=engineio_logger, - max_http_buffer_size=max_http_buffer_size, - cors_allowed_origins=cors_allowed_origins, - ping_interval=(50, 50), - ping_timeout=60, -) - - -""" -END SERVER SETUP -""" - - -""" -APP SETUP -""" - - -class CanceledException(Exception): - pass - - -try: - gfpgan, codeformer, esrgan = None, None, None - from ldm.invoke.restoration.base import Restoration - - restoration = Restoration() - gfpgan, codeformer = restoration.load_face_restore_models() - esrgan = restoration.load_esrgan() - - # coreformer.process(self, image, strength, device, seed=None, fidelity=0.75) - -except (ModuleNotFoundError, ImportError): - print(traceback.format_exc(), file=sys.stderr) - print(">> You may need to install the ESRGAN and/or GFPGAN modules") - -canceled = Event() - -# reduce logging outputs to error -transformers.logging.set_verbosity_error() -logging.getLogger("pytorch_lightning").setLevel(logging.ERROR) - -# Initialize and load model -generate = Generate( - model, - precision=precision, - embedding_path=embedding_path, -) -generate.free_gpu_mem = free_gpu_mem -generate.load_model() - - -# location for "finished" images -result_path = os.path.join(output_dir, "img-samples/") - -# temporary path for intermediates -intermediate_path = os.path.join(result_path, "intermediates/") - -# path for user-uploaded init images and masks -init_image_path = os.path.join(result_path, "init-images/") -mask_image_path = os.path.join(result_path, "mask-images/") - -# txt log -log_path = os.path.join(result_path, "invoke_log.txt") - -# make all output paths -[ - os.makedirs(path, exist_ok=True) - for path in [result_path, intermediate_path, init_image_path, mask_image_path] -] - - -""" -END APP SETUP -""" - - -""" -SOCKET.IO LISTENERS -""" - - -@socketio.on("requestSystemConfig") -def handle_request_capabilities(): - print(f">> System config requested") - config = get_system_config() - socketio.emit("systemConfig", config) - - -@socketio.on("requestImages") -def handle_request_images(page=1, offset=0, last_mtime=None): - chunk_size = 50 - - if last_mtime: - print(f">> Latest images requested") - else: - print( - f">> Page {page} of images requested (page size {chunk_size} offset {offset})" - ) - - paths = glob.glob(os.path.join(result_path, "*.png")) - sorted_paths = sorted(paths, key=lambda x: os.path.getmtime(x), reverse=True) - - if last_mtime: - image_paths = filter(lambda x: os.path.getmtime(x) > last_mtime, sorted_paths) - else: - - image_paths = sorted_paths[ - slice(chunk_size * (page - 1) + offset, chunk_size * page + offset) - ] - page = page + 1 - - image_array = [] - - for path in image_paths: - metadata = retrieve_metadata(path) - image_array.append( - { - "url": path, - "mtime": os.path.getmtime(path), - "metadata": metadata["sd-metadata"], - } - ) - - socketio.emit( - "galleryImages", - { - "images": image_array, - "nextPage": page, - "offset": offset, - "onlyNewImages": True if last_mtime else False, - }, - ) - - -@socketio.on("generateImage") -def handle_generate_image_event( - generation_parameters, esrgan_parameters, gfpgan_parameters -): - print( - f">> Image generation requested: {generation_parameters}\nESRGAN parameters: {esrgan_parameters}\nGFPGAN parameters: {gfpgan_parameters}" - ) - generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters) - - -@socketio.on("runESRGAN") -def handle_run_esrgan_event(original_image, esrgan_parameters): - print( - f'>> ESRGAN upscale requested for "{original_image["url"]}": {esrgan_parameters}' - ) - progress = { - "currentStep": 1, - "totalSteps": 1, - "currentIteration": 1, - "totalIterations": 1, - "currentStatus": "Preparing", - "isProcessing": True, - "currentStatusHasSteps": False, - } - - socketio.emit("progressUpdate", progress) - eventlet.sleep(0) - - image = Image.open(original_image["url"]) - - seed = ( - original_image["metadata"]["seed"] - if "seed" in original_image["metadata"] - else "unknown_seed" - ) - - progress["currentStatus"] = "Upscaling" - socketio.emit("progressUpdate", progress) - eventlet.sleep(0) - - image = esrgan.process( - image=image, - upsampler_scale=esrgan_parameters["upscale"][0], - strength=esrgan_parameters["upscale"][1], - seed=seed, - ) - - progress["currentStatus"] = "Saving image" - socketio.emit("progressUpdate", progress) - eventlet.sleep(0) - - esrgan_parameters["seed"] = seed - metadata = parameters_to_post_processed_image_metadata( - parameters=esrgan_parameters, - original_image_path=original_image["url"], - type="esrgan", - ) - command = parameters_to_command(esrgan_parameters) - - path = save_image(image, command, metadata, result_path, postprocessing="esrgan") - - write_log_message(f'[Upscaled] "{original_image["url"]}" > "{path}": {command}') - - progress["currentStatus"] = "Finished" - progress["currentStep"] = 0 - progress["totalSteps"] = 0 - progress["currentIteration"] = 0 - progress["totalIterations"] = 0 - progress["isProcessing"] = False - socketio.emit("progressUpdate", progress) - eventlet.sleep(0) - - socketio.emit( - "esrganResult", - { - "url": os.path.relpath(path), - "mtime": os.path.getmtime(path), - "metadata": metadata, - }, - ) - - -@socketio.on("runGFPGAN") -def handle_run_gfpgan_event(original_image, gfpgan_parameters): - print( - f'>> GFPGAN face fix requested for "{original_image["url"]}": {gfpgan_parameters}' - ) - progress = { - "currentStep": 1, - "totalSteps": 1, - "currentIteration": 1, - "totalIterations": 1, - "currentStatus": "Preparing", - "isProcessing": True, - "currentStatusHasSteps": False, - } - - socketio.emit("progressUpdate", progress) - eventlet.sleep(0) - - image = Image.open(original_image["url"]) - - seed = ( - original_image["metadata"]["seed"] - if "seed" in original_image["metadata"] - else "unknown_seed" - ) - - progress["currentStatus"] = "Fixing faces" - socketio.emit("progressUpdate", progress) - eventlet.sleep(0) - - image = gfpgan.process( - image=image, strength=gfpgan_parameters["facetool_strength"], seed=seed - ) - - progress["currentStatus"] = "Saving image" - socketio.emit("progressUpdate", progress) - eventlet.sleep(0) - - gfpgan_parameters["seed"] = seed - metadata = parameters_to_post_processed_image_metadata( - parameters=gfpgan_parameters, - original_image_path=original_image["url"], - type="gfpgan", - ) - command = parameters_to_command(gfpgan_parameters) - - path = save_image(image, command, metadata, result_path, postprocessing="gfpgan") - - write_log_message(f'[Fixed faces] "{original_image["url"]}" > "{path}": {command}') - - progress["currentStatus"] = "Finished" - progress["currentStep"] = 0 - progress["totalSteps"] = 0 - progress["currentIteration"] = 0 - progress["totalIterations"] = 0 - progress["isProcessing"] = False - socketio.emit("progressUpdate", progress) - eventlet.sleep(0) - - socketio.emit( - "gfpganResult", - { - "url": os.path.relpath(path), - "mtime": os.path.mtime(path), - "metadata": metadata, - }, - ) - - -@socketio.on("cancel") -def handle_cancel(): - print(f">> Cancel processing requested") - canceled.set() - socketio.emit("processingCanceled") - - -# TODO: I think this needs a safety mechanism. -@socketio.on("deleteImage") -def handle_delete_image(path, uuid): - print(f'>> Delete requested "{path}"') - send2trash(path) - socketio.emit("imageDeleted", {"url": path, "uuid": uuid}) - - -# TODO: I think this needs a safety mechanism. -@socketio.on("uploadInitialImage") -def handle_upload_initial_image(bytes, name): - print(f'>> Init image upload requested "{name}"') - uuid = uuid4().hex - split = os.path.splitext(name) - name = f"{split[0]}.{uuid}{split[1]}" - file_path = os.path.join(init_image_path, name) - os.makedirs(os.path.dirname(file_path), exist_ok=True) - newFile = open(file_path, "wb") - newFile.write(bytes) - socketio.emit("initialImageUploaded", {"url": file_path, "uuid": ""}) - - -# TODO: I think this needs a safety mechanism. -@socketio.on("uploadMaskImage") -def handle_upload_mask_image(bytes, name): - print(f'>> Mask image upload requested "{name}"') - uuid = uuid4().hex - split = os.path.splitext(name) - name = f"{split[0]}.{uuid}{split[1]}" - file_path = os.path.join(mask_image_path, name) - os.makedirs(os.path.dirname(file_path), exist_ok=True) - newFile = open(file_path, "wb") - newFile.write(bytes) - socketio.emit("maskImageUploaded", {"url": file_path, "uuid": ""}) - - -""" -END SOCKET.IO LISTENERS -""" - - -""" -ADDITIONAL FUNCTIONS -""" - - -def get_system_config(): - return { - "model": "stable diffusion", - "model_id": model, - "model_hash": generate.model_hash, - "app_id": APP_ID, - "app_version": APP_VERSION, - } - - -def parameters_to_post_processed_image_metadata(parameters, original_image_path, type): - # top-level metadata minus `image` or `images` - metadata = get_system_config() - - orig_hash = calculate_init_img_hash(original_image_path) - - image = {"orig_path": original_image_path, "orig_hash": orig_hash} - - if type == "esrgan": - image["type"] = "esrgan" - image["scale"] = parameters["upscale"][0] - image["strength"] = parameters["upscale"][1] - elif type == "gfpgan": - image["type"] = "gfpgan" - image["strength"] = parameters["facetool_strength"] - else: - raise TypeError(f"Invalid type: {type}") - - metadata["image"] = image - return metadata - - -def parameters_to_generated_image_metadata(parameters): - # top-level metadata minus `image` or `images` - - metadata = get_system_config() - # remove any image keys not mentioned in RFC #266 - rfc266_img_fields = [ - "type", - "postprocessing", - "sampler", - "prompt", - "seed", - "variations", - "steps", - "cfg_scale", - "threshold", - "perlin", - "step_number", - "width", - "height", - "extra", - "seamless", - "hires_fix", - ] - - rfc_dict = {} - - for item in parameters.items(): - key, value = item - if key in rfc266_img_fields: - rfc_dict[key] = value - - postprocessing = [] - - # 'postprocessing' is either null or an - if "facetool_strength" in parameters: - - postprocessing.append( - {"type": "gfpgan", "strength": float(parameters["facetool_strength"])} - ) - - if "upscale" in parameters: - postprocessing.append( - { - "type": "esrgan", - "scale": int(parameters["upscale"][0]), - "strength": float(parameters["upscale"][1]), - } - ) - - rfc_dict["postprocessing"] = postprocessing if len(postprocessing) > 0 else None - - # semantic drift - rfc_dict["sampler"] = parameters["sampler_name"] - - # display weighted subprompts (liable to change) - subprompts = split_weighted_subprompts(parameters["prompt"]) - subprompts = [{"prompt": x[0], "weight": x[1]} for x in subprompts] - rfc_dict["prompt"] = subprompts - - # 'variations' should always exist and be an array, empty or consisting of {'seed': seed, 'weight': weight} pairs - variations = [] - - if "with_variations" in parameters: - variations = [ - {"seed": x[0], "weight": x[1]} for x in parameters["with_variations"] - ] - - rfc_dict["variations"] = variations - - if "init_img" in parameters: - rfc_dict["type"] = "img2img" - rfc_dict["strength"] = parameters["strength"] - rfc_dict["fit"] = parameters["fit"] # TODO: Noncompliant - rfc_dict["orig_hash"] = calculate_init_img_hash(parameters["init_img"]) - rfc_dict["init_image_path"] = parameters["init_img"] # TODO: Noncompliant - rfc_dict["sampler"] = "ddim" # TODO: FIX ME WHEN IMG2IMG SUPPORTS ALL SAMPLERS - if "init_mask" in parameters: - rfc_dict["mask_hash"] = calculate_init_img_hash( - parameters["init_mask"] - ) # TODO: Noncompliant - rfc_dict["mask_image_path"] = parameters["init_mask"] # TODO: Noncompliant - else: - rfc_dict["type"] = "txt2img" - - metadata["image"] = rfc_dict - - return metadata - - -def make_unique_init_image_filename(name): - uuid = uuid4().hex - split = os.path.splitext(name) - name = f"{split[0]}.{uuid}{split[1]}" - return name - - -def write_log_message(message, log_path=log_path): - """Logs the filename and parameters used to generate or process that image to log file""" - message = f"{message}\n" - with open(log_path, "a", encoding="utf-8") as file: - file.writelines(message) - - -def save_image( - image, command, metadata, output_dir, step_index=None, postprocessing=False -): - pngwriter = PngWriter(output_dir) - prefix = pngwriter.unique_prefix() - - seed = "unknown_seed" - - if "image" in metadata: - if "seed" in metadata["image"]: - seed = metadata["image"]["seed"] - - filename = f"{prefix}.{seed}" - - if step_index: - filename += f".{step_index}" - if postprocessing: - filename += f".postprocessed" - - filename += ".png" - - path = pngwriter.save_image_and_prompt_to_png( - image=image, dream_prompt=command, metadata=metadata, name=filename - ) - - return path - - -def calculate_real_steps(steps, strength, has_init_image): - return math.floor(strength * steps) if has_init_image else steps - - -def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters): - canceled.clear() - - step_index = 1 - prior_variations = ( - generation_parameters["with_variations"] - if "with_variations" in generation_parameters - else [] - ) - """ - If a result image is used as an init image, and then deleted, we will want to be - able to use it as an init image in the future. Need to copy it. - - If the init/mask image doesn't exist in the init_image_path/mask_image_path, - make a unique filename for it and copy it there. - """ - if "init_img" in generation_parameters: - filename = os.path.basename(generation_parameters["init_img"]) - if not os.path.exists(os.path.join(init_image_path, filename)): - unique_filename = make_unique_init_image_filename(filename) - new_path = os.path.join(init_image_path, unique_filename) - shutil.copy(generation_parameters["init_img"], new_path) - generation_parameters["init_img"] = new_path - if "init_mask" in generation_parameters: - filename = os.path.basename(generation_parameters["init_mask"]) - if not os.path.exists(os.path.join(mask_image_path, filename)): - unique_filename = make_unique_init_image_filename(filename) - new_path = os.path.join(init_image_path, unique_filename) - shutil.copy(generation_parameters["init_img"], new_path) - generation_parameters["init_mask"] = new_path - - totalSteps = calculate_real_steps( - steps=generation_parameters["steps"], - strength=generation_parameters["strength"] - if "strength" in generation_parameters - else None, - has_init_image="init_img" in generation_parameters, - ) - - progress = { - "currentStep": 1, - "totalSteps": totalSteps, - "currentIteration": 1, - "totalIterations": generation_parameters["iterations"], - "currentStatus": "Preparing", - "isProcessing": True, - "currentStatusHasSteps": False, - } - - socketio.emit("progressUpdate", progress) - eventlet.sleep(0) - - def image_progress(sample, step): - if canceled.is_set(): - raise CanceledException - - nonlocal step_index - nonlocal generation_parameters - nonlocal progress - - progress["currentStep"] = step + 1 - progress["currentStatus"] = "Generating" - progress["currentStatusHasSteps"] = True - - if ( - generation_parameters["progress_images"] - and step % 5 == 0 - and step < generation_parameters["steps"] - 1 - ): - image = generate.sample_to_image(sample) - - metadata = parameters_to_generated_image_metadata(generation_parameters) - command = parameters_to_command(generation_parameters) - path = save_image(image, command, metadata, intermediate_path, step_index=step_index, postprocessing=False) - - step_index += 1 - socketio.emit( - "intermediateResult", - { - "url": os.path.relpath(path), - "mtime": os.path.getmtime(path), - "metadata": metadata, - }, - ) - socketio.emit("progressUpdate", progress) - eventlet.sleep(0) - - def image_done(image, seed, first_seed): - nonlocal generation_parameters - nonlocal esrgan_parameters - nonlocal gfpgan_parameters - nonlocal progress - - step_index = 1 - nonlocal prior_variations - - progress["currentStatus"] = "Generation complete" - socketio.emit("progressUpdate", progress) - eventlet.sleep(0) - - all_parameters = generation_parameters - postprocessing = False - - if ( - "variation_amount" in all_parameters - and all_parameters["variation_amount"] > 0 - ): - first_seed = first_seed or seed - this_variation = [[seed, all_parameters["variation_amount"]]] - all_parameters["with_variations"] = prior_variations + this_variation - all_parameters["seed"] = first_seed - elif ("with_variations" in all_parameters): - all_parameters["seed"] = first_seed - else: - all_parameters["seed"] = seed - - if esrgan_parameters: - progress["currentStatus"] = "Upscaling" - progress["currentStatusHasSteps"] = False - socketio.emit("progressUpdate", progress) - eventlet.sleep(0) - - image = esrgan.process( - image=image, - upsampler_scale=esrgan_parameters["level"], - strength=esrgan_parameters["strength"], - seed=seed, - ) - - postprocessing = True - all_parameters["upscale"] = [ - esrgan_parameters["level"], - esrgan_parameters["strength"], - ] - - if gfpgan_parameters: - progress["currentStatus"] = "Fixing faces" - progress["currentStatusHasSteps"] = False - socketio.emit("progressUpdate", progress) - eventlet.sleep(0) - - image = gfpgan.process( - image=image, strength=gfpgan_parameters["strength"], seed=seed - ) - postprocessing = True - all_parameters["facetool_strength"] = gfpgan_parameters["strength"] - - progress["currentStatus"] = "Saving image" - socketio.emit("progressUpdate", progress) - eventlet.sleep(0) - - metadata = parameters_to_generated_image_metadata(all_parameters) - command = parameters_to_command(all_parameters) - - path = save_image( - image, command, metadata, result_path, postprocessing=postprocessing - ) - - print(f'>> Image generated: "{path}"') - write_log_message(f'[Generated] "{path}": {command}') - - if progress["totalIterations"] > progress["currentIteration"]: - progress["currentStep"] = 1 - progress["currentIteration"] += 1 - progress["currentStatus"] = "Iteration finished" - progress["currentStatusHasSteps"] = False - else: - progress["currentStep"] = 0 - progress["totalSteps"] = 0 - progress["currentIteration"] = 0 - progress["totalIterations"] = 0 - progress["currentStatus"] = "Finished" - progress["isProcessing"] = False - - socketio.emit("progressUpdate", progress) - eventlet.sleep(0) - - socketio.emit( - "generationResult", - { - "url": os.path.relpath(path), - "mtime": os.path.getmtime(path), - "metadata": metadata, - }, - ) - eventlet.sleep(0) - - try: - generate.prompt2image( - **generation_parameters, - step_callback=image_progress, - image_callback=image_done, - ) - - except KeyboardInterrupt: - raise - except CanceledException: - pass - except Exception as e: - socketio.emit("error", {"message": (str(e))}) - print("\n") - traceback.print_exc() - print("\n") - - -""" -END ADDITIONAL FUNCTIONS -""" - - -if __name__ == "__main__": - print(f">> Starting server at http://{host}:{port}") - socketio.run(app, host=host, port=port) diff --git a/frontend/src/app/invokeai.d.ts b/frontend/src/app/invokeai.d.ts index e107003e20..7293a1e9f0 100644 --- a/frontend/src/app/invokeai.d.ts +++ b/frontend/src/app/invokeai.d.ts @@ -12,6 +12,8 @@ * 'gfpgan'. */ +import { Category as GalleryCategory } from '../features/gallery/gallerySlice'; + /** * TODO: * Once an image has been generated, if it is postprocessed again, @@ -105,14 +107,15 @@ export declare type Metadata = SystemConfig & { image: GeneratedImageMetadata | PostProcessedImageMetadata; }; -// An Image has a UUID, url (path?) and Metadata. +// An Image has a UUID, url, modified timestamp, width, height and maybe metadata export declare type Image = { uuid: string; url: string; mtime: number; - metadata: Metadata; + metadata?: Metadata; width: number; height: number; + category: GalleryCategory; }; // GalleryImages is an array of Image. @@ -167,13 +170,9 @@ export declare type SystemStatusResponse = SystemStatus; export declare type SystemConfigResponse = SystemConfig; -export declare type ImageResultResponse = { - url: string; - mtime: number; - metadata: Metadata; - width: number; - height: number; -}; +export declare type ImageResultResponse = Omit; + +export declare type ImageUploadResponse = Omit; export declare type ErrorResponse = { message: string; @@ -183,11 +182,13 @@ export declare type ErrorResponse = { export declare type GalleryImagesResponse = { images: Array>; areMoreImagesAvailable: boolean; + category: GalleryCategory; }; -export declare type ImageUrlAndUuidResponse = { +export declare type ImageDeletedResponse = { uuid: string; url: string; + category: GalleryCategory; }; export declare type ImageUrlResponse = { diff --git a/frontend/src/app/socketio/actions.ts b/frontend/src/app/socketio/actions.ts index 5e17ee7764..34023c0b50 100644 --- a/frontend/src/app/socketio/actions.ts +++ b/frontend/src/app/socketio/actions.ts @@ -1,4 +1,5 @@ import { createAction } from '@reduxjs/toolkit'; +import { GalleryCategory } from '../../features/gallery/gallerySlice'; import { InvokeTabName } from '../../features/tabs/InvokeTabs'; import * as InvokeAI from '../invokeai'; @@ -15,8 +16,8 @@ export const generateImage = createAction( export const runESRGAN = createAction('socketio/runESRGAN'); export const runFacetool = createAction('socketio/runFacetool'); export const deleteImage = createAction('socketio/deleteImage'); -export const requestImages = createAction('socketio/requestImages'); -export const requestNewImages = createAction( +export const requestImages = createAction('socketio/requestImages'); +export const requestNewImages = createAction( 'socketio/requestNewImages' ); export const cancelProcessing = createAction( diff --git a/frontend/src/app/socketio/emitters.ts b/frontend/src/app/socketio/emitters.ts index 0c8742f29a..be19e3baef 100644 --- a/frontend/src/app/socketio/emitters.ts +++ b/frontend/src/app/socketio/emitters.ts @@ -5,6 +5,11 @@ import { frontendToBackendParameters, FrontendToBackendParametersConfig, } from '../../common/util/parameterTranslation'; +import { + GalleryCategory, + GalleryState, +} from '../../features/gallery/gallerySlice'; +import { OptionsState } from '../../features/options/optionsSlice'; import { addLogEntry, errorOccurred, @@ -108,7 +113,8 @@ const makeSocketIOEmitters = ( }, emitRunESRGAN: (imageToProcess: InvokeAI.Image) => { dispatch(setIsProcessing(true)); - const { upscalingLevel, upscalingStrength } = getState().options; + const options: OptionsState = getState().options; + const { upscalingLevel, upscalingStrength } = options; const esrganParameters = { upscale: [upscalingLevel, upscalingStrength], }; @@ -128,8 +134,8 @@ const makeSocketIOEmitters = ( }, emitRunFacetool: (imageToProcess: InvokeAI.Image) => { dispatch(setIsProcessing(true)); - const { facetoolType, facetoolStrength, codeformerFidelity } = - getState().options; + const options: OptionsState = getState().options; + const { facetoolType, facetoolStrength, codeformerFidelity } = options; const facetoolParameters: Record = { facetool_strength: facetoolStrength, @@ -156,16 +162,18 @@ const makeSocketIOEmitters = ( ); }, emitDeleteImage: (imageToDelete: InvokeAI.Image) => { - const { url, uuid } = imageToDelete; - socketio.emit('deleteImage', url, uuid); + const { url, uuid, category } = imageToDelete; + socketio.emit('deleteImage', url, uuid, category); }, - emitRequestImages: () => { - const { earliest_mtime } = getState().gallery; - socketio.emit('requestImages', earliest_mtime); + emitRequestImages: (category: GalleryCategory) => { + const gallery: GalleryState = getState().gallery; + const { earliest_mtime } = gallery.categories[category]; + socketio.emit('requestImages', category, earliest_mtime); }, - emitRequestNewImages: () => { - const { latest_mtime } = getState().gallery; - socketio.emit('requestLatestImages', latest_mtime); + emitRequestNewImages: (category: GalleryCategory) => { + const gallery: GalleryState = getState().gallery; + const { latest_mtime } = gallery.categories[category]; + socketio.emit('requestLatestImages', category, latest_mtime); }, emitCancelProcessing: () => { socketio.emit('cancel'); diff --git a/frontend/src/app/socketio/listeners.ts b/frontend/src/app/socketio/listeners.ts index 1d1e319e95..7b5bd3e29c 100644 --- a/frontend/src/app/socketio/listeners.ts +++ b/frontend/src/app/socketio/listeners.ts @@ -21,12 +21,14 @@ import { addGalleryImages, addImage, clearIntermediateImage, + GalleryState, removeImage, setIntermediateImage, } from '../../features/gallery/gallerySlice'; import { - setInitialImagePath, + clearInitialImage, + setInitialImage, setMaskPath, } from '../../features/options/optionsSlice'; import { requestImages, requestNewImages } from './actions'; @@ -48,10 +50,18 @@ const makeSocketIOListeners = ( try { dispatch(setIsConnected(true)); dispatch(setCurrentStatus('Connected')); - if (getState().gallery.latest_mtime) { - dispatch(requestNewImages()); + const gallery: GalleryState = getState().gallery; + + if (gallery.categories.user.latest_mtime) { + dispatch(requestNewImages('user')); } else { - dispatch(requestImages()); + dispatch(requestImages('user')); + } + + if (gallery.categories.result.latest_mtime) { + dispatch(requestNewImages('result')); + } else { + dispatch(requestImages('result')); } } catch (e) { console.error(e); @@ -83,8 +93,11 @@ const makeSocketIOListeners = ( try { dispatch( addImage({ - uuid: uuidv4(), - ...data, + category: 'result', + image: { + uuid: uuidv4(), + ...data, + }, }) ); dispatch( @@ -125,8 +138,11 @@ const makeSocketIOListeners = ( try { dispatch( addImage({ - uuid: uuidv4(), - ...data, + category: 'result', + image: { + uuid: uuidv4(), + ...data, + }, }) ); @@ -180,7 +196,7 @@ const makeSocketIOListeners = ( * Callback to run when we receive a 'galleryImages' event. */ onGalleryImages: (data: InvokeAI.GalleryImagesResponse) => { - const { images, areMoreImagesAvailable } = data; + const { images, areMoreImagesAvailable, category } = data; /** * the logic here ideally would be in the reducer but we have a side effect: @@ -189,19 +205,18 @@ const makeSocketIOListeners = ( // Generate a UUID for each image const preparedImages = images.map((image): InvokeAI.Image => { - const { url, metadata, mtime, width, height } = image; return { uuid: uuidv4(), - url, - mtime, - metadata, - width, - height, + ...image, }; }); dispatch( - addGalleryImages({ images: preparedImages, areMoreImagesAvailable }) + addGalleryImages({ + images: preparedImages, + areMoreImagesAvailable, + category, + }) ); dispatch( @@ -220,7 +235,12 @@ const makeSocketIOListeners = ( const { intermediateImage } = getState().gallery; if (intermediateImage) { - dispatch(addImage(intermediateImage)); + dispatch( + addImage({ + category: 'result', + image: intermediateImage, + }) + ); dispatch( addLogEntry({ timestamp: dateFormat(new Date(), 'isoDateTime'), @@ -241,14 +261,17 @@ const makeSocketIOListeners = ( /** * Callback to run when we receive a 'imageDeleted' event. */ - onImageDeleted: (data: InvokeAI.ImageUrlAndUuidResponse) => { - const { url, uuid } = data; - dispatch(removeImage(uuid)); + onImageDeleted: (data: InvokeAI.ImageDeletedResponse) => { + const { url, uuid, category } = data; - const { initialImagePath, maskPath } = getState().options; + // remove image from gallery + dispatch(removeImage(data)); - if (initialImagePath === url) { - dispatch(setInitialImagePath('')); + // remove references to image in options + const { initialImage, maskPath } = getState().options; + + if (initialImage?.url === url || initialImage === url) { + dispatch(clearInitialImage()); } if (maskPath === url) { @@ -262,18 +285,25 @@ const makeSocketIOListeners = ( }) ); }, - /** - * Callback to run when we receive a 'initialImageUploaded' event. - */ - onInitialImageUploaded: (data: InvokeAI.ImageUrlResponse) => { - const { url } = data; - dispatch(setInitialImagePath(url)); - dispatch( - addLogEntry({ - timestamp: dateFormat(new Date(), 'isoDateTime'), - message: `Initial image uploaded: ${url}`, - }) - ); + onInitialImageUploaded: (data: InvokeAI.ImageUploadResponse) => { + const image = { + uuid: uuidv4(), + ...data, + }; + + try { + dispatch(addImage({ image, category: 'user' })); + dispatch(setInitialImage(image)); + + dispatch( + addLogEntry({ + timestamp: dateFormat(new Date(), 'isoDateTime'), + message: `Image uploaded: ${data.url}`, + }) + ); + } catch (e) { + console.error(e); + } }, /** * Callback to run when we receive a 'maskImageUploaded' event. diff --git a/frontend/src/app/socketio/middleware.ts b/frontend/src/app/socketio/middleware.ts index 99de806853..748c848bf7 100644 --- a/frontend/src/app/socketio/middleware.ts +++ b/frontend/src/app/socketio/middleware.ts @@ -100,13 +100,20 @@ export const socketioMiddleware = () => { onProcessingCanceled(); }); - socketio.on('imageDeleted', (data: InvokeAI.ImageUrlAndUuidResponse) => { + socketio.on('imageDeleted', (data: InvokeAI.ImageDeletedResponse) => { onImageDeleted(data); }); - socketio.on('initialImageUploaded', (data: InvokeAI.ImageUrlResponse) => { - onInitialImageUploaded(data); - }); + // socketio.on('initialImageUploaded', (data: InvokeAI.ImageUrlResponse) => { + // onInitialImageUploaded(data); + // }); + + socketio.on( + 'initialImageUploaded', + (data: InvokeAI.ImageUploadResponse) => { + onInitialImageUploaded(data); + } + ); socketio.on('maskImageUploaded', (data: InvokeAI.ImageUrlResponse) => { onMaskImageUploaded(data); @@ -152,12 +159,12 @@ export const socketioMiddleware = () => { } case 'socketio/requestImages': { - emitRequestImages(); + emitRequestImages(action.payload); break; } case 'socketio/requestNewImages': { - emitRequestNewImages(); + emitRequestNewImages(action.payload); break; } diff --git a/frontend/src/common/hooks/useCheckParameters.ts b/frontend/src/common/hooks/useCheckParameters.ts index 1127f40852..5c23200442 100644 --- a/frontend/src/common/hooks/useCheckParameters.ts +++ b/frontend/src/common/hooks/useCheckParameters.ts @@ -23,7 +23,7 @@ export const useCheckParametersSelector = createSelector( shouldGenerateVariations: options.shouldGenerateVariations, seedWeights: options.seedWeights, maskPath: options.maskPath, - initialImagePath: options.initialImagePath, + initialImage: options.initialImage, seed: options.seed, activeTabName: tabMap[options.activeTab], // system @@ -49,7 +49,7 @@ const useCheckParameters = (): boolean => { shouldGenerateVariations, seedWeights, maskPath, - initialImagePath, + initialImage, seed, activeTabName, isProcessing, @@ -63,7 +63,7 @@ const useCheckParameters = (): boolean => { return false; } - if (activeTabName === 'img2img' && !initialImagePath) { + if (activeTabName === 'img2img' && !initialImage) { return false; } @@ -72,7 +72,7 @@ const useCheckParameters = (): boolean => { } // Cannot generate with a mask without img2img - if (maskPath && !initialImagePath) { + if (maskPath && !initialImage) { return false; } @@ -100,8 +100,8 @@ const useCheckParameters = (): boolean => { }, [ prompt, maskPath, - initialImagePath, isProcessing, + initialImage, isConnected, shouldGenerateVariations, seedWeights, diff --git a/frontend/src/common/util/parameterTranslation.ts b/frontend/src/common/util/parameterTranslation.ts index 19945b3166..d3d8e0c28c 100644 --- a/frontend/src/common/util/parameterTranslation.ts +++ b/frontend/src/common/util/parameterTranslation.ts @@ -47,7 +47,7 @@ export const frontendToBackendParameters = ( seamless, hiresFix, img2imgStrength, - initialImagePath, + initialImage, shouldFitToWidthHeight, shouldGenerateVariations, variationAmount, @@ -89,8 +89,9 @@ export const frontendToBackendParameters = ( } // img2img exclusive parameters - if (generationMode === 'img2img') { - generationParameters.init_img = initialImagePath; + if (generationMode === 'img2img' && initialImage) { + generationParameters.init_img = + typeof initialImage === 'string' ? initialImage : initialImage.url; generationParameters.strength = img2imgStrength; generationParameters.fit = shouldFitToWidthHeight; } diff --git a/frontend/src/features/gallery/CurrentImageButtons.tsx b/frontend/src/features/gallery/CurrentImageButtons.tsx index cd0b909285..07ed8d561d 100644 --- a/frontend/src/features/gallery/CurrentImageButtons.tsx +++ b/frontend/src/features/gallery/CurrentImageButtons.tsx @@ -8,7 +8,7 @@ import { RootState } from '../../app/store'; import { setActiveTab, setAllParameters, - setInitialImagePath, + setInitialImage, setSeed, setShouldShowImageDetails, } from '../options/optionsSlice'; @@ -85,7 +85,7 @@ const CurrentImageButtons = ({ image }: CurrentImageButtonsProps) => { useAppSelector(systemSelector); const handleClickUseAsInitialImage = () => { - dispatch(setInitialImagePath(image.url)); + dispatch(setInitialImage(image)); dispatch(setActiveTab(1)); }; @@ -114,7 +114,8 @@ const CurrentImageButtons = ({ image }: CurrentImageButtonsProps) => { ); const handleClickUseAllParameters = () => - dispatch(setAllParameters(image.metadata)); + image.metadata && dispatch(setAllParameters(image.metadata)); + useHotkeys( 'a', () => { @@ -139,9 +140,7 @@ const CurrentImageButtons = ({ image }: CurrentImageButtonsProps) => { [image] ); - // Non-null assertion: this button is disabled if there is no seed. - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - const handleClickUseSeed = () => dispatch(setSeed(image.metadata.image.seed)); + const handleClickUseSeed = () => image.metadata && dispatch(setSeed(image.metadata.image.seed)); useHotkeys( 's', () => { diff --git a/frontend/src/features/gallery/CurrentImagePreview.tsx b/frontend/src/features/gallery/CurrentImagePreview.tsx index 93a137de4e..62a14c864e 100644 --- a/frontend/src/features/gallery/CurrentImagePreview.tsx +++ b/frontend/src/features/gallery/CurrentImagePreview.tsx @@ -10,11 +10,15 @@ import _ from 'lodash'; const imagesSelector = createSelector( (state: RootState) => state.gallery, (gallery: GalleryState) => { - const currentImageIndex = gallery.images.findIndex( + const { currentCategory } = gallery; + + const tempImages = gallery.categories[currentCategory].images; + const currentImageIndex = tempImages.findIndex( (i) => i.uuid === gallery?.currentImage?.uuid ); - const imagesLength = gallery.images.length; + const imagesLength = tempImages.length; return { + currentCategory, isOnFirstImage: currentImageIndex === 0, isOnLastImage: !isNaN(currentImageIndex) && currentImageIndex === imagesLength - 1, @@ -35,7 +39,7 @@ export default function CurrentImagePreview(props: CurrentImagePreviewProps) { const { imageToDisplay } = props; const dispatch = useAppDispatch(); - const { isOnFirstImage, isOnLastImage } = useAppSelector(imagesSelector); + const { isOnFirstImage, isOnLastImage, currentCategory } = useAppSelector(imagesSelector); const shouldShowImageDetails = useAppSelector( (state: RootState) => state.options.shouldShowImageDetails @@ -53,11 +57,11 @@ export default function CurrentImagePreview(props: CurrentImagePreviewProps) { }; const handleClickPrevButton = () => { - dispatch(selectPrevImage()); + dispatch(selectPrevImage(currentCategory)); }; const handleClickNextButton = () => { - dispatch(selectNextImage()); + dispatch(selectNextImage(currentCategory)); }; return ( diff --git a/frontend/src/features/gallery/HoverableImage.tsx b/frontend/src/features/gallery/HoverableImage.tsx index bb980d7f35..99f9edead8 100644 --- a/frontend/src/features/gallery/HoverableImage.tsx +++ b/frontend/src/features/gallery/HoverableImage.tsx @@ -19,7 +19,7 @@ import { setActiveTab, setAllImageToImageParameters, setAllTextToImageParameters, - setInitialImagePath, + setInitialImage, setPrompt, setSeed, } from '../options/optionsSlice'; @@ -58,7 +58,7 @@ const HoverableImage = memo((props: HoverableImageProps) => { const handleMouseOut = () => setIsHovered(false); const handleUsePrompt = () => { - dispatch(setPrompt(image.metadata.image.prompt)); + image.metadata && dispatch(setPrompt(image.metadata.image.prompt)); toast({ title: 'Prompt Set', status: 'success', @@ -68,7 +68,7 @@ const HoverableImage = memo((props: HoverableImageProps) => { }; const handleUseSeed = () => { - dispatch(setSeed(image.metadata.image.seed)); + image.metadata && dispatch(setSeed(image.metadata.image.seed)); toast({ title: 'Seed Set', status: 'success', @@ -78,7 +78,7 @@ const HoverableImage = memo((props: HoverableImageProps) => { }; const handleSendToImageToImage = () => { - dispatch(setInitialImagePath(image.url)); + dispatch(setInitialImage(image)); if (activeTabName !== 'img2img') { dispatch(setActiveTab('img2img')); } @@ -104,7 +104,7 @@ const HoverableImage = memo((props: HoverableImageProps) => { }; const handleUseAllParameters = () => { - dispatch(setAllTextToImageParameters(metadata)); + metadata && dispatch(setAllTextToImageParameters(metadata)); toast({ title: 'Parameters Set', status: 'success', diff --git a/frontend/src/features/gallery/ImageGallery.scss b/frontend/src/features/gallery/ImageGallery.scss index 3bb00b769d..7e62890676 100644 --- a/frontend/src/features/gallery/ImageGallery.scss +++ b/frontend/src/features/gallery/ImageGallery.scss @@ -126,6 +126,22 @@ } } +.image-gallery-category-btn-group { + width: 100% !important; + column-gap: 0 !important; + justify-content: stretch !important; + + button { + flex-grow: 1; + &[data-selected='true'] { + background-color: var(--accent-color); + &:hover { + background-color: var(--accent-color-hover); + } + } + } +} + // from https://css-tricks.com/a-grid-of-logos-in-squares/ .image-gallery { display: grid; diff --git a/frontend/src/features/gallery/ImageGallery.tsx b/frontend/src/features/gallery/ImageGallery.tsx index 373bb5852c..6fe5fff354 100644 --- a/frontend/src/features/gallery/ImageGallery.tsx +++ b/frontend/src/features/gallery/ImageGallery.tsx @@ -11,6 +11,7 @@ import IAIIconButton from '../../common/components/IAIIconButton'; import { selectNextImage, selectPrevImage, + setCurrentCategory, setGalleryImageMinimumWidth, setGalleryImageObjectFit, setGalleryScrollPosition, @@ -20,7 +21,7 @@ import { } from './gallerySlice'; import HoverableImage from './HoverableImage'; import { setShouldShowGallery } from '../gallery/gallerySlice'; -import { Spacer, useToast } from '@chakra-ui/react'; +import { ButtonGroup, Spacer, useToast } from '@chakra-ui/react'; import { CSSTransition } from 'react-transition-group'; import { Direction } from 're-resizable/lib/resizer'; import { imageGallerySelector } from './gallerySliceSelectors'; @@ -36,8 +37,8 @@ export default function ImageGallery() { const { images, + currentCategory, currentImageUuid, - areMoreImagesAvailable, shouldPinGallery, shouldShowGallery, galleryScrollPosition, @@ -47,6 +48,7 @@ export default function ImageGallery() { galleryImageObjectFit, shouldHoldGalleryOpen, shouldAutoSwitchToNewImages, + areMoreImagesAvailable, } = useAppSelector(imageGallerySelector); const [gallerySize, setGallerySize] = useState({ @@ -128,7 +130,7 @@ export default function ImageGallery() { }; const handleClickLoadMore = () => { - dispatch(requestImages()); + dispatch(requestImages(currentCategory)); }; const handleChangeGalleryImageMinimumWidth = (v: number) => { @@ -151,13 +153,21 @@ export default function ImageGallery() { [shouldShowGallery] ); - useHotkeys('left', () => { - dispatch(selectPrevImage()); - }); + useHotkeys( + 'left', + () => { + dispatch(selectPrevImage(currentCategory)); + }, + [currentCategory] + ); - useHotkeys('right', () => { - dispatch(selectNextImage()); - }); + useHotkeys( + 'right', + () => { + dispatch(selectNextImage(currentCategory)); + }, + [currentCategory] + ); useHotkeys( 'shift+p', @@ -317,7 +327,7 @@ export default function ImageGallery() { ) : null}
+
+ + + + +
{ - const { images, currentImageUuid, areMoreImagesAvailable } = useAppSelector( - (state: RootState) => state.gallery - ); - const dispatch = useAppDispatch(); - - const { isOpen, onOpen, onClose } = useDisclosure(); - - /** - * I don't like that this needs to rerender whenever the current image is changed. - * What if we have a large number of images? I suppose pagination (planned) will - * mitigate this issue. - * - * TODO: Refactor if performance complaints, or after migrating to new API which supports pagination. - */ - - const handleClickLoadMore = () => { - dispatch(requestImages()); - }; - - useHotkeys( - 'g', - () => { - if (isOpen) { - onClose(); - } else { - onOpen(); - } - }, - [isOpen] - ); - - useHotkeys( - 'left', - () => { - dispatch(selectPrevImage()); - }, - [] - ); - - useHotkeys( - 'right', - () => { - dispatch(selectNextImage()); - }, - [] - ); - - return ( -
- - - -
- Your Invocations - -
- -
- {images.length ? ( -
- {images.map((image) => { - const { uuid } = image; - const isSelected = currentImageUuid === uuid; - return ( - - ); - })} -
- ) : ( -
- -

No Images In Gallery

-
- )} - -
-
-
-
-
- ); -}; - -export default ImageGalleryOld; diff --git a/frontend/src/features/gallery/ImageMetaDataViewer/ImageMetadataViewer.tsx b/frontend/src/features/gallery/ImageMetaDataViewer/ImageMetadataViewer.tsx index 0c2cc12942..080541419e 100644 --- a/frontend/src/features/gallery/ImageMetaDataViewer/ImageMetadataViewer.tsx +++ b/frontend/src/features/gallery/ImageMetaDataViewer/ImageMetadataViewer.tsx @@ -20,7 +20,6 @@ import { setHeight, setHiresFix, setImg2imgStrength, - setInitialImagePath, setMaskPath, setPrompt, setSampler, @@ -32,6 +31,7 @@ import { setUpscalingLevel, setUpscalingStrength, setWidth, + setInitialImage, } from '../../options/optionsSlice'; import promptToString from '../../../common/util/promptToString'; import { seedWeightsToString } from '../../../common/util/seedWeightPairs'; @@ -248,7 +248,7 @@ const ImageMetadataViewer = memo( label="Initial image" value={init_image_path} isLink - onClick={() => dispatch(setInitialImagePath(init_image_path))} + onClick={() => dispatch(setInitialImage(init_image_path))} /> )} {mask_image_path && ( diff --git a/frontend/src/features/gallery/gallerySlice.ts b/frontend/src/features/gallery/gallerySlice.ts index 2bc1c6d9a9..d517d0d86b 100644 --- a/frontend/src/features/gallery/gallerySlice.ts +++ b/frontend/src/features/gallery/gallerySlice.ts @@ -3,16 +3,27 @@ import type { PayloadAction } from '@reduxjs/toolkit'; import _, { clamp } from 'lodash'; import * as InvokeAI from '../../app/invokeai'; +export type GalleryCategory = 'user' | 'result'; + +export type AddImagesPayload = { + images: Array; + areMoreImagesAvailable: boolean; + category: GalleryCategory; +}; + type GalleryImageObjectFitType = 'contain' | 'cover'; +export type Gallery = { + images: InvokeAI.Image[]; + latest_mtime?: number; + earliest_mtime?: number; + areMoreImagesAvailable: boolean; +}; + export interface GalleryState { currentImage?: InvokeAI.Image; currentImageUuid: string; - images: Array; intermediateImage?: InvokeAI.Image; - areMoreImagesAvailable: boolean; - latest_mtime?: number; - earliest_mtime?: number; shouldPinGallery: boolean; shouldShowGallery: boolean; galleryScrollPosition: number; @@ -20,12 +31,15 @@ export interface GalleryState { galleryImageObjectFit: GalleryImageObjectFitType; shouldHoldGalleryOpen: boolean; shouldAutoSwitchToNewImages: boolean; + categories: { + user: Gallery; + result: Gallery; + }; + currentCategory: GalleryCategory; } const initialState: GalleryState = { currentImageUuid: '', - images: [], - areMoreImagesAvailable: true, shouldPinGallery: true, shouldShowGallery: true, galleryScrollPosition: 0, @@ -33,6 +47,21 @@ const initialState: GalleryState = { galleryImageObjectFit: 'cover', shouldHoldGalleryOpen: false, shouldAutoSwitchToNewImages: true, + currentCategory: 'result', + categories: { + user: { + images: [], + latest_mtime: undefined, + earliest_mtime: undefined, + areMoreImagesAvailable: true, + }, + result: { + images: [], + latest_mtime: undefined, + earliest_mtime: undefined, + areMoreImagesAvailable: true, + }, + }, }; export const gallerySlice = createSlice({ @@ -43,10 +72,15 @@ export const gallerySlice = createSlice({ state.currentImage = action.payload; state.currentImageUuid = action.payload.uuid; }, - removeImage: (state, action: PayloadAction) => { - const uuid = action.payload; + removeImage: ( + state, + action: PayloadAction + ) => { + const { uuid, category } = action.payload; - const newImages = state.images.filter((image) => image.uuid !== uuid); + const tempImages = state.categories[category as GalleryCategory].images; + + const newImages = tempImages.filter((image) => image.uuid !== uuid); if (uuid === state.currentImageUuid) { /** @@ -58,7 +92,7 @@ export const gallerySlice = createSlice({ * * Get the currently selected image's index. */ - const imageToDeleteIndex = state.images.findIndex( + const imageToDeleteIndex = tempImages.findIndex( (image) => image.uuid === uuid ); @@ -84,24 +118,35 @@ export const gallerySlice = createSlice({ : ''; } - state.images = newImages; + state.categories[category as GalleryCategory].images = newImages; }, - addImage: (state, action: PayloadAction) => { - const newImage = action.payload; + addImage: ( + state, + action: PayloadAction<{ + image: InvokeAI.Image; + category: GalleryCategory; + }> + ) => { + const { image: newImage, category } = action.payload; const { uuid, url, mtime } = newImage; + const tempCategory = state.categories[category as GalleryCategory]; + // Do not add duplicate images - if (state.images.find((i) => i.url === url && i.mtime === mtime)) { + if (tempCategory.images.find((i) => i.url === url && i.mtime === mtime)) { return; } - state.images.unshift(newImage); + tempCategory.images.unshift(newImage); if (state.shouldAutoSwitchToNewImages) { state.currentImageUuid = uuid; state.currentImage = newImage; + if (category === 'result') { + state.currentCategory = 'result'; + } } state.intermediateImage = undefined; - state.latest_mtime = mtime; + tempCategory.latest_mtime = mtime; }, setIntermediateImage: (state, action: PayloadAction) => { state.intermediateImage = action.payload; @@ -109,49 +154,53 @@ export const gallerySlice = createSlice({ clearIntermediateImage: (state) => { state.intermediateImage = undefined; }, - selectNextImage: (state) => { - const { images, currentImage } = state; + selectNextImage: (state, action: PayloadAction) => { + const category = action.payload; + const { currentImage } = state; + const tempImages = state.categories[category].images; + if (currentImage) { - const currentImageIndex = images.findIndex( + const currentImageIndex = tempImages.findIndex( (i) => i.uuid === currentImage.uuid ); - if (_.inRange(currentImageIndex, 0, images.length)) { - const newCurrentImage = images[currentImageIndex + 1]; + if (_.inRange(currentImageIndex, 0, tempImages.length)) { + const newCurrentImage = tempImages[currentImageIndex + 1]; state.currentImage = newCurrentImage; state.currentImageUuid = newCurrentImage.uuid; } } }, - selectPrevImage: (state) => { - const { images, currentImage } = state; + selectPrevImage: (state, action: PayloadAction) => { + const category = action.payload; + const { currentImage } = state; + const tempImages = state.categories[category].images; + if (currentImage) { - const currentImageIndex = images.findIndex( + const currentImageIndex = tempImages.findIndex( (i) => i.uuid === currentImage.uuid ); - if (_.inRange(currentImageIndex, 1, images.length + 1)) { - const newCurrentImage = images[currentImageIndex - 1]; + if (_.inRange(currentImageIndex, 1, tempImages.length + 1)) { + const newCurrentImage = tempImages[currentImageIndex - 1]; state.currentImage = newCurrentImage; state.currentImageUuid = newCurrentImage.uuid; } } }, - addGalleryImages: ( - state, - action: PayloadAction<{ - images: Array; - areMoreImagesAvailable: boolean; - }> - ) => { - const { images, areMoreImagesAvailable } = action.payload; + addGalleryImages: (state, action: PayloadAction) => { + const { images, areMoreImagesAvailable, category } = action.payload; + const tempImages = state.categories[category].images; + + // const prevImages = category === 'user' ? state.userImages : state.resultImages + if (images.length > 0) { // Filter images that already exist in the gallery const newImages = images.filter( (newImage) => - !state.images.find( + !tempImages.find( (i) => i.url === newImage.url && i.mtime === newImage.mtime ) ); - state.images = state.images + state.categories[category].images = tempImages .concat(newImages) .sort((a, b) => b.mtime - a.mtime); @@ -162,11 +211,14 @@ export const gallerySlice = createSlice({ } // keep track of the timestamps of latest and earliest images received - state.latest_mtime = images[0].mtime; - state.earliest_mtime = images[images.length - 1].mtime; + state.categories[category].latest_mtime = images[0].mtime; + state.categories[category].earliest_mtime = + images[images.length - 1].mtime; } + if (areMoreImagesAvailable !== undefined) { - state.areMoreImagesAvailable = areMoreImagesAvailable; + state.categories[category].areMoreImagesAvailable = + areMoreImagesAvailable; } }, setShouldPinGallery: (state, action: PayloadAction) => { @@ -193,6 +245,9 @@ export const gallerySlice = createSlice({ setShouldAutoSwitchToNewImages: (state, action: PayloadAction) => { state.shouldAutoSwitchToNewImages = action.payload; }, + setCurrentCategory: (state, action: PayloadAction) => { + state.currentCategory = action.payload; + }, }, }); @@ -212,6 +267,7 @@ export const { setGalleryImageObjectFit, setShouldHoldGalleryOpen, setShouldAutoSwitchToNewImages, + setCurrentCategory, } = gallerySlice.actions; export default gallerySlice.reducer; diff --git a/frontend/src/features/gallery/gallerySliceSelectors.ts b/frontend/src/features/gallery/gallerySliceSelectors.ts index 260182f542..4dd031fcb1 100644 --- a/frontend/src/features/gallery/gallerySliceSelectors.ts +++ b/frontend/src/features/gallery/gallerySliceSelectors.ts @@ -8,24 +8,22 @@ export const imageGallerySelector = createSelector( [(state: RootState) => state.gallery, (state: RootState) => state.options], (gallery: GalleryState, options: OptionsState) => { const { - images, + categories, + currentCategory, currentImageUuid, - areMoreImagesAvailable, shouldPinGallery, shouldShowGallery, galleryScrollPosition, galleryImageMinimumWidth, galleryImageObjectFit, shouldHoldGalleryOpen, - shouldAutoSwitchToNewImages + shouldAutoSwitchToNewImages, } = gallery; const { activeTab } = options; return { - images, currentImageUuid, - areMoreImagesAvailable, shouldPinGallery, shouldShowGallery, galleryScrollPosition, @@ -34,7 +32,11 @@ export const imageGallerySelector = createSelector( galleryGridTemplateColumns: `repeat(auto-fill, minmax(${galleryImageMinimumWidth}px, auto))`, activeTabName: tabMap[activeTab], shouldHoldGalleryOpen, - shouldAutoSwitchToNewImages + shouldAutoSwitchToNewImages, + images: categories[currentCategory].images, + areMoreImagesAvailable: + categories[currentCategory].areMoreImagesAvailable, + currentCategory, }; } ); diff --git a/frontend/src/features/options/optionsSlice.ts b/frontend/src/features/options/optionsSlice.ts index 8372293b5a..22a2c03c1c 100644 --- a/frontend/src/features/options/optionsSlice.ts +++ b/frontend/src/features/options/optionsSlice.ts @@ -27,8 +27,7 @@ export interface OptionsState { codeformerFidelity: number; upscalingLevel: UpscalingLevel; upscalingStrength: number; - shouldUseInitImage: boolean; - initialImagePath: string | null; + initialImage?: InvokeAI.Image | string; // can be an Image or url maskPath: string; seamless: boolean; hiresFix: boolean; @@ -58,9 +57,7 @@ const initialOptionsState: OptionsState = { seed: 0, seamless: false, hiresFix: false, - shouldUseInitImage: false, img2imgStrength: 0.75, - initialImagePath: null, maskPath: '', shouldFitToWidthHeight: true, shouldGenerateVariations: false, @@ -137,14 +134,6 @@ export const optionsSlice = createSlice({ setUpscalingStrength: (state, action: PayloadAction) => { state.upscalingStrength = action.payload; }, - setShouldUseInitImage: (state, action: PayloadAction) => { - state.shouldUseInitImage = action.payload; - }, - setInitialImagePath: (state, action: PayloadAction) => { - const newInitialImagePath = action.payload; - state.shouldUseInitImage = newInitialImagePath ? true : false; - state.initialImagePath = newInitialImagePath; - }, setMaskPath: (state, action: PayloadAction) => { state.maskPath = action.payload; }, @@ -170,9 +159,6 @@ export const optionsSlice = createSlice({ if (key === 'seed') { temp.shouldRandomizeSeed = false; } - if (key === 'initialImagePath' && value === '') { - temp.shouldUseInitImage = false; - } return temp; }, setShouldGenerateVariations: (state, action: PayloadAction) => { @@ -236,13 +222,10 @@ export const optionsSlice = createSlice({ action.payload.image; if (type === 'img2img') { - if (init_image_path) state.initialImagePath = init_image_path; + if (init_image_path) state.initialImage = init_image_path; if (mask_image_path) state.maskPath = mask_image_path; if (strength) state.img2imgStrength = strength; if (typeof fit === 'boolean') state.shouldFitToWidthHeight = fit; - state.shouldUseInitImage = true; - } else { - state.shouldUseInitImage = false; } }, setAllParameters: (state, action: PayloadAction) => { @@ -267,13 +250,10 @@ export const optionsSlice = createSlice({ } = action.payload.image; if (type === 'img2img') { - if (init_image_path) state.initialImagePath = init_image_path; + if (init_image_path) state.initialImage = init_image_path; if (mask_image_path) state.maskPath = mask_image_path; if (strength) state.img2imgStrength = strength; if (typeof fit === 'boolean') state.shouldFitToWidthHeight = fit; - state.shouldUseInitImage = true; - } else { - state.shouldUseInitImage = false; } if (variations && variations.length > 0) { @@ -335,6 +315,15 @@ export const optionsSlice = createSlice({ setShowDualDisplay: (state, action: PayloadAction) => { state.showDualDisplay = action.payload; }, + setInitialImage: ( + state, + action: PayloadAction + ) => { + state.initialImage = action.payload; + }, + clearInitialImage: (state) => { + state.initialImage = undefined; + }, }, }); @@ -357,8 +346,6 @@ export const { setCodeformerFidelity, setUpscalingLevel, setUpscalingStrength, - setShouldUseInitImage, - setInitialImagePath, setMaskPath, resetSeed, resetOptionsState, @@ -377,6 +364,8 @@ export const { setAllTextToImageParameters, setAllImageToImageParameters, setShowDualDisplay, + setInitialImage, + clearInitialImage, } = optionsSlice.actions; export default optionsSlice.reducer; diff --git a/frontend/src/features/system/Console.scss b/frontend/src/features/system/Console.scss index 1aad0b2641..d8c5d1249f 100644 --- a/frontend/src/features/system/Console.scss +++ b/frontend/src/features/system/Console.scss @@ -1,4 +1,5 @@ .console { + width: 100vw; display: flex; flex-direction: column; background: var(--console-bg-color); diff --git a/frontend/src/features/tabs/ImageToImage/ImageToImageDisplay.tsx b/frontend/src/features/tabs/ImageToImage/ImageToImageDisplay.tsx index 7e7a57a5aa..573b674e56 100644 --- a/frontend/src/features/tabs/ImageToImage/ImageToImageDisplay.tsx +++ b/frontend/src/features/tabs/ImageToImage/ImageToImageDisplay.tsx @@ -10,8 +10,8 @@ import ImageMetadataViewer from '../../gallery/ImageMetaDataViewer/ImageMetadata import InitImagePreview from './InitImagePreview'; export default function ImageToImageDisplay() { - const initialImagePath = useAppSelector( - (state: RootState) => state.options.initialImagePath + const initialImage = useAppSelector( + (state: RootState) => state.options.initialImage ); const { currentImage, intermediateImage } = useAppSelector( @@ -33,7 +33,7 @@ export default function ImageToImageDisplay() { : { gridAutoRows: 'auto' } } > - {initialImagePath ? ( + {initialImage ? ( <> {imageToDisplay ? ( <> diff --git a/frontend/src/features/tabs/ImageToImage/InitImagePreview.tsx b/frontend/src/features/tabs/ImageToImage/InitImagePreview.tsx index 9e4ba3489c..45a888da75 100644 --- a/frontend/src/features/tabs/ImageToImage/InitImagePreview.tsx +++ b/frontend/src/features/tabs/ImageToImage/InitImagePreview.tsx @@ -2,12 +2,10 @@ import { IconButton, Image, useToast } from '@chakra-ui/react'; import React, { SyntheticEvent } from 'react'; import { MdClear } from 'react-icons/md'; import { RootState, useAppDispatch, useAppSelector } from '../../../app/store'; -import { setInitialImagePath } from '../../options/optionsSlice'; +import { clearInitialImage } from '../../options/optionsSlice'; export default function InitImagePreview() { - const initialImagePath = useAppSelector( - (state: RootState) => state.options.initialImagePath - ); + const { initialImage } = useAppSelector((state: RootState) => state.options); const dispatch = useAppDispatch(); @@ -15,7 +13,7 @@ export default function InitImagePreview() { const handleClickResetInitialImage = (e: SyntheticEvent) => { e.stopPropagation(); - dispatch(setInitialImagePath(null)); + dispatch(clearInitialImage()); }; const alertMissingInitImage = () => { @@ -25,7 +23,7 @@ export default function InitImagePreview() { status: 'error', isClosable: true, }); - dispatch(setInitialImagePath(null)); + dispatch(clearInitialImage()); }; return ( @@ -33,18 +31,20 @@ export default function InitImagePreview() {

Initial Image

} />
- {initialImagePath && ( + {initialImage && (
diff --git a/frontend/src/features/tabs/ImageToImage/InitialImageOverlay.tsx b/frontend/src/features/tabs/ImageToImage/InitialImageOverlay.tsx index acc49a701f..0c342821cf 100644 --- a/frontend/src/features/tabs/ImageToImage/InitialImageOverlay.tsx +++ b/frontend/src/features/tabs/ImageToImage/InitialImageOverlay.tsx @@ -3,14 +3,14 @@ import React from 'react'; import { RootState, useAppSelector } from '../../../app/store'; export default function InitialImageOverlay() { - const initialImagePath = useAppSelector( - (state: RootState) => state.options.initialImagePath + const initialImage = useAppSelector( + (state: RootState) => state.options.initialImage ); - return initialImagePath ? ( + return initialImage ? ( diff --git a/frontend/src/features/tabs/InvokeWorkarea.tsx b/frontend/src/features/tabs/InvokeWorkarea.tsx index faed672931..080635e794 100644 --- a/frontend/src/features/tabs/InvokeWorkarea.tsx +++ b/frontend/src/features/tabs/InvokeWorkarea.tsx @@ -12,9 +12,8 @@ type InvokeWorkareaProps = { const InvokeWorkarea = (props: InvokeWorkareaProps) => { const { optionsPanel, className, children } = props; - const { shouldShowGallery, shouldHoldGalleryOpen } = useAppSelector( - (state: RootState) => state.gallery - ); + const { shouldShowGallery, shouldHoldGalleryOpen, shouldPinGallery } = + useAppSelector((state: RootState) => state.gallery); return (
{
{children}
- {!(shouldShowGallery || shouldHoldGalleryOpen) && ( + {!(shouldShowGallery || (shouldHoldGalleryOpen && !shouldPinGallery)) && ( )}