mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Move graph processor into session_processor_default
This commit is contained in:
parent
afa4df1991
commit
7d5a88b69d
@ -1,14 +1,19 @@
|
||||
import traceback
|
||||
from threading import BoundedSemaphore, Thread, Event as ThreadEvent
|
||||
from typing import Optional
|
||||
from contextlib import suppress
|
||||
from threading import BoundedSemaphore, Thread
|
||||
from threading import Event as ThreadEvent
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.typing import Event as FastAPIEvent
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.shared.graph_processor import GraphProcessor
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
|
||||
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context
|
||||
from invokeai.app.util.profiler import Profiler
|
||||
|
||||
from ..invoker import Invoker
|
||||
@ -16,8 +21,160 @@ from .session_processor_base import SessionProcessorBase
|
||||
from .session_processor_common import SessionProcessorStatus
|
||||
|
||||
|
||||
class GraphProcessor:
|
||||
"""Process a graph of invocations"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
services: InvocationServices,
|
||||
cancel_event: ThreadEvent,
|
||||
profiler: Union[Profiler, None] = None,
|
||||
on_before_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
|
||||
on_after_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
|
||||
):
|
||||
self.services = services
|
||||
self.profiler = profiler
|
||||
self.cancel_event = cancel_event
|
||||
self.on_before_run_node = on_before_run_node
|
||||
self.on_after_run_node = on_after_run_node
|
||||
|
||||
def run(self, queue_item: SessionQueueItem):
|
||||
"""Run the graph"""
|
||||
if not queue_item.session:
|
||||
raise ValueError("Queue item has no session")
|
||||
# If profiling is enabled, start the profiler
|
||||
if self.profiler is not None:
|
||||
self.profiler.start(profile_id=queue_item.session_id)
|
||||
# Loop over invocations until the session is complete or canceled
|
||||
while not (queue_item.session.is_complete() or self.cancel_event.is_set()):
|
||||
# Prepare the next node
|
||||
invocation = queue_item.session.next()
|
||||
if invocation is None:
|
||||
# If there are no more invocations, complete the graph
|
||||
break
|
||||
# Build invocation context (the node-facing API
|
||||
self.run_node(invocation, queue_item)
|
||||
self.complete(queue_item)
|
||||
|
||||
def complete(self, queue_item: SessionQueueItem):
|
||||
"""Complete the graph"""
|
||||
self.services.events.emit_graph_execution_complete(
|
||||
queue_batch_id=queue_item.batch_id,
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
graph_execution_state_id=queue_item.session.id,
|
||||
)
|
||||
# If we are profiling, stop the profiler and dump the profile & stats
|
||||
if self.profiler:
|
||||
profile_path = self.profiler.stop()
|
||||
stats_path = profile_path.with_suffix(".json")
|
||||
self.services.performance_statistics.dump_stats(
|
||||
graph_execution_state_id=queue_item.session.id, output_path=stats_path
|
||||
)
|
||||
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
|
||||
# we don't care about that - suppress the error.
|
||||
with suppress(GESStatsNotFoundError):
|
||||
self.services.performance_statistics.log_stats(queue_item.session.id)
|
||||
self.services.performance_statistics.reset_stats()
|
||||
|
||||
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
||||
"""Run a single node in the graph"""
|
||||
# If we have a on_before_run_node callback, call it
|
||||
if self.on_before_run_node is not None:
|
||||
self.on_before_run_node(invocation, queue_item)
|
||||
try:
|
||||
data = InvocationContextData(
|
||||
invocation=invocation,
|
||||
source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
queue_item=queue_item,
|
||||
)
|
||||
|
||||
# Send starting event
|
||||
self.services.events.emit_invocation_started(
|
||||
queue_batch_id=queue_item.batch_id,
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
graph_execution_state_id=queue_item.session_id,
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=data.source_invocation_id,
|
||||
)
|
||||
|
||||
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
|
||||
with self.services.performance_statistics.collect_stats(invocation, queue_item.session_id):
|
||||
context = build_invocation_context(
|
||||
data=data,
|
||||
services=self.services,
|
||||
cancel_event=self.cancel_event,
|
||||
)
|
||||
|
||||
# Invoke the node
|
||||
outputs = invocation.invoke_internal(context=context, services=self.services)
|
||||
|
||||
# Save outputs and history
|
||||
queue_item.session.complete(invocation.id, outputs)
|
||||
|
||||
# Send complete event
|
||||
self.services.events.emit_invocation_complete(
|
||||
queue_batch_id=queue_item.batch_id,
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
graph_execution_state_id=queue_item.session.id,
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=data.source_invocation_id,
|
||||
result=outputs.model_dump(),
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
# TODO(MM2): Create an event for this
|
||||
pass
|
||||
except CanceledException:
|
||||
# When the user cancels the graph, we first set the cancel event. The event is checked
|
||||
# between invocations, in this loop. Some invocations are long-running, and we need to
|
||||
# be able to cancel them mid-execution.
|
||||
#
|
||||
# For example, denoising is a long-running invocation with many steps. A step callback
|
||||
# is executed after each step. This step callback checks if the canceled event is set,
|
||||
# then raises a CanceledException to stop execution immediately.
|
||||
#
|
||||
# When we get a CanceledException, we don't need to do anything - just pass and let the
|
||||
# loop go to its next iteration, and the cancel event will be handled correctly.
|
||||
pass
|
||||
except Exception as e:
|
||||
error = traceback.format_exc()
|
||||
|
||||
# Save error
|
||||
queue_item.session.set_node_error(invocation.id, error)
|
||||
self.services.logger.error(
|
||||
f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}):\n{e}"
|
||||
)
|
||||
self.services.logger.error(error)
|
||||
|
||||
# Send error event
|
||||
self.services.events.emit_invocation_error(
|
||||
queue_batch_id=queue_item.session_id,
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
graph_execution_state_id=queue_item.session.id,
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
error_type=e.__class__.__name__,
|
||||
error=error,
|
||||
)
|
||||
pass
|
||||
finally:
|
||||
# If we have a on_after_run_node callback, call it
|
||||
if self.on_after_run_node is not None:
|
||||
self.on_after_run_node(invocation, queue_item)
|
||||
|
||||
|
||||
class DefaultSessionProcessor(SessionProcessorBase):
|
||||
def start(self, invoker: Invoker, thread_limit: int = 1, polling_interval: int = 1) -> None:
|
||||
def start(
|
||||
self,
|
||||
invoker: Invoker,
|
||||
thread_limit: int = 1,
|
||||
polling_interval: int = 1,
|
||||
on_before_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
|
||||
on_after_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
|
||||
) -> None:
|
||||
self._invoker: Invoker = invoker
|
||||
self._queue_item: Optional[SessionQueueItem] = None
|
||||
self._invocation: Optional[BaseInvocation] = None
|
||||
@ -49,6 +206,8 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
services=self._invoker.services,
|
||||
cancel_event=self._cancel_event,
|
||||
profiler=self._profiler,
|
||||
on_before_run_node=on_before_run_node,
|
||||
on_after_run_node=on_after_run_node,
|
||||
)
|
||||
|
||||
self._thread = Thread(
|
||||
@ -154,3 +313,4 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
poll_now_event.clear()
|
||||
self._queue_item = None
|
||||
self._thread_semaphore.release()
|
||||
self._invoker.services.logger.debug("Session processor stopped")
|
||||
|
@ -1,161 +0,0 @@
|
||||
import traceback
|
||||
|
||||
from contextlib import suppress
|
||||
from threading import Event
|
||||
from typing import Callable, Union
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context
|
||||
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
|
||||
from invokeai.app.util.profiler import Profiler
|
||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
||||
|
||||
class GraphProcessor:
|
||||
"""Process a graph of invocations"""
|
||||
def __init__(
|
||||
self,
|
||||
services: InvocationServices,
|
||||
cancel_event: Event,
|
||||
profiler: Union[Profiler, None] = None,
|
||||
on_before_run_node: Union[Callable[[BaseInvocation,SessionQueueItem], bool], None] = None,
|
||||
on_after_run_node: Union[Callable[[BaseInvocation,SessionQueueItem], bool], None] = None,
|
||||
):
|
||||
self.services = services
|
||||
self.profiler = profiler
|
||||
self.cancel_event = cancel_event
|
||||
self.on_before_run_node = on_before_run_node
|
||||
self.on_after_run_node = on_after_run_node
|
||||
|
||||
def run(self, queue_item: SessionQueueItem):
|
||||
"""Run the graph"""
|
||||
if not queue_item.session:
|
||||
raise ValueError("Queue item has no session")
|
||||
# If profiling is enabled, start the profiler
|
||||
if self.profiler is not None:
|
||||
self.profiler.start(profile_id=queue_item.session_id)
|
||||
# Loop over invocations until the session is complete or canceled
|
||||
while not (queue_item.session.is_complete() or self.cancel_event.is_set()):
|
||||
# Prepare the next node
|
||||
invocation = queue_item.session.next()
|
||||
if invocation is None:
|
||||
# If there are no more invocations, complete the graph
|
||||
break
|
||||
# Build invocation context (the node-facing API
|
||||
self.run_node(invocation, queue_item)
|
||||
self.complete(queue_item)
|
||||
|
||||
def complete(self, queue_item: SessionQueueItem):
|
||||
"""Complete the graph"""
|
||||
self.services.events.emit_graph_execution_complete(
|
||||
queue_batch_id=queue_item.batch_id,
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
graph_execution_state_id=queue_item.session.id,
|
||||
)
|
||||
# If we are profiling, stop the profiler and dump the profile & stats
|
||||
if self.profiler:
|
||||
profile_path = self.profiler.stop()
|
||||
stats_path = profile_path.with_suffix(".json")
|
||||
self.services.performance_statistics.dump_stats(
|
||||
graph_execution_state_id=queue_item.session.id, output_path=stats_path
|
||||
)
|
||||
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
|
||||
# we don't care about that - suppress the error.
|
||||
with suppress(GESStatsNotFoundError):
|
||||
self.services.performance_statistics.log_stats(queue_item.session.id)
|
||||
self.services.performance_statistics.reset_stats()
|
||||
|
||||
|
||||
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
||||
"""Run a single node in the graph"""
|
||||
# If we have a on_before_run_node callback, call it
|
||||
if self.on_before_run_node is not None:
|
||||
self.on_before_run_node(invocation, queue_item)
|
||||
try:
|
||||
data = InvocationContextData(
|
||||
invocation=invocation,
|
||||
source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
queue_item=queue_item,
|
||||
)
|
||||
|
||||
# Send starting event
|
||||
self.services.events.emit_invocation_started(
|
||||
queue_batch_id=queue_item.batch_id,
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
graph_execution_state_id=queue_item.session_id,
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=data.source_invocation_id,
|
||||
)
|
||||
|
||||
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
|
||||
with self.services.performance_statistics.collect_stats(
|
||||
invocation, queue_item.session_id
|
||||
):
|
||||
context = build_invocation_context(
|
||||
data=data,
|
||||
services=self.services,
|
||||
cancel_event=self.cancel_event,
|
||||
)
|
||||
|
||||
# Invoke the node
|
||||
outputs = invocation.invoke_internal(
|
||||
context=context, services=self.services
|
||||
)
|
||||
|
||||
# Save outputs and history
|
||||
queue_item.session.complete(invocation.id, outputs)
|
||||
|
||||
# Send complete event
|
||||
self.services.events.emit_invocation_complete(
|
||||
queue_batch_id=queue_item.batch_id,
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
graph_execution_state_id=queue_item.session.id,
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=data.source_invocation_id,
|
||||
result=outputs.model_dump(),
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
# TODO(MM2): Create an event for this
|
||||
self.cancel_event.set()
|
||||
except CanceledException:
|
||||
# When the user cancels the graph, we first set the cancel event. The event is checked
|
||||
# between invocations, in this loop. Some invocations are long-running, and we need to
|
||||
# be able to cancel them mid-execution.
|
||||
#
|
||||
# For example, denoising is a long-running invocation with many steps. A step callback
|
||||
# is executed after each step. This step callback checks if the canceled event is set,
|
||||
# then raises a CanceledException to stop execution immediately.
|
||||
#
|
||||
# When we get a CanceledException, we don't need to do anything - just pass and let the
|
||||
# loop go to its next iteration, and the cancel event will be handled correctly.
|
||||
pass
|
||||
except Exception as e:
|
||||
error = traceback.format_exc()
|
||||
|
||||
# Save error
|
||||
queue_item.session.set_node_error(invocation.id, error)
|
||||
self.services.logger.error(
|
||||
f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}):\n{e}"
|
||||
)
|
||||
self.services.logger.error(error)
|
||||
|
||||
# Send error event
|
||||
self.services.events.emit_invocation_error(
|
||||
queue_batch_id=queue_item.session_id,
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
graph_execution_state_id=queue_item.session.id,
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
error_type=e.__class__.__name__,
|
||||
error=error,
|
||||
)
|
||||
pass
|
||||
finally:
|
||||
# If we have a on_after_run_node callback, call it
|
||||
if self.on_after_run_node is not None:
|
||||
self.on_after_run_node(invocation, queue_item)
|
@ -17,7 +17,8 @@ class MigrateCallback(Protocol):
|
||||
See :class:`Migration` for an example.
|
||||
"""
|
||||
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None: ...
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
...
|
||||
|
||||
|
||||
class MigrationError(RuntimeError):
|
||||
|
@ -858,9 +858,9 @@ def do_textual_inversion_training(
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
|
||||
with torch.no_grad():
|
||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
|
||||
orig_embeds_params[index_no_updates]
|
||||
)
|
||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
|
||||
index_no_updates
|
||||
] = orig_embeds_params[index_no_updates]
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
|
Loading…
Reference in New Issue
Block a user