mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
d1a2c4cd8c
* Implements rudimentary api * Fixes blocking in API * Adds UI to monorepo > src/frontend/ * Updates frontend/README * Reverts conda env name to `ldm` * Fixes environment yamls * CORS config for testing * Fixes LogViewer position * API WID * Adds actions to image viewer * Increases vite chunkSizeWarningLimit to 1500 * Implements init image * Implements state persistence in localStorage * Improve progress data handling * Final build * Fixes mimetypes error on windows * Adds error logging * Fixes bugged img2img strength component * Adds sourcemaps to dev build * Fixes missing key * Changes connection status indicator to text * Adds ability to serve other hosts than localhost * Adding Flask API server * Removes source maps from config * Fixes prop transfer * Add missing packages and add CORS support * Adding API doc * Remove defaults from openapi doc * Adds basic error handling for server config query * Mostly working socket.io implementation. * Fixes bug preventing mask upload * Fixes bug with sampler name not written to metadata * UI Overhaul, numerous fixes Co-authored-by: Kyle Schouviller <kyle0654@hotmail.com> Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
150 lines
4.7 KiB
Python
150 lines
4.7 KiB
Python
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
|
|
|
"""Application module."""
|
|
import argparse
|
|
import json
|
|
import os
|
|
import sys
|
|
from flask import Flask
|
|
from flask_cors import CORS
|
|
from flask_socketio import SocketIO, join_room, leave_room
|
|
from omegaconf import OmegaConf
|
|
from dependency_injector.wiring import inject, Provide
|
|
from server import views
|
|
from server.containers import Container
|
|
from server.services import GeneratorService, SignalService
|
|
|
|
# The socketio_service is injected here (rather than created in run_app) to initialize it
|
|
@inject
|
|
def initialize_app(
|
|
app: Flask,
|
|
socketio: SocketIO = Provide[Container.socketio]
|
|
) -> SocketIO:
|
|
socketio.init_app(app)
|
|
|
|
return socketio
|
|
|
|
# The signal and generator services are injected to warm up the processing queues
|
|
# TODO: Initialize these a better way?
|
|
@inject
|
|
def initialize_generator(
|
|
signal_service: SignalService = Provide[Container.signal_service],
|
|
generator_service: GeneratorService = Provide[Container.generator_service]
|
|
):
|
|
pass
|
|
|
|
|
|
def run_app(config, host, port) -> Flask:
|
|
app = Flask(__name__, static_url_path='')
|
|
|
|
# Set up dependency injection container
|
|
container = Container()
|
|
container.config.from_dict(config)
|
|
container.wire(modules=[__name__])
|
|
app.container = container
|
|
|
|
# Set up CORS
|
|
CORS(app, resources={r'/api/*': {'origins': '*'}})
|
|
|
|
# Web Routes
|
|
app.add_url_rule('/', view_func=views.WebIndex.as_view('web_index', 'index.html'))
|
|
app.add_url_rule('/index.css', view_func=views.WebIndex.as_view('web_index_css', 'index.css'))
|
|
app.add_url_rule('/index.js', view_func=views.WebIndex.as_view('web_index_js', 'index.js'))
|
|
app.add_url_rule('/config.js', view_func=views.WebConfig.as_view('web_config'))
|
|
|
|
# API Routes
|
|
app.add_url_rule('/api/jobs', view_func=views.ApiJobs.as_view('api_jobs'))
|
|
app.add_url_rule('/api/cancel', view_func=views.ApiCancel.as_view('api_cancel'))
|
|
|
|
# TODO: Get storage root from config
|
|
app.add_url_rule('/api/images/<string:dreamId>', view_func=views.ApiImages.as_view('api_images', '../'))
|
|
app.add_url_rule('/api/intermediates/<string:dreamId>/<string:step>', view_func=views.ApiIntermediates.as_view('api_intermediates', '../'))
|
|
|
|
app.static_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../static/dream_web/'))
|
|
|
|
# Initialize
|
|
socketio = initialize_app(app)
|
|
initialize_generator()
|
|
|
|
print(">> Started Stable Diffusion api server!")
|
|
if host == '0.0.0.0':
|
|
print(f"Point your browser at http://localhost:{port} or use the host's DNS name or IP address.")
|
|
else:
|
|
print(">> Default host address now 127.0.0.1 (localhost). Use --host 0.0.0.0 to bind any address.")
|
|
print(f">> Point your browser at http://{host}:{port}.")
|
|
|
|
# Run the app
|
|
socketio.run(app, host, port)
|
|
|
|
|
|
def main():
|
|
"""Initialize command-line parsers and the diffusion model"""
|
|
from scripts.dream import create_argv_parser
|
|
arg_parser = create_argv_parser()
|
|
opt = arg_parser.parse_args()
|
|
|
|
if opt.laion400m:
|
|
print('--laion400m flag has been deprecated. Please use --model laion400m instead.')
|
|
sys.exit(-1)
|
|
if opt.weights != 'model':
|
|
print('--weights argument has been deprecated. Please configure ./configs/models.yaml, and call it using --model instead.')
|
|
sys.exit(-1)
|
|
|
|
try:
|
|
models = OmegaConf.load(opt.config)
|
|
width = models[opt.model].width
|
|
height = models[opt.model].height
|
|
config = models[opt.model].config
|
|
weights = models[opt.model].weights
|
|
except (FileNotFoundError, IOError, KeyError) as e:
|
|
print(f'{e}. Aborting.')
|
|
sys.exit(-1)
|
|
|
|
print('* Initializing, be patient...\n')
|
|
sys.path.append('.')
|
|
from pytorch_lightning import logging
|
|
|
|
# these two lines prevent a horrible warning message from appearing
|
|
# when the frozen CLIP tokenizer is imported
|
|
import transformers
|
|
|
|
transformers.logging.set_verbosity_error()
|
|
|
|
appConfig = {
|
|
"model": {
|
|
"width": width,
|
|
"height": height,
|
|
"sampler_name": opt.sampler_name,
|
|
"weights": weights,
|
|
"full_precision": opt.full_precision,
|
|
"config": config,
|
|
"grid": opt.grid,
|
|
"latent_diffusion_weights": opt.laion400m,
|
|
"embedding_path": opt.embedding_path,
|
|
"device_type": opt.device
|
|
}
|
|
}
|
|
|
|
# make sure the output directory exists
|
|
if not os.path.exists(opt.outdir):
|
|
os.makedirs(opt.outdir)
|
|
|
|
# gets rid of annoying messages about random seed
|
|
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
|
|
|
|
print('\n* starting api server...')
|
|
# Change working directory to the stable-diffusion directory
|
|
os.chdir(
|
|
os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
|
)
|
|
|
|
# Start server
|
|
try:
|
|
run_app(appConfig, opt.host, opt.port)
|
|
except KeyboardInterrupt:
|
|
pass
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|