From acc5199f85c2f921cdac8b285fa9906efa726ec4 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Tue, 27 Sep 2022 09:15:32 +1300 Subject: [PATCH] Integrate New WebUI with dream.py --- backend/invoke_ai_web_server.py | 813 ++++++++++++++++++++++++++++++++ backend/modules/parameters.py | 2 +- ldm/dream/args.py | 17 + scripts/dream.py | 38 +- 4 files changed, 843 insertions(+), 27 deletions(-) create mode 100644 backend/invoke_ai_web_server.py diff --git a/backend/invoke_ai_web_server.py b/backend/invoke_ai_web_server.py new file mode 100644 index 0000000000..037baa40fa --- /dev/null +++ b/backend/invoke_ai_web_server.py @@ -0,0 +1,813 @@ +import eventlet +import glob +import os +import shutil + +from flask import Flask, redirect, send_from_directory +from flask_socketio import SocketIO +from PIL import Image +from uuid import uuid4 +from threading import Event + +from ldm.dream.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash +from ldm.dream.pngwriter import PngWriter, retrieve_metadata +from ldm.dream.conditioning import split_weighted_subprompts + +from backend.modules.parameters import parameters_to_command + +# Loading Arguments +opt = Args() +args = opt.parse_args() + + +class InvokeAIWebServer: + def __init__(self, generate, gfpgan, codeformer, esrgan) -> None: + self.host = args.host + self.port = args.port + + self.generate = generate + self.gfpgan = gfpgan + self.codeformer = codeformer + self.esrgan = esrgan + + self.canceled = Event() + + def run(self): + self.setup_app() + self.setup_flask() + + def setup_flask(self): + # Socket IO + logger = True if args.web_verbose else False + engineio_logger = True if args.web_verbose else False + max_http_buffer_size = 10000000 + + # CORS Allowed Setup + cors_allowed_origins = ['http://127.0.0.1:5173', 'http://localhost:5173'] + additional_allowed_origins = ( + opt.cors if opt.cors else [] + ) # additional CORS allowed origins + if self.host == '127.0.0.1': + cors_allowed_origins.extend( + [ + f'http://{self.host}:{self.port}', + f'http://localhost:{self.port}', + ] + ) + cors_allowed_origins = ( + cors_allowed_origins + additional_allowed_origins + ) + + self.app = Flask( + __name__, static_url_path='', static_folder='../frontend/dist/' + ) + + self.socketio = SocketIO( + self.app, + logger=logger, + engineio_logger=engineio_logger, + max_http_buffer_size=max_http_buffer_size, + cors_allowed_origins=cors_allowed_origins, + ping_interval=(50, 50), + ping_timeout=60, + ) + + # Outputs Route + self.app.config['OUTPUTS_FOLDER'] = f'../{args.outdir}' + + @self.app.route('/outputs/') + def outputs(filename): + return send_from_directory( + self.app.config['OUTPUTS_FOLDER'], filename + ) + + # Base Route + @self.app.route('/') + def serve(): + if args.web_develop: + return redirect('http://127.0.0.1:5173') + else: + return send_from_directory( + self.app.static_folder, 'index.html' + ) + + self.load_socketio_listeners(self.socketio) + + print('>> Started Invoke AI Web Server!') + if self.host == '0.0.0.0': + print( + f"Point your browser at http://localhost:{self.port} or use the host's DNS name or IP address." + ) + else: + print( + '>> Default host address now 127.0.0.1 (localhost). Use --host 0.0.0.0 to bind any address.' + ) + print(f'>> Point your browser at http://{self.host}:{self.port}') + + self.socketio.run(app=self.app, host=self.host, port=self.port) + + def setup_app(self): + # location for "finished" images + self.result_path = args.outdir + # temporary path for intermediates + self.intermediate_path = os.path.join( + self.result_path, 'intermediates/' + ) + # path for user-uploaded init images and masks + self.init_image_path = os.path.join(self.result_path, 'init-images/') + self.mask_image_path = os.path.join(self.result_path, 'mask-images/') + # txt log + self.log_path = os.path.join(self.result_path, 'dream_log.txt') + # make all output paths + [ + os.makedirs(path, exist_ok=True) + for path in [ + self.result_path, + self.intermediate_path, + self.init_image_path, + self.mask_image_path, + ] + ] + + def load_socketio_listeners(self, socketio): + @socketio.on('requestSystemConfig') + def handle_request_capabilities(): + print(f'>> System config requested') + config = self.get_system_config() + socketio.emit('systemConfig', config) + + @socketio.on('requestImages') + def handle_request_images(page=1, offset=0, last_mtime=None): + chunk_size = 50 + + if last_mtime: + print(f'>> Latest images requested') + else: + print( + f'>> Page {page} of images requested (page size {chunk_size} offset {offset})' + ) + + paths = glob.glob(os.path.join(self.result_path, '*.png')) + sorted_paths = sorted( + paths, key=lambda x: os.path.getmtime(x), reverse=True + ) + + if last_mtime: + image_paths = filter( + lambda x: os.path.getmtime(x) > last_mtime, sorted_paths + ) + else: + + image_paths = sorted_paths[ + slice( + chunk_size * (page - 1) + offset, + chunk_size * page + offset, + ) + ] + page = page + 1 + + image_array = [] + + for path in image_paths: + metadata = retrieve_metadata(path) + image_array.append( + { + 'url': path, + 'mtime': os.path.getmtime(path), + 'metadata': metadata['sd-metadata'], + } + ) + + socketio.emit( + 'galleryImages', + { + 'images': image_array, + 'nextPage': page, + 'offset': offset, + 'onlyNewImages': True if last_mtime else False, + }, + ) + + @socketio.on('generateImage') + def handle_generate_image_event( + generation_parameters, esrgan_parameters, gfpgan_parameters + ): + print( + f'>> Image generation requested: {generation_parameters}\nESRGAN parameters: {esrgan_parameters}\nGFPGAN parameters: {gfpgan_parameters}' + ) + self.generate_images( + generation_parameters, esrgan_parameters, gfpgan_parameters + ) + + @socketio.on('runESRGAN') + def handle_run_esrgan_event(original_image, esrgan_parameters): + print( + f'>> ESRGAN upscale requested for "{original_image["url"]}": {esrgan_parameters}' + ) + progress = { + 'currentStep': 1, + 'totalSteps': 1, + 'currentIteration': 1, + 'totalIterations': 1, + 'currentStatus': 'Preparing', + 'isProcessing': True, + 'currentStatusHasSteps': False, + } + + socketio.emit('progressUpdate', progress) + eventlet.sleep(0) + + image = Image.open(original_image['url']) + + seed = ( + original_image['metadata']['seed'] + if 'seed' in original_image['metadata'] + else 'unknown_seed' + ) + + progress['currentStatus'] = 'Upscaling' + socketio.emit('progressUpdate', progress) + eventlet.sleep(0) + + image = self.esrgan.process( + image=image, + upsampler_scale=esrgan_parameters['upscale'][0], + strength=esrgan_parameters['upscale'][1], + seed=seed, + ) + + progress['currentStatus'] = 'Saving image' + socketio.emit('progressUpdate', progress) + eventlet.sleep(0) + + esrgan_parameters['seed'] = seed + metadata = self.parameters_to_post_processed_image_metadata( + parameters=esrgan_parameters, + original_image_path=original_image['url'], + type='esrgan', + ) + command = parameters_to_command(esrgan_parameters) + + path = self.save_image( + image, + command, + metadata, + self.result_path, + postprocessing='esrgan', + ) + + self.write_log_message( + f'[Upscaled] "{original_image["url"]}" > "{path}": {command}' + ) + + progress['currentStatus'] = 'Finished' + progress['currentStep'] = 0 + progress['totalSteps'] = 0 + progress['currentIteration'] = 0 + progress['totalIterations'] = 0 + progress['isProcessing'] = False + socketio.emit('progressUpdate', progress) + eventlet.sleep(0) + + socketio.emit( + 'esrganResult', + { + 'url': os.path.relpath(path), + 'mtime': os.path.getmtime(path), + 'metadata': metadata, + }, + ) + + @socketio.on('runGFPGAN') + def handle_run_gfpgan_event(original_image, gfpgan_parameters): + print( + f'>> GFPGAN face fix requested for "{original_image["url"]}": {gfpgan_parameters}' + ) + progress = { + 'currentStep': 1, + 'totalSteps': 1, + 'currentIteration': 1, + 'totalIterations': 1, + 'currentStatus': 'Preparing', + 'isProcessing': True, + 'currentStatusHasSteps': False, + } + + socketio.emit('progressUpdate', progress) + eventlet.sleep(0) + + image = Image.open(original_image['url']) + + seed = ( + original_image['metadata']['seed'] + if 'seed' in original_image['metadata'] + else 'unknown_seed' + ) + + progress['currentStatus'] = 'Fixing faces' + socketio.emit('progressUpdate', progress) + eventlet.sleep(0) + + image = self.gfpgan.process( + image=image, + strength=gfpgan_parameters['gfpgan_strength'], + seed=seed, + ) + + progress['currentStatus'] = 'Saving image' + socketio.emit('progressUpdate', progress) + eventlet.sleep(0) + + gfpgan_parameters['seed'] = seed + metadata = self.parameters_to_post_processed_image_metadata( + parameters=gfpgan_parameters, + original_image_path=original_image['url'], + type='gfpgan', + ) + command = parameters_to_command(gfpgan_parameters) + + path = self.save_image( + image, + command, + metadata, + self.result_path, + postprocessing='gfpgan', + ) + + self.write_log_message( + f'[Fixed faces] "{original_image["url"]}" > "{path}": {command}' + ) + + progress['currentStatus'] = 'Finished' + progress['currentStep'] = 0 + progress['totalSteps'] = 0 + progress['currentIteration'] = 0 + progress['totalIterations'] = 0 + progress['isProcessing'] = False + socketio.emit('progressUpdate', progress) + eventlet.sleep(0) + + socketio.emit( + 'gfpganResult', + { + 'url': os.path.relpath(path), + 'mtime': os.path.getmtime(path), + 'metadata': metadata, + }, + ) + + @socketio.on('cancel') + def handle_cancel(): + print(f'>> Cancel processing requested') + self.canceled.set() + socketio.emit('processingCanceled') + + # TODO: I think this needs a safety mechanism. + @socketio.on('deleteImage') + def handle_delete_image(path, uuid): + print(f'>> Delete requested "{path}"') + from send2trash import send2trash + + send2trash(path) + socketio.emit('imageDeleted', {'url': path, 'uuid': uuid}) + + # TODO: I think this needs a safety mechanism. + @socketio.on('uploadInitialImage') + def handle_upload_initial_image(bytes, name): + print(f'>> Init image upload requested "{name}"') + uuid = uuid4().hex + split = os.path.splitext(name) + name = f'{split[0]}.{uuid}{split[1]}' + file_path = os.path.join(self.init_image_path, name) + os.makedirs(os.path.dirname(file_path), exist_ok=True) + newFile = open(file_path, 'wb') + newFile.write(bytes) + socketio.emit( + 'initialImageUploaded', {'url': file_path, 'uuid': ''} + ) + + # TODO: I think this needs a safety mechanism. + @socketio.on('uploadMaskImage') + def handle_upload_mask_image(bytes, name): + print(f'>> Mask image upload requested "{name}"') + uuid = uuid4().hex + split = os.path.splitext(name) + name = f'{split[0]}.{uuid}{split[1]}' + file_path = os.path.join(self.mask_image_path, name) + os.makedirs(os.path.dirname(file_path), exist_ok=True) + newFile = open(file_path, 'wb') + newFile.write(bytes) + socketio.emit('maskImageUploaded', {'url': file_path, 'uuid': ''}) + + # App Functions + def get_system_config(self): + return { + 'model': 'stable diffusion', + 'model_id': args.model, + 'model_hash': self.generate.model_hash, + 'app_id': APP_ID, + 'app_version': APP_VERSION, + } + + def generate_images( + self, generation_parameters, esrgan_parameters, gfpgan_parameters + ): + self.canceled.clear() + + step_index = 1 + prior_variations = ( + generation_parameters['with_variations'] + if 'with_variations' in generation_parameters + else [] + ) + """ + If a result image is used as an init image, and then deleted, we will want to be + able to use it as an init image in the future. Need to copy it. + + If the init/mask image doesn't exist in the init_image_path/mask_image_path, + make a unique filename for it and copy it there. + """ + if 'init_img' in generation_parameters: + filename = os.path.basename(generation_parameters['init_img']) + if not os.path.exists( + os.path.join(self.init_image_path, filename) + ): + unique_filename = self.make_unique_init_image_filename( + filename + ) + new_path = os.path.join(self.init_image_path, unique_filename) + shutil.copy(generation_parameters['init_img'], new_path) + generation_parameters['init_img'] = new_path + if 'init_mask' in generation_parameters: + filename = os.path.basename(generation_parameters['init_mask']) + if not os.path.exists( + os.path.join(self.mask_image_path, filename) + ): + unique_filename = self.make_unique_init_image_filename( + filename + ) + new_path = os.path.join( + self.init_image_path, unique_filename + ) + shutil.copy(generation_parameters['init_img'], new_path) + generation_parameters['init_mask'] = new_path + + totalSteps = self.calculate_real_steps( + steps=generation_parameters['steps'], + strength=generation_parameters['strength'] + if 'strength' in generation_parameters + else None, + has_init_image='init_img' in generation_parameters, + ) + + progress = { + 'currentStep': 1, + 'totalSteps': totalSteps, + 'currentIteration': 1, + 'totalIterations': generation_parameters['iterations'], + 'currentStatus': 'Preparing', + 'isProcessing': True, + 'currentStatusHasSteps': False, + } + + self.socketio.emit('progressUpdate', progress) + eventlet.sleep(0) + + def image_progress(sample, step): + if self.canceled.is_set(): + raise CanceledException + + nonlocal step_index + nonlocal generation_parameters + nonlocal progress + + progress['currentStep'] = step + 1 + progress['currentStatus'] = 'Generating' + progress['currentStatusHasSteps'] = True + + if ( + generation_parameters['progress_images'] + and step % 5 == 0 + and step < generation_parameters['steps'] - 1 + ): + image = self.generate.sample_to_image(sample) + metadata = self.parameters_to_generated_image_metadata(generation_parameters) + command = parameters_to_command(generation_parameters) + path = self.save_image(image, command, metadata, self.intermediate_path, step_index=step_index, postprocessing=False) + + step_index += 1 + self.socketio.emit( + 'intermediateResult', + { + 'url': os.path.relpath(path), + 'mtime': os.path.getmtime(path), + 'metadata': metadata, + }, + ) + self.socketio.emit('progressUpdate', progress) + eventlet.sleep(0) + + def image_done(image, seed, first_seed): + nonlocal generation_parameters + nonlocal esrgan_parameters + nonlocal gfpgan_parameters + nonlocal progress + + step_index = 1 + nonlocal prior_variations + + progress['currentStatus'] = 'Generation complete' + self.socketio.emit('progressUpdate', progress) + eventlet.sleep(0) + + all_parameters = generation_parameters + postprocessing = False + + if ( + 'variation_amount' in all_parameters + and all_parameters['variation_amount'] > 0 + ): + first_seed = first_seed or seed + this_variation = [[seed, all_parameters['variation_amount']]] + all_parameters['with_variations'] = ( + prior_variations + this_variation + ) + all_parameters['seed'] = first_seed + elif 'with_variations' in all_parameters: + all_parameters['seed'] = first_seed + else: + all_parameters['seed'] = seed + + if esrgan_parameters: + progress['currentStatus'] = 'Upscaling' + progress['currentStatusHasSteps'] = False + self.socketio.emit('progressUpdate', progress) + eventlet.sleep(0) + + image = self.esrgan.process( + image=image, + upsampler_scale=esrgan_parameters['level'], + strength=esrgan_parameters['strength'], + seed=seed, + ) + + postprocessing = True + all_parameters['upscale'] = [ + esrgan_parameters['level'], + esrgan_parameters['strength'], + ] + + if gfpgan_parameters: + progress['currentStatus'] = 'Fixing faces' + progress['currentStatusHasSteps'] = False + self.socketio.emit('progressUpdate', progress) + eventlet.sleep(0) + + image = self.gfpgan.process( + image=image, + strength=gfpgan_parameters['strength'], + seed=seed, + ) + postprocessing = True + all_parameters['gfpgan_strength'] = gfpgan_parameters[ + 'strength' + ] + + progress['currentStatus'] = 'Saving image' + self.socketio.emit('progressUpdate', progress) + eventlet.sleep(0) + + metadata = self.parameters_to_generated_image_metadata( + all_parameters + ) + command = parameters_to_command(all_parameters) + + path = self.save_image( + image, + command, + metadata, + self.result_path, + postprocessing=postprocessing, + ) + + print(f'>> Image generated: "{path}"') + self.write_log_message(f'[Generated] "{path}": {command}') + + if progress['totalIterations'] > progress['currentIteration']: + progress['currentStep'] = 1 + progress['currentIteration'] += 1 + progress['currentStatus'] = 'Iteration finished' + progress['currentStatusHasSteps'] = False + else: + progress['currentStep'] = 0 + progress['totalSteps'] = 0 + progress['currentIteration'] = 0 + progress['totalIterations'] = 0 + progress['currentStatus'] = 'Finished' + progress['isProcessing'] = False + + self.socketio.emit('progressUpdate', progress) + eventlet.sleep(0) + + self.socketio.emit( + 'generationResult', + { + 'url': os.path.relpath(path), + 'mtime': os.path.getmtime(path), + 'metadata': metadata, + }, + ) + eventlet.sleep(0) + + try: + self.generate.prompt2image( + **generation_parameters, + step_callback=image_progress, + image_callback=image_done, + ) + + except KeyboardInterrupt: + raise + except CanceledException: + pass + except Exception as e: + self.socketio.emit('error', {'message': (str(e))}) + print('\n') + import traceback + + traceback.print_exc() + print('\n') + + def parameters_to_generated_image_metadata(self, parameters): + # top-level metadata minus `image` or `images` + metadata = self.get_system_config() + # remove any image keys not mentioned in RFC #266 + rfc266_img_fields = [ + 'type', + 'postprocessing', + 'sampler', + 'prompt', + 'seed', + 'variations', + 'steps', + 'cfg_scale', + 'step_number', + 'width', + 'height', + 'extra', + 'seamless', + ] + + rfc_dict = {} + + for item in parameters.items(): + key, value = item + if key in rfc266_img_fields: + rfc_dict[key] = value + + postprocessing = [] + + # 'postprocessing' is either null or an + if 'gfpgan_strength' in parameters: + + postprocessing.append( + { + 'type': 'gfpgan', + 'strength': float(parameters['gfpgan_strength']), + } + ) + + if 'upscale' in parameters: + postprocessing.append( + { + 'type': 'esrgan', + 'scale': int(parameters['upscale'][0]), + 'strength': float(parameters['upscale'][1]), + } + ) + + rfc_dict['postprocessing'] = ( + postprocessing if len(postprocessing) > 0 else None + ) + + # semantic drift + rfc_dict['sampler'] = parameters['sampler_name'] + + # display weighted subprompts (liable to change) + subprompts = split_weighted_subprompts(parameters['prompt']) + subprompts = [{'prompt': x[0], 'weight': x[1]} for x in subprompts] + rfc_dict['prompt'] = subprompts + + # 'variations' should always exist and be an array, empty or consisting of {'seed': seed, 'weight': weight} pairs + variations = [] + + if 'with_variations' in parameters: + variations = [ + {'seed': x[0], 'weight': x[1]} + for x in parameters['with_variations'] + ] + + rfc_dict['variations'] = variations + + if 'init_img' in parameters: + rfc_dict['type'] = 'img2img' + rfc_dict['strength'] = parameters['strength'] + rfc_dict['fit'] = parameters['fit'] # TODO: Noncompliant + rfc_dict['orig_hash'] = calculate_init_img_hash( + parameters['init_img'] + ) + rfc_dict['init_image_path'] = parameters[ + 'init_img' + ] # TODO: Noncompliant + rfc_dict[ + 'sampler' + ] = 'ddim' # TODO: FIX ME WHEN IMG2IMG SUPPORTS ALL SAMPLERS + if 'init_mask' in parameters: + rfc_dict['mask_hash'] = calculate_init_img_hash( + parameters['init_mask'] + ) # TODO: Noncompliant + rfc_dict['mask_image_path'] = parameters[ + 'init_mask' + ] # TODO: Noncompliant + else: + rfc_dict['type'] = 'txt2img' + + metadata['image'] = rfc_dict + + return metadata + + def parameters_to_post_processed_image_metadata( + self, parameters, original_image_path, type + ): + # top-level metadata minus `image` or `images` + metadata = self.get_system_config() + + orig_hash = calculate_init_img_hash(original_image_path) + + image = {'orig_path': original_image_path, 'orig_hash': orig_hash} + + if type == 'esrgan': + image['type'] = 'esrgan' + image['scale'] = parameters['upscale'][0] + image['strength'] = parameters['upscale'][1] + elif type == 'gfpgan': + image['type'] = 'gfpgan' + image['strength'] = parameters['gfpgan_strength'] + else: + raise TypeError(f'Invalid type: {type}') + + metadata['image'] = image + return metadata + + def save_image( + self, + image, + command, + metadata, + output_dir, + step_index=None, + postprocessing=False, + ): + pngwriter = PngWriter(output_dir) + prefix = pngwriter.unique_prefix() + + seed = 'unknown_seed' + + if 'image' in metadata: + if 'seed' in metadata['image']: + seed = metadata['image']['seed'] + + filename = f'{prefix}.{seed}' + + if step_index: + filename += f'.{step_index}' + if postprocessing: + filename += f'.postprocessed' + + filename += '.png' + + path = pngwriter.save_image_and_prompt_to_png( + image=image, dream_prompt=command, metadata=metadata, name=filename + ) + + return path + + def make_unique_init_image_filename(self, name): + uuid = uuid4().hex + split = os.path.splitext(name) + name = f'{split[0]}.{uuid}{split[1]}' + return name + + def calculate_real_steps(self, steps, strength, has_init_image): + import math + return math.floor(strength * steps) if has_init_image else steps + + def write_log_message(self, message): + """Logs the filename and parameters used to generate or process that image to log file""" + message = f'{message}\n' + with open(self.log_path, 'a', encoding='utf-8') as file: + file.writelines(message) + + +class CanceledException(Exception): + pass diff --git a/backend/modules/parameters.py b/backend/modules/parameters.py index ec0cfe8272..d15167e792 100644 --- a/backend/modules/parameters.py +++ b/backend/modules/parameters.py @@ -1,4 +1,4 @@ -from modules.parse_seed_weights import parse_seed_weights +from backend.modules.parse_seed_weights import parse_seed_weights import argparse SAMPLER_CHOICES = [ diff --git a/ldm/dream/args.py b/ldm/dream/args.py index 1822e5d112..7e0d44b352 100644 --- a/ldm/dream/args.py +++ b/ldm/dream/args.py @@ -422,6 +422,23 @@ class Args(object): action='store_true', help='Start in web server mode.', ) + web_server_group.add_argument( + '--web_develop', + dest='web_develop', + action='store_true', + help='Start in web server development mode.', + ) + web_server_group.add_argument( + "--web_verbose", + action="store_true", + help="Enables verbose logging", + ) + web_server_group.add_argument( + "--cors", + nargs="*", + type=str, + help="Additional allowed origins, comma-separated", + ) web_server_group.add_argument( '--host', type=str, diff --git a/scripts/dream.py b/scripts/dream.py index cac8c2aee4..c9eb6a0497 100755 --- a/scripts/dream.py +++ b/scripts/dream.py @@ -12,11 +12,12 @@ sys.path.append('.') # corrects a weird problem on Macs import ldm.dream.readline from ldm.dream.args import Args, metadata_dumps, metadata_from_png from ldm.dream.pngwriter import PngWriter -from ldm.dream.server import DreamServer, ThreadingDreamServer from ldm.dream.image_util import make_grid from ldm.dream.log import write_log from omegaconf import OmegaConf +from backend.invoke_ai_web_server import InvokeAIWebServer + # Placeholder to be replaced with proper class that tracks the # outputs and associates with the prompt that generated them. # Just want to get the formatting look right for now. @@ -111,16 +112,16 @@ def main(): #set additional option gen.free_gpu_mem = opt.free_gpu_mem + # web server loops forever + if opt.web: + invoke_ai_web_server_loop(gen, gfpgan, codeformer, esrgan) + sys.exit(0) + if not infile: print( "\n* Initialization done! Awaiting your command (-h for help, 'q' to quit)" ) - # web server loops forever - if opt.web: - dream_server_loop(gen, opt.host, opt.port, opt.outdir, gfpgan) - sys.exit(0) - main_loop(gen, opt, infile) # TODO: main_loop() has gotten busy. Needs to be refactored. @@ -414,35 +415,20 @@ def get_next_command(infile=None) -> str: # command string print(f'#{command}') return command -def dream_server_loop(gen, host, port, outdir, gfpgan): +def invoke_ai_web_server_loop(gen, gfpgan, codeformer, esrgan): print('\n* --web was specified, starting web server...') # Change working directory to the stable-diffusion directory os.chdir( os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) ) - - # Start server - DreamServer.model = gen # misnomer in DreamServer - this is not the model you are looking for - DreamServer.outdir = outdir - DreamServer.gfpgan_model_exists = False - if gfpgan is not None: - DreamServer.gfpgan_model_exists = gfpgan.gfpgan_model_exists - - dream_server = ThreadingDreamServer((host, port)) - print(">> Started Stable Diffusion dream server!") - if host == '0.0.0.0': - print( - f"Point your browser at http://localhost:{port} or use the host's DNS name or IP address.") - else: - print(">> Default host address now 127.0.0.1 (localhost). Use --host 0.0.0.0 to bind any address.") - print(f">> Point your browser at http://{host}:{port}") + + invoke_ai_web_server = InvokeAIWebServer(generate=gen, gfpgan=gfpgan, codeformer=codeformer, esrgan=esrgan) try: - dream_server.serve_forever() + invoke_ai_web_server.run() except KeyboardInterrupt: pass - - dream_server.server_close() + def write_log_message(results, log_path): """logs the name of the output image, prompt, and prompt args to the terminal and log file"""