diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index 6e291a4c0f..20922b64d3 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -1,14 +1,19 @@ import traceback -from threading import BoundedSemaphore, Thread, Event as ThreadEvent -from typing import Optional +from contextlib import suppress +from threading import BoundedSemaphore, Thread +from threading import Event as ThreadEvent +from typing import Callable, Optional, Union from fastapi_events.handlers.local import local_handler from fastapi_events.typing import Event as FastAPIEvent from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.services.events.events_base import EventServiceBase -from invokeai.app.services.shared.graph_processor import GraphProcessor +from invokeai.app.services.invocation_services import InvocationServices +from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError +from invokeai.app.services.session_processor.session_processor_common import CanceledException from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem +from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context from invokeai.app.util.profiler import Profiler from ..invoker import Invoker @@ -16,8 +21,160 @@ from .session_processor_base import SessionProcessorBase from .session_processor_common import SessionProcessorStatus +class GraphProcessor: + """Process a graph of invocations""" + + def __init__( + self, + services: InvocationServices, + cancel_event: ThreadEvent, + profiler: Union[Profiler, None] = None, + on_before_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None, + on_after_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None, + ): + self.services = services + self.profiler = profiler + self.cancel_event = cancel_event + self.on_before_run_node = on_before_run_node + self.on_after_run_node = on_after_run_node + + def run(self, queue_item: SessionQueueItem): + """Run the graph""" + if not queue_item.session: + raise ValueError("Queue item has no session") + # If profiling is enabled, start the profiler + if self.profiler is not None: + self.profiler.start(profile_id=queue_item.session_id) + # Loop over invocations until the session is complete or canceled + while not (queue_item.session.is_complete() or self.cancel_event.is_set()): + # Prepare the next node + invocation = queue_item.session.next() + if invocation is None: + # If there are no more invocations, complete the graph + break + # Build invocation context (the node-facing API + self.run_node(invocation, queue_item) + self.complete(queue_item) + + def complete(self, queue_item: SessionQueueItem): + """Complete the graph""" + self.services.events.emit_graph_execution_complete( + queue_batch_id=queue_item.batch_id, + queue_item_id=queue_item.item_id, + queue_id=queue_item.queue_id, + graph_execution_state_id=queue_item.session.id, + ) + # If we are profiling, stop the profiler and dump the profile & stats + if self.profiler: + profile_path = self.profiler.stop() + stats_path = profile_path.with_suffix(".json") + self.services.performance_statistics.dump_stats( + graph_execution_state_id=queue_item.session.id, output_path=stats_path + ) + # We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor + # we don't care about that - suppress the error. + with suppress(GESStatsNotFoundError): + self.services.performance_statistics.log_stats(queue_item.session.id) + self.services.performance_statistics.reset_stats() + + def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): + """Run a single node in the graph""" + # If we have a on_before_run_node callback, call it + if self.on_before_run_node is not None: + self.on_before_run_node(invocation, queue_item) + try: + data = InvocationContextData( + invocation=invocation, + source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id], + queue_item=queue_item, + ) + + # Send starting event + self.services.events.emit_invocation_started( + queue_batch_id=queue_item.batch_id, + queue_item_id=queue_item.item_id, + queue_id=queue_item.queue_id, + graph_execution_state_id=queue_item.session_id, + node=invocation.model_dump(), + source_node_id=data.source_invocation_id, + ) + + # Innermost processor try block; any unhandled exception is an invocation error & will fail the graph + with self.services.performance_statistics.collect_stats(invocation, queue_item.session_id): + context = build_invocation_context( + data=data, + services=self.services, + cancel_event=self.cancel_event, + ) + + # Invoke the node + outputs = invocation.invoke_internal(context=context, services=self.services) + + # Save outputs and history + queue_item.session.complete(invocation.id, outputs) + + # Send complete event + self.services.events.emit_invocation_complete( + queue_batch_id=queue_item.batch_id, + queue_item_id=queue_item.item_id, + queue_id=queue_item.queue_id, + graph_execution_state_id=queue_item.session.id, + node=invocation.model_dump(), + source_node_id=data.source_invocation_id, + result=outputs.model_dump(), + ) + except KeyboardInterrupt: + # TODO(MM2): Create an event for this + pass + except CanceledException: + # When the user cancels the graph, we first set the cancel event. The event is checked + # between invocations, in this loop. Some invocations are long-running, and we need to + # be able to cancel them mid-execution. + # + # For example, denoising is a long-running invocation with many steps. A step callback + # is executed after each step. This step callback checks if the canceled event is set, + # then raises a CanceledException to stop execution immediately. + # + # When we get a CanceledException, we don't need to do anything - just pass and let the + # loop go to its next iteration, and the cancel event will be handled correctly. + pass + except Exception as e: + error = traceback.format_exc() + + # Save error + queue_item.session.set_node_error(invocation.id, error) + self.services.logger.error( + f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}):\n{e}" + ) + self.services.logger.error(error) + + # Send error event + self.services.events.emit_invocation_error( + queue_batch_id=queue_item.session_id, + queue_item_id=queue_item.item_id, + queue_id=queue_item.queue_id, + graph_execution_state_id=queue_item.session.id, + node=invocation.model_dump(), + source_node_id=queue_item.session.prepared_source_mapping[invocation.id], + error_type=e.__class__.__name__, + error=error, + ) + pass + finally: + # If we have a on_after_run_node callback, call it + if self.on_after_run_node is not None: + self.on_after_run_node(invocation, queue_item) + + class DefaultSessionProcessor(SessionProcessorBase): - def start(self, invoker: Invoker, thread_limit: int = 1, polling_interval: int = 1) -> None: + def start( + self, + invoker: Invoker, + thread_limit: int = 1, + polling_interval: int = 1, + on_before_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None, + on_after_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None, + ) -> None: self._invoker: Invoker = invoker self._queue_item: Optional[SessionQueueItem] = None self._invocation: Optional[BaseInvocation] = None @@ -49,6 +206,8 @@ class DefaultSessionProcessor(SessionProcessorBase): services=self._invoker.services, cancel_event=self._cancel_event, profiler=self._profiler, + on_before_run_node=on_before_run_node, + on_after_run_node=on_after_run_node, ) self._thread = Thread( @@ -154,3 +313,4 @@ class DefaultSessionProcessor(SessionProcessorBase): poll_now_event.clear() self._queue_item = None self._thread_semaphore.release() + self._invoker.services.logger.debug("Session processor stopped") diff --git a/invokeai/app/services/shared/graph_processor.py b/invokeai/app/services/shared/graph_processor.py deleted file mode 100644 index 763a317f7c..0000000000 --- a/invokeai/app/services/shared/graph_processor.py +++ /dev/null @@ -1,161 +0,0 @@ -import traceback - -from contextlib import suppress -from threading import Event -from typing import Callable, Union - -from invokeai.app.invocations.baseinvocation import BaseInvocation -from invokeai.app.services.invocation_services import InvocationServices -from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context -from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError -from invokeai.app.util.profiler import Profiler -from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem -from invokeai.app.services.session_processor.session_processor_common import CanceledException - -class GraphProcessor: - """Process a graph of invocations""" - def __init__( - self, - services: InvocationServices, - cancel_event: Event, - profiler: Union[Profiler, None] = None, - on_before_run_node: Union[Callable[[BaseInvocation,SessionQueueItem], bool], None] = None, - on_after_run_node: Union[Callable[[BaseInvocation,SessionQueueItem], bool], None] = None, - ): - self.services = services - self.profiler = profiler - self.cancel_event = cancel_event - self.on_before_run_node = on_before_run_node - self.on_after_run_node = on_after_run_node - - def run(self, queue_item: SessionQueueItem): - """Run the graph""" - if not queue_item.session: - raise ValueError("Queue item has no session") - # If profiling is enabled, start the profiler - if self.profiler is not None: - self.profiler.start(profile_id=queue_item.session_id) - # Loop over invocations until the session is complete or canceled - while not (queue_item.session.is_complete() or self.cancel_event.is_set()): - # Prepare the next node - invocation = queue_item.session.next() - if invocation is None: - # If there are no more invocations, complete the graph - break - # Build invocation context (the node-facing API - self.run_node(invocation, queue_item) - self.complete(queue_item) - - def complete(self, queue_item: SessionQueueItem): - """Complete the graph""" - self.services.events.emit_graph_execution_complete( - queue_batch_id=queue_item.batch_id, - queue_item_id=queue_item.item_id, - queue_id=queue_item.queue_id, - graph_execution_state_id=queue_item.session.id, - ) - # If we are profiling, stop the profiler and dump the profile & stats - if self.profiler: - profile_path = self.profiler.stop() - stats_path = profile_path.with_suffix(".json") - self.services.performance_statistics.dump_stats( - graph_execution_state_id=queue_item.session.id, output_path=stats_path - ) - # We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor - # we don't care about that - suppress the error. - with suppress(GESStatsNotFoundError): - self.services.performance_statistics.log_stats(queue_item.session.id) - self.services.performance_statistics.reset_stats() - - - def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): - """Run a single node in the graph""" - # If we have a on_before_run_node callback, call it - if self.on_before_run_node is not None: - self.on_before_run_node(invocation, queue_item) - try: - data = InvocationContextData( - invocation=invocation, - source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id], - queue_item=queue_item, - ) - - # Send starting event - self.services.events.emit_invocation_started( - queue_batch_id=queue_item.batch_id, - queue_item_id=queue_item.item_id, - queue_id=queue_item.queue_id, - graph_execution_state_id=queue_item.session_id, - node=invocation.model_dump(), - source_node_id=data.source_invocation_id, - ) - - # Innermost processor try block; any unhandled exception is an invocation error & will fail the graph - with self.services.performance_statistics.collect_stats( - invocation, queue_item.session_id - ): - context = build_invocation_context( - data=data, - services=self.services, - cancel_event=self.cancel_event, - ) - - # Invoke the node - outputs = invocation.invoke_internal( - context=context, services=self.services - ) - - # Save outputs and history - queue_item.session.complete(invocation.id, outputs) - - # Send complete event - self.services.events.emit_invocation_complete( - queue_batch_id=queue_item.batch_id, - queue_item_id=queue_item.item_id, - queue_id=queue_item.queue_id, - graph_execution_state_id=queue_item.session.id, - node=invocation.model_dump(), - source_node_id=data.source_invocation_id, - result=outputs.model_dump(), - ) - except KeyboardInterrupt: - # TODO(MM2): Create an event for this - self.cancel_event.set() - except CanceledException: - # When the user cancels the graph, we first set the cancel event. The event is checked - # between invocations, in this loop. Some invocations are long-running, and we need to - # be able to cancel them mid-execution. - # - # For example, denoising is a long-running invocation with many steps. A step callback - # is executed after each step. This step callback checks if the canceled event is set, - # then raises a CanceledException to stop execution immediately. - # - # When we get a CanceledException, we don't need to do anything - just pass and let the - # loop go to its next iteration, and the cancel event will be handled correctly. - pass - except Exception as e: - error = traceback.format_exc() - - # Save error - queue_item.session.set_node_error(invocation.id, error) - self.services.logger.error( - f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}):\n{e}" - ) - self.services.logger.error(error) - - # Send error event - self.services.events.emit_invocation_error( - queue_batch_id=queue_item.session_id, - queue_item_id=queue_item.item_id, - queue_id=queue_item.queue_id, - graph_execution_state_id=queue_item.session.id, - node=invocation.model_dump(), - source_node_id=queue_item.session.prepared_source_mapping[invocation.id], - error_type=e.__class__.__name__, - error=error, - ) - pass - finally: - # If we have a on_after_run_node callback, call it - if self.on_after_run_node is not None: - self.on_after_run_node(invocation, queue_item) diff --git a/invokeai/app/services/shared/sqlite_migrator/sqlite_migrator_common.py b/invokeai/app/services/shared/sqlite_migrator/sqlite_migrator_common.py index 9b2444dae4..47ed5da505 100644 --- a/invokeai/app/services/shared/sqlite_migrator/sqlite_migrator_common.py +++ b/invokeai/app/services/shared/sqlite_migrator/sqlite_migrator_common.py @@ -17,7 +17,8 @@ class MigrateCallback(Protocol): See :class:`Migration` for an example. """ - def __call__(self, cursor: sqlite3.Cursor) -> None: ... + def __call__(self, cursor: sqlite3.Cursor) -> None: + ... class MigrationError(RuntimeError): diff --git a/invokeai/backend/training/textual_inversion_training.py b/invokeai/backend/training/textual_inversion_training.py index 7ddcf14367..9a38c006a5 100644 --- a/invokeai/backend/training/textual_inversion_training.py +++ b/invokeai/backend/training/textual_inversion_training.py @@ -858,9 +858,9 @@ def do_textual_inversion_training( # Let's make sure we don't update any embedding weights besides the newly added token index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id with torch.no_grad(): - accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = ( - orig_embeds_params[index_no_updates] - ) + accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[ + index_no_updates + ] = orig_embeds_params[index_no_updates] # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: