feat(dev_reload): notice when files with Invocation classes are re-loaded [WIP]

This commit is contained in:
Kevin Turner 2023-08-25 16:23:23 -07:00
parent 877348af49
commit 751fe68d16
2 changed files with 41 additions and 13 deletions

View File

@ -1,11 +1,13 @@
# Copyright (c) 2022-2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
import asyncio
import logging
import mimetypes
import socket
from inspect import signature
from pathlib import Path
from typing import Literal
import torch
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
@ -17,23 +19,16 @@ from fastapi_events.middleware import EventHandlerASGIMiddleware
from pydantic.schema import schema
from invokeai.app.services.graph import update_invocations_union
from .services.config import InvokeAIAppConfig
from ..backend.util.logging import InvokeAILogger
from invokeai.version.invokeai_version import __version__
# noinspection PyUnresolvedReferences
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
import invokeai.frontend.web as web_dir
import mimetypes
from invokeai.version.invokeai_version import __version__
from .api.dependencies import ApiDependencies
from .api.routers import sessions, models, images, boards, board_images, app_info
from .api.sockets import SocketIO
from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, BaseInvocationOutput, UIConfigBase
import torch
# noinspection PyUnresolvedReferences
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
from .services.config import InvokeAIAppConfig
from ..backend.util.logging import InvokeAILogger
if torch.backends.mps.is_available():
# noinspection PyUnresolvedReferences
@ -223,7 +218,9 @@ def invoke_api():
exc_info=e,
)
else:
jurigged.watch(logger=InvokeAILogger.getLogger(name="jurigged").info)
from invokeai.app.util.dev_reload import start_reloader
start_reloader()
port = find_port(app_config.port)
if port != app_config.port:

View File

@ -0,0 +1,31 @@
import jurigged
from jurigged.codetools import ClassDefinition
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.backend.util.logging import InvokeAILogger
logger = InvokeAILogger.getLogger(name=__name__)
def reload_nodes(path: str, codefile: jurigged.CodeFile):
"""Callback function for jurigged post-run events."""
# Things we have access to here:
# codefile.module:module - the module object associated with this file
# codefile.module_name:str - the full module name (its key in sys.modules)
# codefile.root:ModuleCode - an AST of the current source
# This is only reading top-level statements, not walking the whole AST, but class definition should be top-level, right?
class_names = [statement.name for statement in codefile.root.children if isinstance(statement, ClassDefinition)]
classes = [getattr(codefile.module, name) for name in class_names]
invocations = [cls for cls in classes if issubclass(cls, BaseInvocation)]
# outputs = [cls for cls in classes if issubclass(cls, BaseInvocationOutput)]
# We should assume jurigged has already replaced all references to methods of these classes,
# but it hasn't re-executed any annotations on them (like @title or @tags).
# We need to re-do anything that involved introspection like BaseInvocation.get_all_subclasses()
logger.info("File reloaded: %s contains invocation classes %s", path, invocations)
def start_reloader():
watcher = jurigged.watch(logger=InvokeAILogger.getLogger(name="jurigged").info)
watcher.postrun.register(reload_nodes, apply_history=False)