InvokeAI/server/containers.py
psychedelicious d1a2c4cd8c
React web UI with flask-socketio API (#429)
* Implements rudimentary api
* Fixes blocking in API
* Adds UI to monorepo > src/frontend/
* Updates frontend/README
* Reverts conda env name to `ldm`
* Fixes environment yamls
* CORS config for testing
* Fixes LogViewer position
* API WID
* Adds actions to image viewer
* Increases vite chunkSizeWarningLimit to 1500
* Implements init image
* Implements state persistence in localStorage
* Improve progress data handling
* Final build
* Fixes mimetypes error on windows
* Adds error logging
* Fixes bugged img2img strength component
* Adds sourcemaps to dev build
* Fixes missing key
* Changes connection status indicator to text
* Adds ability to serve other hosts than localhost
* Adding Flask API server
* Removes source maps from config
* Fixes prop transfer
* Add missing packages and add CORS support
* Adding API doc
* Remove defaults from openapi doc
* Adds basic error handling for server config query
* Mostly working socket.io implementation.
* Fixes bug preventing mask upload
* Fixes bug with sampler name not written to metadata
* UI Overhaul, numerous fixes

Co-authored-by: Kyle Schouviller <kyle0654@hotmail.com>
Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2022-09-16 13:18:15 -04:00

76 lines
2.0 KiB
Python

# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
"""Containers module."""
from dependency_injector import containers, providers
from flask_socketio import SocketIO
from ldm.generate import Generate
from server import services
class Container(containers.DeclarativeContainer):
wiring_config = containers.WiringConfiguration(packages=['server'])
config = providers.Configuration()
socketio = providers.ThreadSafeSingleton(
SocketIO,
app = None
)
model_singleton = providers.ThreadSafeSingleton(
Generate,
width = config.model.width,
height = config.model.height,
sampler_name = config.model.sampler_name,
weights = config.model.weights,
full_precision = config.model.full_precision,
config = config.model.config,
grid = config.model.grid,
seamless = config.model.seamless,
embedding_path = config.model.embedding_path,
device_type = config.model.device_type
)
# TODO: get location from config
image_storage_service = providers.ThreadSafeSingleton(
services.ImageStorageService,
'./outputs/img-samples/'
)
# TODO: get location from config
image_intermediates_storage_service = providers.ThreadSafeSingleton(
services.ImageStorageService,
'./outputs/intermediates/'
)
signal_queue_service = providers.ThreadSafeSingleton(
services.SignalQueueService
)
signal_service = providers.ThreadSafeSingleton(
services.SignalService,
socketio = socketio,
queue = signal_queue_service
)
generation_queue_service = providers.ThreadSafeSingleton(
services.JobQueueService
)
# TODO: get locations from config
log_service = providers.ThreadSafeSingleton(
services.LogService,
'./outputs/img-samples/',
'dream_web_log.txt'
)
generator_service = providers.ThreadSafeSingleton(
services.GeneratorService,
model = model_singleton,
queue = generation_queue_service,
imageStorage = image_storage_service,
intermediateStorage = image_intermediates_storage_service,
log = log_service,
signal_service = signal_service
)