mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
b7ffd36cc6
Replace `delete_on_startup: bool` & associated logic with `ephemeral: bool` and `TemporaryDirectory`. The temp dir is created inside of `output_dir`. For example, if `output_dir` is `invokeai/outputs/tensors/`, then the temp dir might be `invokeai/outputs/tensors/tmpvj35ht7b/`. The temp dir is cleaned up when the service is stopped, or when it is GC'd if not properly stopped. In the event of a catastrophic crash where the temp files are not cleaned up, the user can delete the tempdir themselves. This situation may not occur in normal use, but if you kill the process, python cannot clean up the temp dir itself. This includes running the app in a debugger and killing the debugger process - something I do relatively often. Tests updated.
155 lines
6.5 KiB
Python
155 lines
6.5 KiB
Python
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
|
|
|
from logging import Logger
|
|
|
|
import torch
|
|
|
|
from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory
|
|
from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk
|
|
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
|
|
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
|
from invokeai.backend.model_manager.metadata import ModelMetadataStore
|
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
|
from invokeai.backend.util.logging import InvokeAILogger
|
|
from invokeai.version.invokeai_version import __version__
|
|
|
|
from ..services.board_image_records.board_image_records_sqlite import SqliteBoardImageRecordStorage
|
|
from ..services.board_images.board_images_default import BoardImagesService
|
|
from ..services.board_records.board_records_sqlite import SqliteBoardRecordStorage
|
|
from ..services.boards.boards_default import BoardService
|
|
from ..services.config import InvokeAIAppConfig
|
|
from ..services.download import DownloadQueueService
|
|
from ..services.image_files.image_files_disk import DiskImageFileStorage
|
|
from ..services.image_records.image_records_sqlite import SqliteImageRecordStorage
|
|
from ..services.images.images_default import ImageService
|
|
from ..services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
|
|
from ..services.invocation_processor.invocation_processor_default import DefaultInvocationProcessor
|
|
from ..services.invocation_queue.invocation_queue_memory import MemoryInvocationQueue
|
|
from ..services.invocation_services import InvocationServices
|
|
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
|
|
from ..services.invoker import Invoker
|
|
from ..services.model_install import ModelInstallService
|
|
from ..services.model_manager.model_manager_default import ModelManagerService
|
|
from ..services.model_records import ModelRecordServiceSQL
|
|
from ..services.names.names_default import SimpleNameService
|
|
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
|
|
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
|
from ..services.shared.graph import GraphExecutionState
|
|
from ..services.urls.urls_default import LocalUrlService
|
|
from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
|
|
from .events import FastAPIEventService
|
|
|
|
|
|
# TODO: is there a better way to achieve this?
|
|
def check_internet() -> bool:
|
|
"""
|
|
Return true if the internet is reachable.
|
|
It does this by pinging huggingface.co.
|
|
"""
|
|
import urllib.request
|
|
|
|
host = "http://huggingface.co"
|
|
try:
|
|
urllib.request.urlopen(host, timeout=1)
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
logger = InvokeAILogger.get_logger()
|
|
|
|
|
|
class ApiDependencies:
|
|
"""Contains and initializes all dependencies for the API"""
|
|
|
|
invoker: Invoker
|
|
|
|
@staticmethod
|
|
def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger) -> None:
|
|
logger.info(f"InvokeAI version {__version__}")
|
|
logger.info(f"Root directory = {str(config.root_path)}")
|
|
logger.debug(f"Internet connectivity is {config.internet_available}")
|
|
|
|
output_folder = config.output_path
|
|
if output_folder is None:
|
|
raise ValueError("Output folder is not set")
|
|
|
|
image_files = DiskImageFileStorage(f"{output_folder}/images")
|
|
|
|
db = init_db(config=config, logger=logger, image_files=image_files)
|
|
|
|
configuration = config
|
|
logger = logger
|
|
|
|
board_image_records = SqliteBoardImageRecordStorage(db=db)
|
|
board_images = BoardImagesService()
|
|
board_records = SqliteBoardRecordStorage(db=db)
|
|
boards = BoardService()
|
|
events = FastAPIEventService(event_handler_id)
|
|
graph_execution_manager = ItemStorageMemory[GraphExecutionState]()
|
|
image_records = SqliteImageRecordStorage(db=db)
|
|
images = ImageService()
|
|
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
|
|
tensors = ObjectSerializerForwardCache(
|
|
ObjectSerializerDisk[torch.Tensor](output_folder / "tensors", ephemeral=True)
|
|
)
|
|
conditioning = ObjectSerializerForwardCache(
|
|
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
|
|
)
|
|
model_manager = ModelManagerService(config, logger)
|
|
model_record_service = ModelRecordServiceSQL(db=db)
|
|
download_queue_service = DownloadQueueService(event_bus=events)
|
|
metadata_store = ModelMetadataStore(db=db)
|
|
model_install_service = ModelInstallService(
|
|
app_config=config,
|
|
record_store=model_record_service,
|
|
download_queue=download_queue_service,
|
|
metadata_store=metadata_store,
|
|
event_bus=events,
|
|
)
|
|
names = SimpleNameService()
|
|
performance_statistics = InvocationStatsService()
|
|
processor = DefaultInvocationProcessor()
|
|
queue = MemoryInvocationQueue()
|
|
session_processor = DefaultSessionProcessor()
|
|
session_queue = SqliteSessionQueue(db=db)
|
|
urls = LocalUrlService()
|
|
workflow_records = SqliteWorkflowRecordsStorage(db=db)
|
|
|
|
services = InvocationServices(
|
|
board_image_records=board_image_records,
|
|
board_images=board_images,
|
|
board_records=board_records,
|
|
boards=boards,
|
|
configuration=configuration,
|
|
events=events,
|
|
graph_execution_manager=graph_execution_manager,
|
|
image_files=image_files,
|
|
image_records=image_records,
|
|
images=images,
|
|
invocation_cache=invocation_cache,
|
|
logger=logger,
|
|
model_manager=model_manager,
|
|
model_records=model_record_service,
|
|
download_queue=download_queue_service,
|
|
model_install=model_install_service,
|
|
names=names,
|
|
performance_statistics=performance_statistics,
|
|
processor=processor,
|
|
queue=queue,
|
|
session_processor=session_processor,
|
|
session_queue=session_queue,
|
|
urls=urls,
|
|
workflow_records=workflow_records,
|
|
tensors=tensors,
|
|
conditioning=conditioning,
|
|
)
|
|
|
|
ApiDependencies.invoker = Invoker(services)
|
|
db.clean()
|
|
|
|
@staticmethod
|
|
def shutdown() -> None:
|
|
if ApiDependencies.invoker:
|
|
ApiDependencies.invoker.stop()
|