mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
38bdb440d0
- dangling debug messages in several files, introduced during testing of the external root directory - these need to be removed before they are interpreted as errors by users
1268 lines
46 KiB
Python
1268 lines
46 KiB
Python
import eventlet
|
|
import glob
|
|
import os
|
|
import shutil
|
|
import mimetypes
|
|
import traceback
|
|
import math
|
|
import io
|
|
import base64
|
|
|
|
from flask import Flask, redirect, send_from_directory
|
|
from flask_socketio import SocketIO
|
|
from PIL import Image
|
|
from uuid import uuid4
|
|
from threading import Event
|
|
|
|
from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
|
|
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata
|
|
from ldm.invoke.prompt_parser import split_weighted_subprompts
|
|
|
|
from backend.modules.parameters import parameters_to_command
|
|
|
|
# Loading Arguments
|
|
opt = Args()
|
|
args = opt.parse_args()
|
|
|
|
# Set the root directory for static files and relative paths
|
|
args.root_dir = os.path.expanduser(args.root_dir or '..')
|
|
if not os.path.isabs(args.outdir):
|
|
args.outdir=os.path.join(args.root_dir,args.outdir)
|
|
|
|
class InvokeAIWebServer:
|
|
def __init__(self, generate, gfpgan, codeformer, esrgan) -> None:
|
|
self.host = args.host
|
|
self.port = args.port
|
|
|
|
self.generate = generate
|
|
self.gfpgan = gfpgan
|
|
self.codeformer = codeformer
|
|
self.esrgan = esrgan
|
|
|
|
self.canceled = Event()
|
|
|
|
def run(self):
|
|
self.setup_app()
|
|
self.setup_flask()
|
|
|
|
def setup_flask(self):
|
|
# Fix missing mimetypes on Windows
|
|
mimetypes.add_type("application/javascript", ".js")
|
|
mimetypes.add_type("text/css", ".css")
|
|
# Socket IO
|
|
logger = True if args.web_verbose else False
|
|
engineio_logger = True if args.web_verbose else False
|
|
max_http_buffer_size = 10000000
|
|
|
|
socketio_args = {
|
|
"logger": logger,
|
|
"engineio_logger": engineio_logger,
|
|
"max_http_buffer_size": max_http_buffer_size,
|
|
"ping_interval": (50, 50),
|
|
"ping_timeout": 60,
|
|
}
|
|
|
|
if opt.cors:
|
|
socketio_args["cors_allowed_origins"] = opt.cors
|
|
|
|
self.app = Flask(
|
|
__name__, static_url_path="", static_folder=os.path.join(args.root_dir,"frontend/dist")
|
|
)
|
|
|
|
self.socketio = SocketIO(self.app, **socketio_args)
|
|
|
|
# Keep Server Alive Route
|
|
@self.app.route("/flaskwebgui-keep-server-alive")
|
|
def keep_alive():
|
|
return {"message": "Server Running"}
|
|
|
|
# Outputs Route
|
|
self.app.config["OUTPUTS_FOLDER"] = os.path.abspath(args.outdir)
|
|
|
|
@self.app.route("/outputs/<path:file_path>")
|
|
def outputs(file_path):
|
|
return send_from_directory(self.app.config["OUTPUTS_FOLDER"], file_path)
|
|
|
|
# Base Route
|
|
@self.app.route("/")
|
|
def serve():
|
|
if args.web_develop:
|
|
return redirect("http://127.0.0.1:5173")
|
|
else:
|
|
return send_from_directory(self.app.static_folder, "index.html")
|
|
|
|
self.load_socketio_listeners(self.socketio)
|
|
|
|
if args.gui:
|
|
print(">> Launching Invoke AI GUI")
|
|
close_server_on_exit = True
|
|
if args.web_develop:
|
|
close_server_on_exit = False
|
|
try:
|
|
from flaskwebgui import FlaskUI
|
|
|
|
FlaskUI(
|
|
app=self.app,
|
|
socketio=self.socketio,
|
|
start_server="flask-socketio",
|
|
host=self.host,
|
|
port=self.port,
|
|
width=1600,
|
|
height=1000,
|
|
idle_interval=10,
|
|
close_server_on_exit=close_server_on_exit,
|
|
).run()
|
|
except KeyboardInterrupt:
|
|
import sys
|
|
|
|
sys.exit(0)
|
|
else:
|
|
print(">> Started Invoke AI Web Server!")
|
|
if self.host == "0.0.0.0":
|
|
print(
|
|
f"Point your browser at http://localhost:{self.port} or use the host's DNS name or IP address."
|
|
)
|
|
else:
|
|
print(
|
|
">> Default host address now 127.0.0.1 (localhost). Use --host 0.0.0.0 to bind any address."
|
|
)
|
|
print(f">> Point your browser at http://{self.host}:{self.port}")
|
|
self.socketio.run(app=self.app, host=self.host, port=self.port)
|
|
|
|
def setup_app(self):
|
|
self.result_url = "outputs/"
|
|
self.init_image_url = "outputs/init-images/"
|
|
self.mask_image_url = "outputs/mask-images/"
|
|
self.intermediate_url = "outputs/intermediates/"
|
|
# location for "finished" images
|
|
self.result_path = args.outdir
|
|
# temporary path for intermediates
|
|
self.intermediate_path = os.path.join(self.result_path, "intermediates/")
|
|
# path for user-uploaded init images and masks
|
|
self.init_image_path = os.path.join(self.result_path, "init-images/")
|
|
self.mask_image_path = os.path.join(self.result_path, "mask-images/")
|
|
# txt log
|
|
self.log_path = os.path.join(self.result_path, "invoke_log.txt")
|
|
# make all output paths
|
|
[
|
|
os.makedirs(path, exist_ok=True)
|
|
for path in [
|
|
self.result_path,
|
|
self.intermediate_path,
|
|
self.init_image_path,
|
|
self.mask_image_path,
|
|
]
|
|
]
|
|
|
|
def load_socketio_listeners(self, socketio):
|
|
@socketio.on("requestSystemConfig")
|
|
def handle_request_capabilities():
|
|
print(f">> System config requested")
|
|
config = self.get_system_config()
|
|
socketio.emit("systemConfig", config)
|
|
|
|
@socketio.on("requestModelChange")
|
|
def handle_set_model(model_name: str):
|
|
try:
|
|
print(f">> Model change requested: {model_name}")
|
|
model = self.generate.set_model(model_name)
|
|
model_list = self.generate.model_cache.list_models()
|
|
if model is None:
|
|
socketio.emit(
|
|
"modelChangeFailed",
|
|
{"model_name": model_name, "model_list": model_list},
|
|
)
|
|
else:
|
|
socketio.emit(
|
|
"modelChanged",
|
|
{"model_name": model_name, "model_list": model_list},
|
|
)
|
|
except Exception as e:
|
|
self.socketio.emit("error", {"message": (str(e))})
|
|
print("\n")
|
|
|
|
traceback.print_exc()
|
|
print("\n")
|
|
|
|
@socketio.on("requestLatestImages")
|
|
def handle_request_latest_images(category, latest_mtime):
|
|
try:
|
|
base_path = (
|
|
self.result_path if category == "result" else self.init_image_path
|
|
)
|
|
|
|
paths = []
|
|
for ext in ("*.png", "*.jpg", "*.jpeg"):
|
|
paths.extend(glob.glob(os.path.join(base_path, ext)))
|
|
|
|
image_paths = sorted(
|
|
paths, key=lambda x: os.path.getmtime(x), reverse=True
|
|
)
|
|
|
|
image_paths = list(
|
|
filter(
|
|
lambda x: os.path.getmtime(x) > latest_mtime,
|
|
image_paths,
|
|
)
|
|
)
|
|
|
|
image_array = []
|
|
|
|
for path in image_paths:
|
|
if os.path.splitext(path)[1] == ".png":
|
|
metadata = retrieve_metadata(path)
|
|
sd_metadata = metadata["sd-metadata"]
|
|
else:
|
|
sd_metadata = {}
|
|
|
|
(width, height) = Image.open(path).size
|
|
|
|
image_array.append(
|
|
{
|
|
"url": self.get_url_from_image_path(path),
|
|
"mtime": os.path.getmtime(path),
|
|
"metadata": sd_metadata,
|
|
"width": width,
|
|
"height": height,
|
|
"category": category,
|
|
}
|
|
)
|
|
|
|
socketio.emit(
|
|
"galleryImages",
|
|
{"images": image_array, "category": category},
|
|
)
|
|
except Exception as e:
|
|
self.socketio.emit("error", {"message": (str(e))})
|
|
print("\n")
|
|
|
|
traceback.print_exc()
|
|
print("\n")
|
|
|
|
@socketio.on("requestImages")
|
|
def handle_request_images(category, earliest_mtime=None):
|
|
try:
|
|
page_size = 50
|
|
|
|
base_path = (
|
|
self.result_path if category == "result" else self.init_image_path
|
|
)
|
|
|
|
paths = []
|
|
for ext in ("*.png", "*.jpg", "*.jpeg"):
|
|
paths.extend(glob.glob(os.path.join(base_path, ext)))
|
|
|
|
image_paths = sorted(
|
|
paths, key=lambda x: os.path.getmtime(x), reverse=True
|
|
)
|
|
|
|
if earliest_mtime:
|
|
image_paths = list(
|
|
filter(
|
|
lambda x: os.path.getmtime(x) < earliest_mtime,
|
|
image_paths,
|
|
)
|
|
)
|
|
|
|
areMoreImagesAvailable = len(image_paths) >= page_size
|
|
image_paths = image_paths[slice(0, page_size)]
|
|
|
|
image_array = []
|
|
for path in image_paths:
|
|
if os.path.splitext(path)[1] == ".png":
|
|
metadata = retrieve_metadata(path)
|
|
sd_metadata = metadata["sd-metadata"]
|
|
else:
|
|
sd_metadata = {}
|
|
|
|
(width, height) = Image.open(path).size
|
|
|
|
image_array.append(
|
|
{
|
|
"url": self.get_url_from_image_path(path),
|
|
"mtime": os.path.getmtime(path),
|
|
"metadata": sd_metadata,
|
|
"width": width,
|
|
"height": height,
|
|
"category": category,
|
|
}
|
|
)
|
|
|
|
socketio.emit(
|
|
"galleryImages",
|
|
{
|
|
"images": image_array,
|
|
"areMoreImagesAvailable": areMoreImagesAvailable,
|
|
"category": category,
|
|
},
|
|
)
|
|
except Exception as e:
|
|
self.socketio.emit("error", {"message": (str(e))})
|
|
print("\n")
|
|
|
|
traceback.print_exc()
|
|
print("\n")
|
|
|
|
@socketio.on("generateImage")
|
|
def handle_generate_image_event(
|
|
generation_parameters, esrgan_parameters, facetool_parameters
|
|
):
|
|
try:
|
|
# truncate long init_mask base64 if needed
|
|
if "init_mask" in generation_parameters:
|
|
printable_parameters = {
|
|
**generation_parameters,
|
|
"init_mask": generation_parameters["init_mask"][:20] + "...",
|
|
}
|
|
print(
|
|
f">> Image generation requested: {printable_parameters}\nESRGAN parameters: {esrgan_parameters}\nFacetool parameters: {facetool_parameters}"
|
|
)
|
|
else:
|
|
print(
|
|
f">> Image generation requested: {generation_parameters}\nESRGAN parameters: {esrgan_parameters}\nFacetool parameters: {facetool_parameters}"
|
|
)
|
|
self.generate_images(
|
|
generation_parameters,
|
|
esrgan_parameters,
|
|
facetool_parameters,
|
|
)
|
|
except Exception as e:
|
|
self.socketio.emit("error", {"message": (str(e))})
|
|
print("\n")
|
|
|
|
traceback.print_exc()
|
|
print("\n")
|
|
|
|
@socketio.on("runPostprocessing")
|
|
def handle_run_postprocessing(original_image, postprocessing_parameters):
|
|
try:
|
|
print(
|
|
f'>> Postprocessing requested for "{original_image["url"]}": {postprocessing_parameters}'
|
|
)
|
|
|
|
progress = Progress()
|
|
|
|
socketio.emit("progressUpdate", progress.to_formatted_dict())
|
|
eventlet.sleep(0)
|
|
|
|
original_image_path = self.get_image_path_from_url(
|
|
original_image["url"]
|
|
)
|
|
|
|
image = Image.open(original_image_path)
|
|
|
|
seed = (
|
|
original_image["metadata"]["seed"]
|
|
if "metadata" in original_image
|
|
and "seed" in original_image["metadata"]
|
|
else "unknown_seed"
|
|
)
|
|
|
|
if postprocessing_parameters["type"] == "esrgan":
|
|
progress.set_current_status("Upscaling (ESRGAN)")
|
|
elif postprocessing_parameters["type"] == "gfpgan":
|
|
progress.set_current_status("Restoring Faces (GFPGAN)")
|
|
elif postprocessing_parameters["type"] == "codeformer":
|
|
progress.set_current_status("Restoring Faces (Codeformer)")
|
|
|
|
socketio.emit("progressUpdate", progress.to_formatted_dict())
|
|
eventlet.sleep(0)
|
|
|
|
if postprocessing_parameters["type"] == "esrgan":
|
|
image = self.esrgan.process(
|
|
image=image,
|
|
upsampler_scale=postprocessing_parameters["upscale"][0],
|
|
strength=postprocessing_parameters["upscale"][1],
|
|
seed=seed,
|
|
)
|
|
elif postprocessing_parameters["type"] == "gfpgan":
|
|
image = self.gfpgan.process(
|
|
image=image,
|
|
strength=postprocessing_parameters["facetool_strength"],
|
|
seed=seed,
|
|
)
|
|
elif postprocessing_parameters["type"] == "codeformer":
|
|
image = self.codeformer.process(
|
|
image=image,
|
|
strength=postprocessing_parameters["facetool_strength"],
|
|
fidelity=postprocessing_parameters["codeformer_fidelity"],
|
|
seed=seed,
|
|
device="cpu"
|
|
if str(self.generate.device) == "mps"
|
|
else self.generate.device,
|
|
)
|
|
else:
|
|
raise TypeError(
|
|
f'{postprocessing_parameters["type"]} is not a valid postprocessing type'
|
|
)
|
|
|
|
progress.set_current_status("Saving Image")
|
|
socketio.emit("progressUpdate", progress.to_formatted_dict())
|
|
eventlet.sleep(0)
|
|
|
|
postprocessing_parameters["seed"] = seed
|
|
metadata = self.parameters_to_post_processed_image_metadata(
|
|
parameters=postprocessing_parameters,
|
|
original_image_path=original_image_path,
|
|
)
|
|
|
|
command = parameters_to_command(postprocessing_parameters)
|
|
|
|
(width, height) = image.size
|
|
|
|
path = self.save_result_image(
|
|
image,
|
|
command,
|
|
metadata,
|
|
self.result_path,
|
|
postprocessing=postprocessing_parameters["type"],
|
|
)
|
|
|
|
self.write_log_message(
|
|
f'[Postprocessed] "{original_image_path}" > "{path}": {postprocessing_parameters}'
|
|
)
|
|
|
|
progress.mark_complete()
|
|
socketio.emit("progressUpdate", progress.to_formatted_dict())
|
|
eventlet.sleep(0)
|
|
|
|
socketio.emit(
|
|
"postprocessingResult",
|
|
{
|
|
"url": self.get_url_from_image_path(path),
|
|
"mtime": os.path.getmtime(path),
|
|
"metadata": metadata,
|
|
"width": width,
|
|
"height": height,
|
|
},
|
|
)
|
|
except Exception as e:
|
|
self.socketio.emit("error", {"message": (str(e))})
|
|
print("\n")
|
|
|
|
traceback.print_exc()
|
|
print("\n")
|
|
|
|
@socketio.on("cancel")
|
|
def handle_cancel():
|
|
print(f">> Cancel processing requested")
|
|
self.canceled.set()
|
|
|
|
# TODO: I think this needs a safety mechanism.
|
|
@socketio.on("deleteImage")
|
|
def handle_delete_image(url, uuid, category):
|
|
try:
|
|
print(f'>> Delete requested "{url}"')
|
|
from send2trash import send2trash
|
|
|
|
path = self.get_image_path_from_url(url)
|
|
print(path)
|
|
send2trash(path)
|
|
socketio.emit(
|
|
"imageDeleted",
|
|
{"url": url, "uuid": uuid, "category": category},
|
|
)
|
|
except Exception as e:
|
|
self.socketio.emit("error", {"message": (str(e))})
|
|
print("\n")
|
|
|
|
traceback.print_exc()
|
|
print("\n")
|
|
|
|
# TODO: I think this needs a safety mechanism.
|
|
@socketio.on("uploadImage")
|
|
def handle_upload_image(bytes, name, destination):
|
|
try:
|
|
print(f'>> Image upload requested "{name}"')
|
|
file_path = self.save_file_unique_uuid_name(
|
|
bytes=bytes, name=name, path=self.init_image_path
|
|
)
|
|
mtime = os.path.getmtime(file_path)
|
|
(width, height) = Image.open(file_path).size
|
|
print(file_path)
|
|
socketio.emit(
|
|
"imageUploaded",
|
|
{
|
|
"url": self.get_url_from_image_path(file_path),
|
|
"mtime": mtime,
|
|
"width": width,
|
|
"height": height,
|
|
"category": "user",
|
|
"destination": destination,
|
|
},
|
|
)
|
|
except Exception as e:
|
|
self.socketio.emit("error", {"message": (str(e))})
|
|
print("\n")
|
|
|
|
traceback.print_exc()
|
|
print("\n")
|
|
|
|
# TODO: I think this needs a safety mechanism.
|
|
@socketio.on("uploadMaskImage")
|
|
def handle_upload_mask_image(bytes, name):
|
|
try:
|
|
print(f'>> Mask image upload requested "{name}"')
|
|
|
|
file_path = self.save_file_unique_uuid_name(
|
|
bytes=bytes, name=name, path=self.mask_image_path
|
|
)
|
|
|
|
socketio.emit(
|
|
"maskImageUploaded",
|
|
{
|
|
"url": self.get_url_from_image_path(file_path),
|
|
},
|
|
)
|
|
except Exception as e:
|
|
self.socketio.emit("error", {"message": (str(e))})
|
|
print("\n")
|
|
|
|
traceback.print_exc()
|
|
print("\n")
|
|
|
|
# App Functions
|
|
def get_system_config(self):
|
|
model_list = self.generate.model_cache.list_models()
|
|
return {
|
|
"model": "stable diffusion",
|
|
"model_id": args.model,
|
|
"model_hash": self.generate.model_hash,
|
|
"app_id": APP_ID,
|
|
"app_version": APP_VERSION,
|
|
"model_list": model_list,
|
|
}
|
|
|
|
def generate_images(
|
|
self, generation_parameters, esrgan_parameters, facetool_parameters
|
|
):
|
|
try:
|
|
self.canceled.clear()
|
|
|
|
step_index = 1
|
|
prior_variations = (
|
|
generation_parameters["with_variations"]
|
|
if "with_variations" in generation_parameters
|
|
else []
|
|
)
|
|
|
|
"""
|
|
TODO:
|
|
If a result image is used as an init image, and then deleted, we will want to be
|
|
able to use it as an init image in the future. Need to handle this case.
|
|
"""
|
|
|
|
# We need to give absolute paths to the generator, stash the URLs for later
|
|
init_img_url = None
|
|
mask_img_url = None
|
|
|
|
if "init_img" in generation_parameters:
|
|
init_img_url = generation_parameters["init_img"]
|
|
init_img_path = self.get_image_path_from_url(init_img_url)
|
|
generation_parameters["init_img"] = init_img_path
|
|
|
|
# if 'init_mask' in generation_parameters:
|
|
# mask_img_url = generation_parameters['init_mask']
|
|
# generation_parameters[
|
|
# 'init_mask'
|
|
# ] = self.get_image_path_from_url(
|
|
# generation_parameters['init_mask']
|
|
# )
|
|
|
|
if "init_mask" in generation_parameters:
|
|
# grab an Image of the init image
|
|
original_image = Image.open(init_img_path)
|
|
|
|
# copy a region from it which we will inpaint
|
|
cropped_init_image = copy_image_from_bounding_box(
|
|
original_image, **generation_parameters["bounding_box"]
|
|
)
|
|
generation_parameters["init_img"] = cropped_init_image
|
|
|
|
if generation_parameters["is_mask_empty"]:
|
|
generation_parameters["init_mask"] = None
|
|
else:
|
|
# grab an Image of the mask
|
|
mask_image = Image.open(
|
|
io.BytesIO(
|
|
base64.decodebytes(
|
|
bytes(generation_parameters["init_mask"], "utf-8")
|
|
)
|
|
)
|
|
)
|
|
generation_parameters["init_mask"] = mask_image
|
|
|
|
totalSteps = self.calculate_real_steps(
|
|
steps=generation_parameters["steps"],
|
|
strength=generation_parameters["strength"]
|
|
if "strength" in generation_parameters
|
|
else None,
|
|
has_init_image="init_img" in generation_parameters,
|
|
)
|
|
|
|
progress = Progress(generation_parameters=generation_parameters)
|
|
|
|
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
|
|
eventlet.sleep(0)
|
|
|
|
def image_progress(sample, step):
|
|
if self.canceled.is_set():
|
|
raise CanceledException
|
|
|
|
nonlocal step_index
|
|
nonlocal generation_parameters
|
|
nonlocal progress
|
|
|
|
progress.set_current_step(step + 1)
|
|
progress.set_current_status("Generating")
|
|
progress.set_current_status_has_steps(True)
|
|
|
|
if (
|
|
generation_parameters["progress_images"]
|
|
and step % generation_parameters['save_intermediates'] == 0
|
|
and step < generation_parameters["steps"] - 1
|
|
):
|
|
image = self.generate.sample_to_image(sample)
|
|
metadata = self.parameters_to_generated_image_metadata(
|
|
generation_parameters
|
|
)
|
|
command = parameters_to_command(generation_parameters)
|
|
|
|
(width, height) = image.size
|
|
|
|
path = self.save_result_image(
|
|
image,
|
|
command,
|
|
metadata,
|
|
self.intermediate_path,
|
|
step_index=step_index,
|
|
postprocessing=False,
|
|
)
|
|
|
|
step_index += 1
|
|
self.socketio.emit(
|
|
"intermediateResult",
|
|
{
|
|
"url": self.get_url_from_image_path(path),
|
|
"mtime": os.path.getmtime(path),
|
|
"metadata": metadata,
|
|
"width": width,
|
|
"height": height,
|
|
},
|
|
)
|
|
|
|
if generation_parameters["progress_latents"]:
|
|
image = self.generate.sample_to_lowres_estimated_image(sample)
|
|
(width, height) = image.size
|
|
width *= 8
|
|
height *= 8
|
|
buffered = io.BytesIO()
|
|
image.save(buffered, format="PNG")
|
|
img_base64 = "data:image/png;base64," + base64.b64encode(
|
|
buffered.getvalue()
|
|
).decode("UTF-8")
|
|
self.socketio.emit(
|
|
"intermediateResult",
|
|
{
|
|
"url": img_base64,
|
|
"isBase64": True,
|
|
"mtime": 0,
|
|
"metadata": {},
|
|
"width": width,
|
|
"height": height,
|
|
},
|
|
)
|
|
|
|
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
|
|
eventlet.sleep(0)
|
|
|
|
def image_done(image, seed, first_seed):
|
|
if self.canceled.is_set():
|
|
raise CanceledException
|
|
|
|
nonlocal generation_parameters
|
|
nonlocal esrgan_parameters
|
|
nonlocal facetool_parameters
|
|
nonlocal progress
|
|
|
|
step_index = 1
|
|
nonlocal prior_variations
|
|
|
|
# paste the inpainting image back onto the original
|
|
if "init_mask" in generation_parameters:
|
|
image = paste_image_into_bounding_box(
|
|
Image.open(init_img_path),
|
|
image,
|
|
**generation_parameters["bounding_box"],
|
|
)
|
|
|
|
progress.set_current_status("Generation Complete")
|
|
|
|
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
|
|
eventlet.sleep(0)
|
|
|
|
all_parameters = generation_parameters
|
|
postprocessing = False
|
|
|
|
if (
|
|
"variation_amount" in all_parameters
|
|
and all_parameters["variation_amount"] > 0
|
|
):
|
|
first_seed = first_seed or seed
|
|
this_variation = [[seed, all_parameters["variation_amount"]]]
|
|
all_parameters["with_variations"] = (
|
|
prior_variations + this_variation
|
|
)
|
|
all_parameters["seed"] = first_seed
|
|
elif "with_variations" in all_parameters:
|
|
all_parameters["seed"] = first_seed
|
|
else:
|
|
all_parameters["seed"] = seed
|
|
|
|
if self.canceled.is_set():
|
|
raise CanceledException
|
|
|
|
if esrgan_parameters:
|
|
progress.set_current_status("Upscaling")
|
|
progress.set_current_status_has_steps(False)
|
|
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
|
|
eventlet.sleep(0)
|
|
|
|
image = self.esrgan.process(
|
|
image=image,
|
|
upsampler_scale=esrgan_parameters["level"],
|
|
strength=esrgan_parameters["strength"],
|
|
seed=seed,
|
|
)
|
|
|
|
postprocessing = True
|
|
all_parameters["upscale"] = [
|
|
esrgan_parameters["level"],
|
|
esrgan_parameters["strength"],
|
|
]
|
|
|
|
if self.canceled.is_set():
|
|
raise CanceledException
|
|
|
|
if facetool_parameters:
|
|
if facetool_parameters["type"] == "gfpgan":
|
|
progress.set_current_status("Restoring Faces (GFPGAN)")
|
|
elif facetool_parameters["type"] == "codeformer":
|
|
progress.set_current_status("Restoring Faces (Codeformer)")
|
|
|
|
progress.set_current_status_has_steps(False)
|
|
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
|
|
eventlet.sleep(0)
|
|
|
|
if facetool_parameters["type"] == "gfpgan":
|
|
image = self.gfpgan.process(
|
|
image=image,
|
|
strength=facetool_parameters["strength"],
|
|
seed=seed,
|
|
)
|
|
elif facetool_parameters["type"] == "codeformer":
|
|
image = self.codeformer.process(
|
|
image=image,
|
|
strength=facetool_parameters["strength"],
|
|
fidelity=facetool_parameters["codeformer_fidelity"],
|
|
seed=seed,
|
|
device="cpu"
|
|
if str(self.generate.device) == "mps"
|
|
else self.generate.device,
|
|
)
|
|
all_parameters["codeformer_fidelity"] = facetool_parameters[
|
|
"codeformer_fidelity"
|
|
]
|
|
|
|
postprocessing = True
|
|
all_parameters["facetool_strength"] = facetool_parameters[
|
|
"strength"
|
|
]
|
|
all_parameters["facetool_type"] = facetool_parameters["type"]
|
|
|
|
progress.set_current_status("Saving Image")
|
|
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
|
|
eventlet.sleep(0)
|
|
|
|
# restore the stashed URLS and discard the paths, we are about to send the result to client
|
|
if "init_img" in all_parameters:
|
|
all_parameters["init_img"] = init_img_url
|
|
|
|
if "init_mask" in all_parameters:
|
|
all_parameters["init_mask"] = "" # TODO: store the mask in metadata
|
|
|
|
metadata = self.parameters_to_generated_image_metadata(all_parameters)
|
|
|
|
command = parameters_to_command(all_parameters)
|
|
|
|
(width, height) = image.size
|
|
|
|
path = self.save_result_image(
|
|
image,
|
|
command,
|
|
metadata,
|
|
self.result_path,
|
|
postprocessing=postprocessing,
|
|
)
|
|
|
|
print(f'>> Image generated: "{path}"')
|
|
self.write_log_message(f'[Generated] "{path}": {command}')
|
|
|
|
if progress.total_iterations > progress.current_iteration:
|
|
progress.set_current_step(1)
|
|
progress.set_current_status("Iteration complete")
|
|
progress.set_current_status_has_steps(False)
|
|
else:
|
|
progress.mark_complete()
|
|
|
|
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
|
|
eventlet.sleep(0)
|
|
|
|
self.socketio.emit(
|
|
"generationResult",
|
|
{
|
|
"url": self.get_url_from_image_path(path),
|
|
"mtime": os.path.getmtime(path),
|
|
"metadata": metadata,
|
|
"width": width,
|
|
"height": height,
|
|
},
|
|
)
|
|
eventlet.sleep(0)
|
|
|
|
progress.set_current_iteration(progress.current_iteration + 1)
|
|
|
|
self.generate.prompt2image(
|
|
**generation_parameters,
|
|
step_callback=image_progress,
|
|
image_callback=image_done,
|
|
)
|
|
|
|
except KeyboardInterrupt:
|
|
self.socketio.emit("processingCanceled")
|
|
raise
|
|
except CanceledException:
|
|
self.socketio.emit("processingCanceled")
|
|
pass
|
|
except Exception as e:
|
|
print(e)
|
|
self.socketio.emit("error", {"message": (str(e))})
|
|
print("\n")
|
|
|
|
traceback.print_exc()
|
|
print("\n")
|
|
|
|
def parameters_to_generated_image_metadata(self, parameters):
|
|
try:
|
|
# top-level metadata minus `image` or `images`
|
|
metadata = self.get_system_config()
|
|
# remove any image keys not mentioned in RFC #266
|
|
rfc266_img_fields = [
|
|
"type",
|
|
"postprocessing",
|
|
"sampler",
|
|
"prompt",
|
|
"seed",
|
|
"variations",
|
|
"steps",
|
|
"cfg_scale",
|
|
"threshold",
|
|
"perlin",
|
|
"step_number",
|
|
"width",
|
|
"height",
|
|
"extra",
|
|
"seamless",
|
|
"hires_fix",
|
|
]
|
|
|
|
rfc_dict = {}
|
|
|
|
for item in parameters.items():
|
|
key, value = item
|
|
if key in rfc266_img_fields:
|
|
rfc_dict[key] = value
|
|
|
|
postprocessing = []
|
|
|
|
# 'postprocessing' is either null or an
|
|
if "facetool_strength" in parameters:
|
|
facetool_parameters = {
|
|
"type": str(parameters["facetool_type"]),
|
|
"strength": float(parameters["facetool_strength"]),
|
|
}
|
|
|
|
if parameters["facetool_type"] == "codeformer":
|
|
facetool_parameters["fidelity"] = float(
|
|
parameters["codeformer_fidelity"]
|
|
)
|
|
|
|
postprocessing.append(facetool_parameters)
|
|
|
|
if "upscale" in parameters:
|
|
postprocessing.append(
|
|
{
|
|
"type": "esrgan",
|
|
"scale": int(parameters["upscale"][0]),
|
|
"strength": float(parameters["upscale"][1]),
|
|
}
|
|
)
|
|
|
|
rfc_dict["postprocessing"] = (
|
|
postprocessing if len(postprocessing) > 0 else None
|
|
)
|
|
|
|
# semantic drift
|
|
rfc_dict["sampler"] = parameters["sampler_name"]
|
|
|
|
# display weighted subprompts (liable to change)
|
|
subprompts = split_weighted_subprompts(
|
|
parameters["prompt"], skip_normalize=True
|
|
)
|
|
subprompts = [{"prompt": x[0], "weight": x[1]} for x in subprompts]
|
|
rfc_dict["prompt"] = subprompts
|
|
|
|
# 'variations' should always exist and be an array, empty or consisting of {'seed': seed, 'weight': weight} pairs
|
|
variations = []
|
|
|
|
if "with_variations" in parameters:
|
|
variations = [
|
|
{"seed": x[0], "weight": x[1]}
|
|
for x in parameters["with_variations"]
|
|
]
|
|
|
|
rfc_dict["variations"] = variations
|
|
|
|
if "init_img" in parameters:
|
|
rfc_dict["type"] = "img2img"
|
|
rfc_dict["strength"] = parameters["strength"]
|
|
rfc_dict["fit"] = parameters["fit"] # TODO: Noncompliant
|
|
rfc_dict["orig_hash"] = calculate_init_img_hash(
|
|
self.get_image_path_from_url(parameters["init_img"])
|
|
)
|
|
rfc_dict["init_image_path"] = parameters[
|
|
"init_img"
|
|
] # TODO: Noncompliant
|
|
# if 'init_mask' in parameters:
|
|
# rfc_dict['mask_hash'] = calculate_init_img_hash(
|
|
# self.get_image_path_from_url(parameters['init_mask'])
|
|
# ) # TODO: Noncompliant
|
|
# rfc_dict['mask_image_path'] = parameters[
|
|
# 'init_mask'
|
|
# ] # TODO: Noncompliant
|
|
else:
|
|
rfc_dict["type"] = "txt2img"
|
|
|
|
metadata["image"] = rfc_dict
|
|
|
|
return metadata
|
|
|
|
except Exception as e:
|
|
self.socketio.emit("error", {"message": (str(e))})
|
|
print("\n")
|
|
|
|
traceback.print_exc()
|
|
print("\n")
|
|
|
|
def parameters_to_post_processed_image_metadata(
|
|
self, parameters, original_image_path
|
|
):
|
|
try:
|
|
current_metadata = retrieve_metadata(original_image_path)["sd-metadata"]
|
|
postprocessing_metadata = {}
|
|
|
|
"""
|
|
if we don't have an original image metadata to reconstruct,
|
|
need to record the original image and its hash
|
|
"""
|
|
if "image" not in current_metadata:
|
|
current_metadata["image"] = {}
|
|
|
|
orig_hash = calculate_init_img_hash(
|
|
self.get_image_path_from_url(original_image_path)
|
|
)
|
|
|
|
postprocessing_metadata["orig_path"] = (original_image_path,)
|
|
postprocessing_metadata["orig_hash"] = orig_hash
|
|
|
|
if parameters["type"] == "esrgan":
|
|
postprocessing_metadata["type"] = "esrgan"
|
|
postprocessing_metadata["scale"] = parameters["upscale"][0]
|
|
postprocessing_metadata["strength"] = parameters["upscale"][1]
|
|
elif parameters["type"] == "gfpgan":
|
|
postprocessing_metadata["type"] = "gfpgan"
|
|
postprocessing_metadata["strength"] = parameters["facetool_strength"]
|
|
elif parameters["type"] == "codeformer":
|
|
postprocessing_metadata["type"] = "codeformer"
|
|
postprocessing_metadata["strength"] = parameters["facetool_strength"]
|
|
postprocessing_metadata["fidelity"] = parameters["codeformer_fidelity"]
|
|
|
|
else:
|
|
raise TypeError(f"Invalid type: {parameters['type']}")
|
|
|
|
if "postprocessing" in current_metadata["image"] and isinstance(
|
|
current_metadata["image"]["postprocessing"], list
|
|
):
|
|
current_metadata["image"]["postprocessing"].append(
|
|
postprocessing_metadata
|
|
)
|
|
else:
|
|
current_metadata["image"]["postprocessing"] = [postprocessing_metadata]
|
|
|
|
return current_metadata
|
|
|
|
except Exception as e:
|
|
self.socketio.emit("error", {"message": (str(e))})
|
|
print("\n")
|
|
|
|
traceback.print_exc()
|
|
print("\n")
|
|
|
|
def save_result_image(
|
|
self,
|
|
image,
|
|
command,
|
|
metadata,
|
|
output_dir,
|
|
step_index=None,
|
|
postprocessing=False,
|
|
):
|
|
try:
|
|
pngwriter = PngWriter(output_dir)
|
|
|
|
number_prefix = pngwriter.unique_prefix()
|
|
|
|
uuid = uuid4().hex
|
|
truncated_uuid = uuid[:8]
|
|
|
|
seed = "unknown_seed"
|
|
|
|
if "image" in metadata:
|
|
if "seed" in metadata["image"]:
|
|
seed = metadata["image"]["seed"]
|
|
|
|
filename = f"{number_prefix}.{truncated_uuid}.{seed}"
|
|
|
|
if step_index:
|
|
filename += f".{step_index}"
|
|
if postprocessing:
|
|
filename += f".postprocessed"
|
|
|
|
filename += ".png"
|
|
|
|
path = pngwriter.save_image_and_prompt_to_png(
|
|
image=image,
|
|
dream_prompt=command,
|
|
metadata=metadata,
|
|
name=filename,
|
|
)
|
|
|
|
return os.path.abspath(path)
|
|
|
|
except Exception as e:
|
|
self.socketio.emit("error", {"message": (str(e))})
|
|
print("\n")
|
|
|
|
traceback.print_exc()
|
|
print("\n")
|
|
|
|
def make_unique_init_image_filename(self, name):
|
|
try:
|
|
uuid = uuid4().hex
|
|
split = os.path.splitext(name)
|
|
name = f"{split[0]}.{uuid}{split[1]}"
|
|
return name
|
|
except Exception as e:
|
|
self.socketio.emit("error", {"message": (str(e))})
|
|
print("\n")
|
|
|
|
traceback.print_exc()
|
|
print("\n")
|
|
|
|
def calculate_real_steps(self, steps, strength, has_init_image):
|
|
import math
|
|
|
|
return math.floor(strength * steps) if has_init_image else steps
|
|
|
|
def write_log_message(self, message):
|
|
"""Logs the filename and parameters used to generate or process that image to log file"""
|
|
try:
|
|
message = f"{message}\n"
|
|
with open(self.log_path, "a", encoding="utf-8") as file:
|
|
file.writelines(message)
|
|
|
|
except Exception as e:
|
|
self.socketio.emit("error", {"message": (str(e))})
|
|
print("\n")
|
|
|
|
traceback.print_exc()
|
|
print("\n")
|
|
|
|
def get_image_path_from_url(self, url):
|
|
"""Given a url to an image used by the client, returns the absolute file path to that image"""
|
|
try:
|
|
if "init-images" in url:
|
|
return os.path.abspath(
|
|
os.path.join(self.init_image_path, os.path.basename(url))
|
|
)
|
|
elif "mask-images" in url:
|
|
return os.path.abspath(
|
|
os.path.join(self.mask_image_path, os.path.basename(url))
|
|
)
|
|
elif "intermediates" in url:
|
|
return os.path.abspath(
|
|
os.path.join(self.intermediate_path, os.path.basename(url))
|
|
)
|
|
else:
|
|
return os.path.abspath(
|
|
os.path.join(self.result_path, os.path.basename(url))
|
|
)
|
|
except Exception as e:
|
|
self.socketio.emit("error", {"message": (str(e))})
|
|
print("\n")
|
|
|
|
traceback.print_exc()
|
|
print("\n")
|
|
|
|
def get_url_from_image_path(self, path):
|
|
"""Given an absolute file path to an image, returns the URL that the client can use to load the image"""
|
|
try:
|
|
if "init-images" in path:
|
|
return os.path.join(self.init_image_url, os.path.basename(path))
|
|
elif "mask-images" in path:
|
|
return os.path.join(self.mask_image_url, os.path.basename(path))
|
|
elif "intermediates" in path:
|
|
return os.path.join(self.intermediate_url, os.path.basename(path))
|
|
else:
|
|
return os.path.join(self.result_url, os.path.basename(path))
|
|
except Exception as e:
|
|
self.socketio.emit("error", {"message": (str(e))})
|
|
print("\n")
|
|
|
|
traceback.print_exc()
|
|
print("\n")
|
|
|
|
def save_file_unique_uuid_name(self, bytes, name, path):
|
|
try:
|
|
uuid = uuid4().hex
|
|
truncated_uuid = uuid[:8]
|
|
|
|
split = os.path.splitext(name)
|
|
name = f"{split[0]}.{truncated_uuid}{split[1]}"
|
|
|
|
file_path = os.path.join(path, name)
|
|
|
|
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
|
|
|
newFile = open(file_path, "wb")
|
|
newFile.write(bytes)
|
|
|
|
return file_path
|
|
except Exception as e:
|
|
self.socketio.emit("error", {"message": (str(e))})
|
|
print("\n")
|
|
|
|
traceback.print_exc()
|
|
print("\n")
|
|
|
|
|
|
class Progress:
|
|
def __init__(self, generation_parameters=None):
|
|
self.current_step = 1
|
|
self.total_steps = (
|
|
self._calculate_real_steps(
|
|
steps=generation_parameters["steps"],
|
|
strength=generation_parameters["strength"]
|
|
if "strength" in generation_parameters
|
|
else None,
|
|
has_init_image="init_img" in generation_parameters,
|
|
)
|
|
if generation_parameters
|
|
else 1
|
|
)
|
|
self.current_iteration = 1
|
|
self.total_iterations = (
|
|
generation_parameters["iterations"] if generation_parameters else 1
|
|
)
|
|
self.current_status = "Preparing"
|
|
self.is_processing = True
|
|
self.current_status_has_steps = False
|
|
self.has_error = False
|
|
|
|
def set_current_step(self, current_step):
|
|
self.current_step = current_step
|
|
|
|
def set_total_steps(self, total_steps):
|
|
self.total_steps = total_steps
|
|
|
|
def set_current_iteration(self, current_iteration):
|
|
self.current_iteration = current_iteration
|
|
|
|
def set_total_iterations(self, total_iterations):
|
|
self.total_iterations = total_iterations
|
|
|
|
def set_current_status(self, current_status):
|
|
self.current_status = current_status
|
|
|
|
def set_is_processing(self, is_processing):
|
|
self.is_processing = is_processing
|
|
|
|
def set_current_status_has_steps(self, current_status_has_steps):
|
|
self.current_status_has_steps = current_status_has_steps
|
|
|
|
def set_has_error(self, has_error):
|
|
self.has_error = has_error
|
|
|
|
def mark_complete(self):
|
|
self.current_status = "Processing Complete"
|
|
self.current_step = 0
|
|
self.total_steps = 0
|
|
self.current_iteration = 0
|
|
self.total_iterations = 0
|
|
self.is_processing = False
|
|
|
|
def to_formatted_dict(
|
|
self,
|
|
):
|
|
return {
|
|
"currentStep": self.current_step,
|
|
"totalSteps": self.total_steps,
|
|
"currentIteration": self.current_iteration,
|
|
"totalIterations": self.total_iterations,
|
|
"currentStatus": self.current_status,
|
|
"isProcessing": self.is_processing,
|
|
"currentStatusHasSteps": self.current_status_has_steps,
|
|
"hasError": self.has_error,
|
|
}
|
|
|
|
def _calculate_real_steps(self, steps, strength, has_init_image):
|
|
return math.floor(strength * steps) if has_init_image else steps
|
|
|
|
|
|
class CanceledException(Exception):
|
|
pass
|
|
|
|
|
|
"""
|
|
Crops an image to a bounding box.
|
|
"""
|
|
|
|
|
|
def copy_image_from_bounding_box(image, x, y, width, height):
|
|
with image as im:
|
|
bounds = (x, y, x + width, y + height)
|
|
im_cropped = im.crop(bounds)
|
|
return im_cropped
|
|
|
|
|
|
"""
|
|
Pastes an image onto another with a bounding box.
|
|
"""
|
|
|
|
|
|
def paste_image_into_bounding_box(recipient_image, donor_image, x, y, width, height):
|
|
with recipient_image as im:
|
|
bounds = (x, y, x + width, y + height)
|
|
im.paste(donor_image, bounds)
|
|
return recipient_image
|