2022-09-16 17:18:15 +00:00
|
|
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
|
|
|
|
|
|
|
"""Containers module."""
|
|
|
|
|
|
|
|
from dependency_injector import containers, providers
|
|
|
|
from flask_socketio import SocketIO
|
|
|
|
from ldm.generate import Generate
|
|
|
|
from server import services
|
|
|
|
|
|
|
|
class Container(containers.DeclarativeContainer):
|
|
|
|
wiring_config = containers.WiringConfiguration(packages=['server'])
|
|
|
|
|
|
|
|
config = providers.Configuration()
|
|
|
|
|
|
|
|
socketio = providers.ThreadSafeSingleton(
|
|
|
|
SocketIO,
|
|
|
|
app = None
|
|
|
|
)
|
|
|
|
|
2022-09-16 20:35:34 +00:00
|
|
|
# TODO: Add a model provider service that provides model(s) dynamically
|
2022-09-16 17:18:15 +00:00
|
|
|
model_singleton = providers.ThreadSafeSingleton(
|
|
|
|
Generate,
|
2022-09-16 20:35:34 +00:00
|
|
|
model = config.model,
|
|
|
|
sampler_name = config.sampler_name,
|
|
|
|
embedding_path = config.embedding_path,
|
|
|
|
full_precision = config.full_precision
|
|
|
|
# config = config.model.config,
|
|
|
|
|
|
|
|
# width = config.model.width,
|
|
|
|
# height = config.model.height,
|
|
|
|
# sampler_name = config.model.sampler_name,
|
|
|
|
# weights = config.model.weights,
|
|
|
|
# full_precision = config.model.full_precision,
|
|
|
|
# grid = config.model.grid,
|
|
|
|
# seamless = config.model.seamless,
|
|
|
|
# embedding_path = config.model.embedding_path,
|
|
|
|
# device_type = config.model.device_type
|
2022-09-16 17:18:15 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
# TODO: get location from config
|
|
|
|
image_storage_service = providers.ThreadSafeSingleton(
|
|
|
|
services.ImageStorageService,
|
|
|
|
'./outputs/img-samples/'
|
|
|
|
)
|
|
|
|
|
|
|
|
# TODO: get location from config
|
|
|
|
image_intermediates_storage_service = providers.ThreadSafeSingleton(
|
|
|
|
services.ImageStorageService,
|
|
|
|
'./outputs/intermediates/'
|
|
|
|
)
|
|
|
|
|
|
|
|
signal_queue_service = providers.ThreadSafeSingleton(
|
|
|
|
services.SignalQueueService
|
|
|
|
)
|
|
|
|
|
|
|
|
signal_service = providers.ThreadSafeSingleton(
|
|
|
|
services.SignalService,
|
|
|
|
socketio = socketio,
|
|
|
|
queue = signal_queue_service
|
|
|
|
)
|
|
|
|
|
|
|
|
generation_queue_service = providers.ThreadSafeSingleton(
|
|
|
|
services.JobQueueService
|
|
|
|
)
|
|
|
|
|
|
|
|
# TODO: get locations from config
|
|
|
|
log_service = providers.ThreadSafeSingleton(
|
|
|
|
services.LogService,
|
|
|
|
'./outputs/img-samples/',
|
|
|
|
'dream_web_log.txt'
|
|
|
|
)
|
|
|
|
|
|
|
|
generator_service = providers.ThreadSafeSingleton(
|
|
|
|
services.GeneratorService,
|
|
|
|
model = model_singleton,
|
|
|
|
queue = generation_queue_service,
|
|
|
|
imageStorage = image_storage_service,
|
|
|
|
intermediateStorage = image_intermediates_storage_service,
|
|
|
|
log = log_service,
|
|
|
|
signal_service = signal_service
|
|
|
|
)
|