# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) """Containers module.""" from dependency_injector import containers, providers from flask_socketio import SocketIO from ldm.generate import Generate from server import services class Container(containers.DeclarativeContainer): wiring_config = containers.WiringConfiguration(packages=['server']) config = providers.Configuration() socketio = providers.ThreadSafeSingleton( SocketIO, app = None ) # TODO: Add a model provider service that provides model(s) dynamically model_singleton = providers.ThreadSafeSingleton( Generate, model = config.model, sampler_name = config.sampler_name, embedding_path = config.embedding_path, full_precision = config.full_precision # config = config.model.config, # width = config.model.width, # height = config.model.height, # sampler_name = config.model.sampler_name, # weights = config.model.weights, # full_precision = config.model.full_precision, # grid = config.model.grid, # seamless = config.model.seamless, # embedding_path = config.model.embedding_path, # device_type = config.model.device_type ) # TODO: get location from config image_storage_service = providers.ThreadSafeSingleton( services.ImageStorageService, './outputs/img-samples/' ) # TODO: get location from config image_intermediates_storage_service = providers.ThreadSafeSingleton( services.ImageStorageService, './outputs/intermediates/' ) signal_queue_service = providers.ThreadSafeSingleton( services.SignalQueueService ) signal_service = providers.ThreadSafeSingleton( services.SignalService, socketio = socketio, queue = signal_queue_service ) generation_queue_service = providers.ThreadSafeSingleton( services.JobQueueService ) # TODO: get locations from config log_service = providers.ThreadSafeSingleton( services.LogService, './outputs/img-samples/', 'dream_web_log.txt' ) generator_service = providers.ThreadSafeSingleton( services.GeneratorService, model = model_singleton, queue = generation_queue_service, imageStorage = image_storage_service, intermediateStorage = image_intermediates_storage_service, log = log_service, signal_service = signal_service )