mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(dev_reload): notice when files with Invocation classes are re-loaded [WIP]
This commit is contained in:
parent
877348af49
commit
751fe68d16
@ -1,11 +1,13 @@
|
||||
# Copyright (c) 2022-2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||
import asyncio
|
||||
import logging
|
||||
import mimetypes
|
||||
import socket
|
||||
from inspect import signature
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
@ -17,23 +19,16 @@ from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||
from pydantic.schema import schema
|
||||
from invokeai.app.services.graph import update_invocations_union
|
||||
|
||||
from .services.config import InvokeAIAppConfig
|
||||
from ..backend.util.logging import InvokeAILogger
|
||||
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||
import invokeai.frontend.web as web_dir
|
||||
import mimetypes
|
||||
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
from .api.dependencies import ApiDependencies
|
||||
from .api.routers import sessions, models, images, boards, board_images, app_info
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, BaseInvocationOutput, UIConfigBase
|
||||
|
||||
import torch
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||
from .services.config import InvokeAIAppConfig
|
||||
from ..backend.util.logging import InvokeAILogger
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
# noinspection PyUnresolvedReferences
|
||||
@ -223,7 +218,9 @@ def invoke_api():
|
||||
exc_info=e,
|
||||
)
|
||||
else:
|
||||
jurigged.watch(logger=InvokeAILogger.getLogger(name="jurigged").info)
|
||||
from invokeai.app.util.dev_reload import start_reloader
|
||||
|
||||
start_reloader()
|
||||
|
||||
port = find_port(app_config.port)
|
||||
if port != app_config.port:
|
||||
|
31
invokeai/app/util/dev_reload.py
Normal file
31
invokeai/app/util/dev_reload.py
Normal file
@ -0,0 +1,31 @@
|
||||
import jurigged
|
||||
from jurigged.codetools import ClassDefinition
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
logger = InvokeAILogger.getLogger(name=__name__)
|
||||
|
||||
|
||||
def reload_nodes(path: str, codefile: jurigged.CodeFile):
|
||||
"""Callback function for jurigged post-run events."""
|
||||
# Things we have access to here:
|
||||
# codefile.module:module - the module object associated with this file
|
||||
# codefile.module_name:str - the full module name (its key in sys.modules)
|
||||
# codefile.root:ModuleCode - an AST of the current source
|
||||
|
||||
# This is only reading top-level statements, not walking the whole AST, but class definition should be top-level, right?
|
||||
class_names = [statement.name for statement in codefile.root.children if isinstance(statement, ClassDefinition)]
|
||||
classes = [getattr(codefile.module, name) for name in class_names]
|
||||
invocations = [cls for cls in classes if issubclass(cls, BaseInvocation)]
|
||||
# outputs = [cls for cls in classes if issubclass(cls, BaseInvocationOutput)]
|
||||
|
||||
# We should assume jurigged has already replaced all references to methods of these classes,
|
||||
# but it hasn't re-executed any annotations on them (like @title or @tags).
|
||||
# We need to re-do anything that involved introspection like BaseInvocation.get_all_subclasses()
|
||||
logger.info("File reloaded: %s contains invocation classes %s", path, invocations)
|
||||
|
||||
|
||||
def start_reloader():
|
||||
watcher = jurigged.watch(logger=InvokeAILogger.getLogger(name="jurigged").info)
|
||||
watcher.postrun.register(reload_nodes, apply_history=False)
|
Loading…
Reference in New Issue
Block a user