InvokeAI/backend/invoke_ai_web_server.py

1251 lines
45 KiB
Python
Raw Normal View History

2022-09-26 20:15:32 +00:00
import eventlet
import glob
import os
import shutil
import mimetypes
import traceback
import math
2022-10-27 04:24:00 +00:00
import io
import base64
2022-09-26 20:15:32 +00:00
from flask import Flask, redirect, send_from_directory
from flask_socketio import SocketIO
from PIL import Image
from uuid import uuid4
from threading import Event
from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata
2022-10-24 09:16:52 +00:00
from ldm.invoke.prompt_parser import split_weighted_subprompts
2022-09-26 20:15:32 +00:00
from backend.modules.parameters import parameters_to_command
2022-09-26 20:15:32 +00:00
# Loading Arguments
opt = Args()
args = opt.parse_args()
class InvokeAIWebServer:
def __init__(self, generate, gfpgan, codeformer, esrgan) -> None:
self.host = args.host
self.port = args.port
self.generate = generate
self.gfpgan = gfpgan
self.codeformer = codeformer
self.esrgan = esrgan
self.canceled = Event()
def run(self):
self.setup_app()
self.setup_flask()
def setup_flask(self):
# Fix missing mimetypes on Windows
mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("text/css", ".css")
2022-09-26 20:15:32 +00:00
# Socket IO
logger = True if args.web_verbose else False
engineio_logger = True if args.web_verbose else False
max_http_buffer_size = 10000000
2022-10-07 23:12:51 +00:00
socketio_args = {
"logger": logger,
"engineio_logger": engineio_logger,
"max_http_buffer_size": max_http_buffer_size,
"ping_interval": (50, 50),
"ping_timeout": 60,
2022-10-07 23:12:51 +00:00
}
if opt.cors:
socketio_args["cors_allowed_origins"] = opt.cors
2022-09-26 20:15:32 +00:00
self.app = Flask(
__name__, static_url_path="", static_folder="../frontend/dist/"
2022-09-26 20:15:32 +00:00
)
2022-10-27 04:24:00 +00:00
self.socketio = SocketIO(self.app, **socketio_args)
2022-09-26 20:15:32 +00:00
# Keep Server Alive Route
@self.app.route("/flaskwebgui-keep-server-alive")
def keep_alive():
return {"message": "Server Running"}
2022-09-26 20:15:32 +00:00
# Outputs Route
self.app.config["OUTPUTS_FOLDER"] = os.path.abspath(args.outdir)
2022-09-26 20:15:32 +00:00
@self.app.route("/outputs/<path:file_path>")
def outputs(file_path):
return send_from_directory(self.app.config["OUTPUTS_FOLDER"], file_path)
2022-09-26 20:15:32 +00:00
# Base Route
@self.app.route("/")
2022-09-26 20:15:32 +00:00
def serve():
if args.web_develop:
return redirect("http://127.0.0.1:5173")
2022-09-26 20:15:32 +00:00
else:
return send_from_directory(self.app.static_folder, "index.html")
2022-09-26 20:15:32 +00:00
self.load_socketio_listeners(self.socketio)
if args.gui:
print(">> Launching Invoke AI GUI")
close_server_on_exit = True
if args.web_develop:
close_server_on_exit = False
try:
from flaskwebgui import FlaskUI
2022-10-27 04:24:00 +00:00
FlaskUI(
app=self.app,
socketio=self.socketio,
start_server="flask-socketio",
host=self.host,
port=self.port,
width=1600,
height=1000,
idle_interval=10,
close_server_on_exit=close_server_on_exit,
).run()
except KeyboardInterrupt:
import sys
sys.exit(0)
2022-09-26 20:15:32 +00:00
else:
print(">> Started Invoke AI Web Server!")
if self.host == "0.0.0.0":
print(
f"Point your browser at http://localhost:{self.port} or use the host's DNS name or IP address."
)
else:
print(
">> Default host address now 127.0.0.1 (localhost). Use --host 0.0.0.0 to bind any address."
)
print(f">> Point your browser at http://{self.host}:{self.port}")
self.socketio.run(app=self.app, host=self.host, port=self.port)
2022-09-26 20:15:32 +00:00
def setup_app(self):
self.result_url = "outputs/"
self.init_image_url = "outputs/init-images/"
self.mask_image_url = "outputs/mask-images/"
self.intermediate_url = "outputs/intermediates/"
2022-09-26 20:15:32 +00:00
# location for "finished" images
self.result_path = args.outdir
# temporary path for intermediates
self.intermediate_path = os.path.join(self.result_path, "intermediates/")
2022-09-26 20:15:32 +00:00
# path for user-uploaded init images and masks
self.init_image_path = os.path.join(self.result_path, "init-images/")
self.mask_image_path = os.path.join(self.result_path, "mask-images/")
2022-09-26 20:15:32 +00:00
# txt log
self.log_path = os.path.join(self.result_path, "invoke_log.txt")
2022-09-26 20:15:32 +00:00
# make all output paths
[
os.makedirs(path, exist_ok=True)
for path in [
self.result_path,
self.intermediate_path,
self.init_image_path,
self.mask_image_path,
]
]
def load_socketio_listeners(self, socketio):
@socketio.on("requestSystemConfig")
2022-09-26 20:15:32 +00:00
def handle_request_capabilities():
print(f">> System config requested")
2022-09-26 20:15:32 +00:00
config = self.get_system_config()
socketio.emit("systemConfig", config)
2022-09-26 20:15:32 +00:00
@socketio.on("requestModelChange")
2022-10-28 05:47:15 +00:00
def handle_set_model(model_name: str):
try:
print(f">> Model change requested: {model_name}")
model = self.generate.set_model(model_name)
model_list = self.generate.model_cache.list_models()
if model is None:
socketio.emit(
"modelChangeFailed",
{"model_name": model_name, "model_list": model_list},
)
else:
socketio.emit(
"modelChanged",
{"model_name": model_name, "model_list": model_list},
)
except Exception as e:
self.socketio.emit("error", {"message": (str(e))})
print("\n")
traceback.print_exc()
print("\n")
2022-10-28 05:47:15 +00:00
@socketio.on("requestLatestImages")
2022-10-28 12:15:03 +00:00
def handle_request_latest_images(category, latest_mtime):
try:
2022-10-28 12:15:03 +00:00
base_path = (
self.result_path if category == "result" else self.init_image_path
)
paths = glob.glob(os.path.join(base_path, "*.png"))
2022-09-26 20:15:32 +00:00
image_paths = sorted(
paths, key=lambda x: os.path.getmtime(x), reverse=True
2022-09-26 20:15:32 +00:00
)
image_paths = list(
filter(
lambda x: os.path.getmtime(x) > latest_mtime,
image_paths,
)
2022-09-26 20:15:32 +00:00
)
image_array = []
2022-09-26 20:15:32 +00:00
for path in image_paths:
metadata = retrieve_metadata(path)
2022-10-27 04:24:00 +00:00
(width, height) = Image.open(path).size
image_array.append(
{
"url": self.get_url_from_image_path(path),
"mtime": os.path.getmtime(path),
"metadata": metadata["sd-metadata"],
"width": width,
"height": height,
2022-10-28 12:15:03 +00:00
"category": category,
}
)
2022-09-26 20:15:32 +00:00
socketio.emit(
"galleryImages",
2022-10-28 12:15:03 +00:00
{"images": image_array, "category": category},
2022-09-26 20:15:32 +00:00
)
except Exception as e:
self.socketio.emit("error", {"message": (str(e))})
print("\n")
2022-09-26 20:15:32 +00:00
traceback.print_exc()
print("\n")
2022-09-26 20:15:32 +00:00
@socketio.on("requestImages")
2022-10-28 12:15:03 +00:00
def handle_request_images(category, earliest_mtime=None):
try:
page_size = 50
2022-09-26 20:15:32 +00:00
2022-10-28 12:15:03 +00:00
base_path = (
self.result_path if category == "result" else self.init_image_path
)
paths = glob.glob(os.path.join(base_path, "*.png"))
2022-09-26 20:15:32 +00:00
image_paths = sorted(
paths, key=lambda x: os.path.getmtime(x), reverse=True
)
if earliest_mtime:
image_paths = list(
filter(
lambda x: os.path.getmtime(x) < earliest_mtime,
image_paths,
)
)
2022-09-26 20:15:32 +00:00
areMoreImagesAvailable = len(image_paths) >= page_size
image_paths = image_paths[slice(0, page_size)]
2022-09-26 20:15:32 +00:00
image_array = []
2022-09-26 20:15:32 +00:00
for path in image_paths:
metadata = retrieve_metadata(path)
2022-10-27 04:24:00 +00:00
(width, height) = Image.open(path).size
image_array.append(
{
"url": self.get_url_from_image_path(path),
"mtime": os.path.getmtime(path),
"metadata": metadata["sd-metadata"],
"width": width,
"height": height,
2022-10-28 12:15:03 +00:00
"category": category,
}
)
2022-09-26 20:15:32 +00:00
socketio.emit(
"galleryImages",
{
"images": image_array,
"areMoreImagesAvailable": areMoreImagesAvailable,
2022-10-28 12:15:03 +00:00
"category": category,
},
)
except Exception as e:
self.socketio.emit("error", {"message": (str(e))})
print("\n")
2022-09-26 20:15:32 +00:00
traceback.print_exc()
print("\n")
2022-09-26 20:15:32 +00:00
@socketio.on("generateImage")
def handle_generate_image_event(
2022-10-18 13:57:40 +00:00
generation_parameters, esrgan_parameters, facetool_parameters
):
try:
2022-10-27 04:24:00 +00:00
# truncate long init_mask base64 if needed
if "init_mask" in generation_parameters:
2022-10-27 04:24:00 +00:00
printable_parameters = {
**generation_parameters,
"init_mask": generation_parameters["init_mask"][:20] + "...",
2022-10-27 04:24:00 +00:00
}
print(
f">> Image generation requested: {printable_parameters}\nESRGAN parameters: {esrgan_parameters}\nFacetool parameters: {facetool_parameters}"
2022-10-27 04:24:00 +00:00
)
else:
print(
f">> Image generation requested: {generation_parameters}\nESRGAN parameters: {esrgan_parameters}\nFacetool parameters: {facetool_parameters}"
2022-10-27 04:24:00 +00:00
)
self.generate_images(
2022-10-27 04:24:00 +00:00
generation_parameters,
esrgan_parameters,
facetool_parameters,
)
except Exception as e:
self.socketio.emit("error", {"message": (str(e))})
print("\n")
2022-09-26 20:15:32 +00:00
traceback.print_exc()
print("\n")
2022-09-26 20:15:32 +00:00
@socketio.on("runPostprocessing")
def handle_run_postprocessing(original_image, postprocessing_parameters):
try:
print(
f'>> Postprocessing requested for "{original_image["url"]}": {postprocessing_parameters}'
)
2022-09-26 20:15:32 +00:00
progress = Progress()
2022-09-26 20:15:32 +00:00
socketio.emit("progressUpdate", progress.to_formatted_dict())
eventlet.sleep(0)
2022-09-26 20:15:32 +00:00
original_image_path = self.get_image_path_from_url(
original_image["url"]
)
image = Image.open(original_image_path)
2022-09-26 20:15:32 +00:00
seed = (
original_image["metadata"]["seed"]
if "metadata" in original_image
and "seed" in original_image["metadata"]
else "unknown_seed"
)
2022-09-26 20:15:32 +00:00
if postprocessing_parameters["type"] == "esrgan":
progress.set_current_status("Upscaling (ESRGAN)")
elif postprocessing_parameters["type"] == "gfpgan":
progress.set_current_status("Restoring Faces (GFPGAN)")
elif postprocessing_parameters["type"] == "codeformer":
progress.set_current_status("Restoring Faces (Codeformer)")
2022-09-26 20:15:32 +00:00
socketio.emit("progressUpdate", progress.to_formatted_dict())
eventlet.sleep(0)
2022-09-26 20:15:32 +00:00
if postprocessing_parameters["type"] == "esrgan":
image = self.esrgan.process(
image=image,
upsampler_scale=postprocessing_parameters["upscale"][0],
strength=postprocessing_parameters["upscale"][1],
seed=seed,
)
elif postprocessing_parameters["type"] == "gfpgan":
image = self.gfpgan.process(
image=image,
strength=postprocessing_parameters["facetool_strength"],
seed=seed,
)
elif postprocessing_parameters["type"] == "codeformer":
2022-10-18 13:57:40 +00:00
image = self.codeformer.process(
image=image,
strength=postprocessing_parameters["facetool_strength"],
fidelity=postprocessing_parameters["codeformer_fidelity"],
2022-10-18 13:57:40 +00:00
seed=seed,
device="cpu"
if str(self.generate.device) == "mps"
2022-10-27 04:24:00 +00:00
else self.generate.device,
2022-10-18 13:57:40 +00:00
)
else:
raise TypeError(
f'{postprocessing_parameters["type"]} is not a valid postprocessing type'
)
2022-09-26 20:15:32 +00:00
progress.set_current_status("Saving Image")
socketio.emit("progressUpdate", progress.to_formatted_dict())
eventlet.sleep(0)
2022-09-26 20:15:32 +00:00
postprocessing_parameters["seed"] = seed
metadata = self.parameters_to_post_processed_image_metadata(
parameters=postprocessing_parameters,
original_image_path=original_image_path,
)
2022-09-26 20:15:32 +00:00
command = parameters_to_command(postprocessing_parameters)
2022-09-26 20:15:32 +00:00
2022-10-27 04:24:00 +00:00
(width, height) = image.size
path = self.save_result_image(
image,
command,
metadata,
self.result_path,
postprocessing=postprocessing_parameters["type"],
)
2022-09-26 20:15:32 +00:00
self.write_log_message(
f'[Postprocessed] "{original_image_path}" > "{path}": {postprocessing_parameters}'
)
progress.mark_complete()
socketio.emit("progressUpdate", progress.to_formatted_dict())
eventlet.sleep(0)
socketio.emit(
"postprocessingResult",
{
"url": self.get_url_from_image_path(path),
"mtime": os.path.getmtime(path),
"metadata": metadata,
"width": width,
"height": height,
},
)
except Exception as e:
self.socketio.emit("error", {"message": (str(e))})
print("\n")
traceback.print_exc()
print("\n")
2022-09-26 20:15:32 +00:00
@socketio.on("cancel")
2022-09-26 20:15:32 +00:00
def handle_cancel():
print(f">> Cancel processing requested")
2022-09-26 20:15:32 +00:00
self.canceled.set()
# TODO: I think this needs a safety mechanism.
@socketio.on("deleteImage")
2022-10-28 12:15:03 +00:00
def handle_delete_image(url, uuid, category):
try:
2022-10-04 02:51:59 +00:00
print(f'>> Delete requested "{url}"')
from send2trash import send2trash
2022-09-26 20:15:32 +00:00
2022-10-04 02:51:59 +00:00
path = self.get_image_path_from_url(url)
2022-10-28 12:15:03 +00:00
print(path)
send2trash(path)
2022-10-28 12:15:03 +00:00
socketio.emit(
2022-10-29 00:10:48 +00:00
"imageDeleted",
{"url": url, "uuid": uuid, "category": category},
2022-10-28 12:15:03 +00:00
)
except Exception as e:
self.socketio.emit("error", {"message": (str(e))})
print("\n")
traceback.print_exc()
print("\n")
2022-09-26 20:15:32 +00:00
# TODO: I think this needs a safety mechanism.
@socketio.on("uploadImage")
def handle_upload_image(bytes, name, destination):
try:
print(f'>> Image upload requested "{name}"')
file_path = self.save_file_unique_uuid_name(
bytes=bytes, name=name, path=self.init_image_path
)
2022-10-28 12:15:03 +00:00
mtime = os.path.getmtime(file_path)
(width, height) = Image.open(file_path).size
print(file_path)
socketio.emit(
"imageUploaded",
{
"url": self.get_url_from_image_path(file_path),
2022-10-28 12:15:03 +00:00
"mtime": mtime,
"width": width,
"height": height,
"category": "user",
"destination": destination,
},
)
except Exception as e:
self.socketio.emit("error", {"message": (str(e))})
print("\n")
traceback.print_exc()
print("\n")
2022-09-26 20:15:32 +00:00
# TODO: I think this needs a safety mechanism.
@socketio.on("uploadMaskImage")
2022-09-26 20:15:32 +00:00
def handle_upload_mask_image(bytes, name):
try:
print(f'>> Mask image upload requested "{name}"')
file_path = self.save_file_unique_uuid_name(
bytes=bytes, name=name, path=self.mask_image_path
)
socketio.emit(
"maskImageUploaded",
{
"url": self.get_url_from_image_path(file_path),
},
)
except Exception as e:
self.socketio.emit("error", {"message": (str(e))})
print("\n")
traceback.print_exc()
print("\n")
2022-09-26 20:15:32 +00:00
# App Functions
def get_system_config(self):
2022-10-28 05:47:15 +00:00
model_list = self.generate.model_cache.list_models()
2022-09-26 20:15:32 +00:00
return {
"model": "stable diffusion",
"model_id": args.model,
"model_hash": self.generate.model_hash,
"app_id": APP_ID,
"app_version": APP_VERSION,
"model_list": model_list,
2022-09-26 20:15:32 +00:00
}
def generate_images(
2022-10-18 13:57:40 +00:00
self, generation_parameters, esrgan_parameters, facetool_parameters
2022-09-26 20:15:32 +00:00
):
try:
self.canceled.clear()
2022-09-26 20:15:32 +00:00
step_index = 1
prior_variations = (
generation_parameters["with_variations"]
if "with_variations" in generation_parameters
else []
)
"""
TODO:
If a result image is used as an init image, and then deleted, we will want to be
able to use it as an init image in the future. Need to handle this case.
"""
# We need to give absolute paths to the generator, stash the URLs for later
init_img_url = None
mask_img_url = None
if "init_img" in generation_parameters:
init_img_url = generation_parameters["init_img"]
2022-10-27 04:24:00 +00:00
init_img_path = self.get_image_path_from_url(init_img_url)
generation_parameters["init_img"] = init_img_path
2022-10-27 04:24:00 +00:00
# if 'init_mask' in generation_parameters:
# mask_img_url = generation_parameters['init_mask']
# generation_parameters[
# 'init_mask'
# ] = self.get_image_path_from_url(
# generation_parameters['init_mask']
# )
2022-09-26 20:15:32 +00:00
if "init_mask" in generation_parameters:
2022-10-27 04:24:00 +00:00
# grab an Image of the init image
original_image = Image.open(init_img_path)
# copy a region from it which we will inpaint
cropped_init_image = copy_image_from_bounding_box(
original_image, **generation_parameters["bounding_box"]
2022-10-27 04:24:00 +00:00
)
generation_parameters["init_img"] = cropped_init_image
2022-10-27 04:24:00 +00:00
if generation_parameters["is_mask_empty"]:
generation_parameters["init_mask"] = None
else:
# grab an Image of the mask
mask_image = Image.open(
io.BytesIO(
base64.decodebytes(
bytes(generation_parameters["init_mask"], "utf-8")
)
2022-10-27 04:24:00 +00:00
)
)
# crop the mask image
cropped_mask_image = copy_image_from_bounding_box(
mask_image, **generation_parameters["bounding_box"]
)
generation_parameters["init_mask"] = cropped_mask_image
2022-09-26 20:15:32 +00:00
totalSteps = self.calculate_real_steps(
steps=generation_parameters["steps"],
strength=generation_parameters["strength"]
if "strength" in generation_parameters
else None,
has_init_image="init_img" in generation_parameters,
)
2022-09-26 20:15:32 +00:00
progress = Progress(generation_parameters=generation_parameters)
2022-09-26 20:15:32 +00:00
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
eventlet.sleep(0)
2022-09-26 20:15:32 +00:00
def image_progress(sample, step):
if self.canceled.is_set():
raise CanceledException
nonlocal step_index
nonlocal generation_parameters
nonlocal progress
progress.set_current_step(step + 1)
progress.set_current_status("Generating")
progress.set_current_status_has_steps(True)
if (
generation_parameters["progress_images"]
and step % 5 == 0
and step < generation_parameters["steps"] - 1
):
image = self.generate.sample_to_image(sample)
metadata = self.parameters_to_generated_image_metadata(
generation_parameters
)
command = parameters_to_command(generation_parameters)
2022-10-27 04:24:00 +00:00
(width, height) = image.size
path = self.save_result_image(
image,
command,
metadata,
self.intermediate_path,
step_index=step_index,
postprocessing=False,
)
2022-09-26 20:15:32 +00:00
step_index += 1
self.socketio.emit(
"intermediateResult",
{
"url": self.get_url_from_image_path(path),
"mtime": os.path.getmtime(path),
"metadata": metadata,
"width": width,
"height": height,
},
)
if generation_parameters['progress_latents']:
image = self.generate.sample_to_lowres_estimated_image(sample)
(width, height) = image.size
width *= 8
height *= 8
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_base64 = "data:image/png;base64," + base64.b64encode(buffered.getvalue()).decode('UTF-8')
self.socketio.emit(
"intermediateResult",
{
"url": img_base64,
"isBase64": True,
"mtime": 0,
"metadata": {},
"width": width,
"height": height,
}
)
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
eventlet.sleep(0)
def image_done(image, seed, first_seed):
if self.canceled.is_set():
raise CanceledException
nonlocal generation_parameters
nonlocal esrgan_parameters
2022-10-18 13:57:40 +00:00
nonlocal facetool_parameters
nonlocal progress
step_index = 1
nonlocal prior_variations
progress.set_current_status("Generation Complete")
2022-09-26 20:15:32 +00:00
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
eventlet.sleep(0)
2022-09-26 20:15:32 +00:00
all_parameters = generation_parameters
postprocessing = False
if (
"variation_amount" in all_parameters
and all_parameters["variation_amount"] > 0
):
first_seed = first_seed or seed
this_variation = [[seed, all_parameters["variation_amount"]]]
all_parameters["with_variations"] = (
prior_variations + this_variation
)
all_parameters["seed"] = first_seed
elif "with_variations" in all_parameters:
all_parameters["seed"] = first_seed
else:
all_parameters["seed"] = seed
if self.canceled.is_set():
raise CanceledException
if esrgan_parameters:
progress.set_current_status("Upscaling")
progress.set_current_status_has_steps(False)
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
eventlet.sleep(0)
2022-09-26 20:15:32 +00:00
image = self.esrgan.process(
image=image,
upsampler_scale=esrgan_parameters["level"],
strength=esrgan_parameters["strength"],
seed=seed,
)
2022-09-26 20:15:32 +00:00
postprocessing = True
all_parameters["upscale"] = [
esrgan_parameters["level"],
esrgan_parameters["strength"],
]
2022-09-26 20:15:32 +00:00
if self.canceled.is_set():
raise CanceledException
2022-09-26 20:15:32 +00:00
2022-10-18 13:57:40 +00:00
if facetool_parameters:
if facetool_parameters["type"] == "gfpgan":
progress.set_current_status("Restoring Faces (GFPGAN)")
elif facetool_parameters["type"] == "codeformer":
progress.set_current_status("Restoring Faces (Codeformer)")
2022-10-18 13:57:40 +00:00
progress.set_current_status_has_steps(False)
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
eventlet.sleep(0)
2022-09-26 20:15:32 +00:00
if facetool_parameters["type"] == "gfpgan":
2022-10-18 13:57:40 +00:00
image = self.gfpgan.process(
image=image,
strength=facetool_parameters["strength"],
2022-10-18 13:57:40 +00:00
seed=seed,
)
elif facetool_parameters["type"] == "codeformer":
2022-10-18 13:57:40 +00:00
image = self.codeformer.process(
image=image,
strength=facetool_parameters["strength"],
fidelity=facetool_parameters["codeformer_fidelity"],
2022-10-18 13:57:40 +00:00
seed=seed,
device="cpu"
if str(self.generate.device) == "mps"
2022-10-27 04:24:00 +00:00
else self.generate.device,
2022-10-18 13:57:40 +00:00
)
all_parameters["codeformer_fidelity"] = facetool_parameters[
"codeformer_fidelity"
]
2022-10-18 13:57:40 +00:00
postprocessing = True
all_parameters["facetool_strength"] = facetool_parameters[
"strength"
2022-10-18 13:57:40 +00:00
]
all_parameters["facetool_type"] = facetool_parameters["type"]
2022-09-26 20:15:32 +00:00
progress.set_current_status("Saving Image")
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
2022-09-26 20:15:32 +00:00
eventlet.sleep(0)
2022-10-27 04:24:00 +00:00
# paste the inpainting image back onto the original
if "init_mask" in generation_parameters:
2022-10-27 04:24:00 +00:00
image = paste_image_into_bounding_box(
Image.open(init_img_path),
image,
**generation_parameters["bounding_box"],
2022-10-27 04:24:00 +00:00
)
# restore the stashed URLS and discard the paths, we are about to send the result to client
if "init_img" in all_parameters:
all_parameters["init_img"] = init_img_url
2022-09-26 20:15:32 +00:00
if "init_mask" in all_parameters:
all_parameters["init_mask"] = "" # TODO: store the mask in metadata
2022-09-26 20:15:32 +00:00
metadata = self.parameters_to_generated_image_metadata(all_parameters)
command = parameters_to_command(all_parameters)
2022-10-27 04:24:00 +00:00
(width, height) = image.size
path = self.save_result_image(
image,
command,
metadata,
self.result_path,
postprocessing=postprocessing,
)
print(f'>> Image generated: "{path}"')
self.write_log_message(f'[Generated] "{path}": {command}')
2022-09-26 20:15:32 +00:00
if progress.total_iterations > progress.current_iteration:
progress.set_current_step(1)
progress.set_current_status("Iteration complete")
progress.set_current_status_has_steps(False)
else:
progress.mark_complete()
2022-09-26 20:15:32 +00:00
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
eventlet.sleep(0)
2022-09-26 20:15:32 +00:00
self.socketio.emit(
"generationResult",
{
"url": self.get_url_from_image_path(path),
"mtime": os.path.getmtime(path),
"metadata": metadata,
"width": width,
"height": height,
},
)
eventlet.sleep(0)
2022-09-26 20:15:32 +00:00
progress.set_current_iteration(progress.current_iteration + 1)
2022-09-26 20:15:32 +00:00
self.generate.prompt2image(
**generation_parameters,
step_callback=image_progress,
image_callback=image_done,
)
except KeyboardInterrupt:
self.socketio.emit("processingCanceled")
2022-09-26 20:15:32 +00:00
raise
except CanceledException:
self.socketio.emit("processingCanceled")
2022-09-26 20:15:32 +00:00
pass
except Exception as e:
print(e)
self.socketio.emit("error", {"message": (str(e))})
print("\n")
2022-09-26 20:15:32 +00:00
traceback.print_exc()
print("\n")
2022-09-26 20:15:32 +00:00
def parameters_to_generated_image_metadata(self, parameters):
2022-10-04 04:01:13 +00:00
try:
# top-level metadata minus `image` or `images`
metadata = self.get_system_config()
# remove any image keys not mentioned in RFC #266
rfc266_img_fields = [
"type",
"postprocessing",
"sampler",
"prompt",
"seed",
"variations",
"steps",
"cfg_scale",
"threshold",
"perlin",
"step_number",
"width",
"height",
"extra",
"seamless",
"hires_fix",
2022-10-04 04:01:13 +00:00
]
2022-09-26 20:15:32 +00:00
rfc_dict = {}
2022-09-26 20:15:32 +00:00
for item in parameters.items():
key, value = item
if key in rfc266_img_fields:
rfc_dict[key] = value
2022-09-26 20:15:32 +00:00
postprocessing = []
2022-09-26 20:15:32 +00:00
# 'postprocessing' is either null or an
if "facetool_strength" in parameters:
2022-10-18 13:57:40 +00:00
facetool_parameters = {
"type": str(parameters["facetool_type"]),
"strength": float(parameters["facetool_strength"]),
2022-10-27 04:24:00 +00:00
}
2022-10-18 13:57:40 +00:00
if parameters["facetool_type"] == "codeformer":
facetool_parameters["fidelity"] = float(
parameters["codeformer_fidelity"]
2022-10-27 04:24:00 +00:00
)
2022-10-18 13:57:40 +00:00
postprocessing.append(facetool_parameters)
2022-09-26 20:15:32 +00:00
if "upscale" in parameters:
postprocessing.append(
{
"type": "esrgan",
"scale": int(parameters["upscale"][0]),
"strength": float(parameters["upscale"][1]),
}
)
2022-09-26 20:15:32 +00:00
rfc_dict["postprocessing"] = (
postprocessing if len(postprocessing) > 0 else None
2022-09-26 20:15:32 +00:00
)
# semantic drift
rfc_dict["sampler"] = parameters["sampler_name"]
2022-09-26 20:15:32 +00:00
# display weighted subprompts (liable to change)
2022-10-27 04:24:00 +00:00
subprompts = split_weighted_subprompts(
parameters["prompt"], skip_normalize=True
2022-10-27 04:24:00 +00:00
)
subprompts = [{"prompt": x[0], "weight": x[1]} for x in subprompts]
rfc_dict["prompt"] = subprompts
2022-09-26 20:15:32 +00:00
# 'variations' should always exist and be an array, empty or consisting of {'seed': seed, 'weight': weight} pairs
variations = []
2022-09-26 20:15:32 +00:00
if "with_variations" in parameters:
variations = [
{"seed": x[0], "weight": x[1]}
for x in parameters["with_variations"]
]
2022-09-26 20:15:32 +00:00
rfc_dict["variations"] = variations
2022-09-26 20:15:32 +00:00
if "init_img" in parameters:
rfc_dict["type"] = "img2img"
rfc_dict["strength"] = parameters["strength"]
rfc_dict["fit"] = parameters["fit"] # TODO: Noncompliant
rfc_dict["orig_hash"] = calculate_init_img_hash(
self.get_image_path_from_url(parameters["init_img"])
)
rfc_dict["init_image_path"] = parameters[
"init_img"
2022-09-26 20:15:32 +00:00
] # TODO: Noncompliant
2022-10-27 04:24:00 +00:00
# if 'init_mask' in parameters:
# rfc_dict['mask_hash'] = calculate_init_img_hash(
# self.get_image_path_from_url(parameters['init_mask'])
# ) # TODO: Noncompliant
# rfc_dict['mask_image_path'] = parameters[
# 'init_mask'
# ] # TODO: Noncompliant
else:
rfc_dict["type"] = "txt2img"
metadata["image"] = rfc_dict
2022-09-26 20:15:32 +00:00
return metadata
2022-09-26 20:15:32 +00:00
except Exception as e:
self.socketio.emit("error", {"message": (str(e))})
print("\n")
2022-09-26 20:15:32 +00:00
traceback.print_exc()
print("\n")
2022-09-26 20:15:32 +00:00
def parameters_to_post_processed_image_metadata(
self, parameters, original_image_path
2022-09-26 20:15:32 +00:00
):
try:
current_metadata = retrieve_metadata(original_image_path)["sd-metadata"]
postprocessing_metadata = {}
2022-09-26 20:15:32 +00:00
"""
if we don't have an original image metadata to reconstruct,
need to record the original image and its hash
"""
if "image" not in current_metadata:
current_metadata["image"] = {}
2022-09-26 20:15:32 +00:00
orig_hash = calculate_init_img_hash(
self.get_image_path_from_url(original_image_path)
)
2022-09-26 20:15:32 +00:00
postprocessing_metadata["orig_path"] = (original_image_path,)
postprocessing_metadata["orig_hash"] = orig_hash
if parameters["type"] == "esrgan":
postprocessing_metadata["type"] = "esrgan"
postprocessing_metadata["scale"] = parameters["upscale"][0]
postprocessing_metadata["strength"] = parameters["upscale"][1]
elif parameters["type"] == "gfpgan":
postprocessing_metadata["type"] = "gfpgan"
postprocessing_metadata["strength"] = parameters["facetool_strength"]
elif parameters["type"] == "codeformer":
postprocessing_metadata["type"] = "codeformer"
postprocessing_metadata["strength"] = parameters["facetool_strength"]
postprocessing_metadata["fidelity"] = parameters["codeformer_fidelity"]
2022-10-18 13:57:40 +00:00
else:
raise TypeError(f"Invalid type: {parameters['type']}")
if "postprocessing" in current_metadata["image"] and isinstance(
current_metadata["image"]["postprocessing"], list
):
current_metadata["image"]["postprocessing"].append(
postprocessing_metadata
)
else:
current_metadata["image"]["postprocessing"] = [postprocessing_metadata]
return current_metadata
2022-09-26 20:15:32 +00:00
except Exception as e:
self.socketio.emit("error", {"message": (str(e))})
print("\n")
traceback.print_exc()
print("\n")
2022-09-26 20:15:32 +00:00
def save_result_image(
2022-09-26 20:15:32 +00:00
self,
image,
command,
metadata,
output_dir,
step_index=None,
postprocessing=False,
):
try:
pngwriter = PngWriter(output_dir)
2022-10-29 00:10:48 +00:00
number_prefix = pngwriter.unique_prefix()
uuid = uuid4().hex
truncated_uuid = uuid[:8]
2022-09-26 20:15:32 +00:00
seed = "unknown_seed"
2022-09-26 20:15:32 +00:00
if "image" in metadata:
if "seed" in metadata["image"]:
seed = metadata["image"]["seed"]
2022-09-26 20:15:32 +00:00
2022-10-29 00:10:48 +00:00
filename = f"{number_prefix}.{truncated_uuid}.{seed}"
2022-09-26 20:15:32 +00:00
if step_index:
filename += f".{step_index}"
if postprocessing:
filename += f".postprocessed"
2022-09-26 20:15:32 +00:00
filename += ".png"
2022-09-26 20:15:32 +00:00
path = pngwriter.save_image_and_prompt_to_png(
image=image,
dream_prompt=command,
metadata=metadata,
name=filename,
)
2022-09-26 20:15:32 +00:00
return os.path.abspath(path)
except Exception as e:
self.socketio.emit("error", {"message": (str(e))})
print("\n")
2022-09-26 20:15:32 +00:00
traceback.print_exc()
print("\n")
2022-09-26 20:15:32 +00:00
def make_unique_init_image_filename(self, name):
try:
uuid = uuid4().hex
split = os.path.splitext(name)
name = f"{split[0]}.{uuid}{split[1]}"
return name
except Exception as e:
self.socketio.emit("error", {"message": (str(e))})
print("\n")
traceback.print_exc()
print("\n")
2022-09-26 20:15:32 +00:00
def calculate_real_steps(self, steps, strength, has_init_image):
import math
2022-09-26 20:15:32 +00:00
return math.floor(strength * steps) if has_init_image else steps
def write_log_message(self, message):
"""Logs the filename and parameters used to generate or process that image to log file"""
try:
message = f"{message}\n"
with open(self.log_path, "a", encoding="utf-8") as file:
file.writelines(message)
except Exception as e:
self.socketio.emit("error", {"message": (str(e))})
print("\n")
traceback.print_exc()
print("\n")
2022-09-26 20:15:32 +00:00
def get_image_path_from_url(self, url):
"""Given a url to an image used by the client, returns the absolute file path to that image"""
try:
if "init-images" in url:
return os.path.abspath(
os.path.join(self.init_image_path, os.path.basename(url))
)
elif "mask-images" in url:
return os.path.abspath(
os.path.join(self.mask_image_path, os.path.basename(url))
)
elif "intermediates" in url:
return os.path.abspath(
os.path.join(self.intermediate_path, os.path.basename(url))
)
else:
return os.path.abspath(
os.path.join(self.result_path, os.path.basename(url))
)
except Exception as e:
self.socketio.emit("error", {"message": (str(e))})
print("\n")
traceback.print_exc()
print("\n")
def get_url_from_image_path(self, path):
"""Given an absolute file path to an image, returns the URL that the client can use to load the image"""
try:
if "init-images" in path:
return os.path.join(self.init_image_url, os.path.basename(path))
elif "mask-images" in path:
return os.path.join(self.mask_image_url, os.path.basename(path))
elif "intermediates" in path:
return os.path.join(self.intermediate_url, os.path.basename(path))
else:
return os.path.join(self.result_url, os.path.basename(path))
except Exception as e:
self.socketio.emit("error", {"message": (str(e))})
print("\n")
traceback.print_exc()
print("\n")
def save_file_unique_uuid_name(self, bytes, name, path):
try:
uuid = uuid4().hex
2022-10-29 00:10:48 +00:00
truncated_uuid = uuid[:8]
split = os.path.splitext(name)
2022-10-29 00:10:48 +00:00
name = f"{split[0]}.{truncated_uuid}{split[1]}"
file_path = os.path.join(path, name)
2022-10-29 00:10:48 +00:00
os.makedirs(os.path.dirname(file_path), exist_ok=True)
2022-10-29 00:10:48 +00:00
newFile = open(file_path, "wb")
newFile.write(bytes)
2022-10-29 00:10:48 +00:00
return file_path
except Exception as e:
self.socketio.emit("error", {"message": (str(e))})
print("\n")
traceback.print_exc()
print("\n")
class Progress:
def __init__(self, generation_parameters=None):
self.current_step = 1
self.total_steps = (
self._calculate_real_steps(
steps=generation_parameters["steps"],
strength=generation_parameters["strength"]
if "strength" in generation_parameters
else None,
has_init_image="init_img" in generation_parameters,
)
if generation_parameters
else 1
)
self.current_iteration = 1
self.total_iterations = (
generation_parameters["iterations"] if generation_parameters else 1
)
self.current_status = "Preparing"
self.is_processing = True
self.current_status_has_steps = False
self.has_error = False
def set_current_step(self, current_step):
self.current_step = current_step
def set_total_steps(self, total_steps):
self.total_steps = total_steps
def set_current_iteration(self, current_iteration):
self.current_iteration = current_iteration
def set_total_iterations(self, total_iterations):
self.total_iterations = total_iterations
def set_current_status(self, current_status):
self.current_status = current_status
def set_is_processing(self, is_processing):
self.is_processing = is_processing
def set_current_status_has_steps(self, current_status_has_steps):
self.current_status_has_steps = current_status_has_steps
def set_has_error(self, has_error):
self.has_error = has_error
def mark_complete(self):
self.current_status = "Processing Complete"
self.current_step = 0
self.total_steps = 0
self.current_iteration = 0
self.total_iterations = 0
self.is_processing = False
def to_formatted_dict(
self,
):
return {
"currentStep": self.current_step,
"totalSteps": self.total_steps,
"currentIteration": self.current_iteration,
"totalIterations": self.total_iterations,
"currentStatus": self.current_status,
"isProcessing": self.is_processing,
"currentStatusHasSteps": self.current_status_has_steps,
"hasError": self.has_error,
}
def _calculate_real_steps(self, steps, strength, has_init_image):
return math.floor(strength * steps) if has_init_image else steps
2022-09-26 20:15:32 +00:00
class CanceledException(Exception):
pass
2022-10-27 04:24:00 +00:00
"""
Crops an image to a bounding box.
"""
def copy_image_from_bounding_box(image, x, y, width, height):
with image as im:
bounds = (x, y, x + width, y + height)
im_cropped = im.crop(bounds)
return im_cropped
"""
Pastes an image onto another with a bounding box.
"""
def paste_image_into_bounding_box(recipient_image, donor_image, x, y, width, height):
2022-10-27 04:24:00 +00:00
with recipient_image as im:
bounds = (x, y, x + width, y + height)
im.paste(donor_image, bounds)
return recipient_image