Merge branch 'main' into refactor/rename-performance-options

This commit is contained in:
Lincoln Stein 2023-08-21 19:47:55 -04:00 committed by GitHub
commit 9d7dfeb857
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 141 additions and 59 deletions

View File

@ -71,6 +71,9 @@ class FieldDescriptions:
safe_mode = "Whether or not to use safe mode" safe_mode = "Whether or not to use safe mode"
scribble_mode = "Whether or not to use scribble mode" scribble_mode = "Whether or not to use scribble mode"
scale_factor = "The factor by which to scale" scale_factor = "The factor by which to scale"
blend_alpha = (
"Blending factor. 0.0 = use input A only, 1.0 = use input B only, 0.5 = 50% mix of input A and input B."
)
num_1 = "The first number" num_1 = "The first number"
num_2 = "The second number" num_2 = "The second number"
mask = "The mask to use for the operation" mask = "The mask to use for the operation"

View File

@ -233,7 +233,7 @@ class SDXLPromptInvocationBase:
dtype_for_device_getter=torch_dtype, dtype_for_device_getter=torch_dtype,
truncate_long_prompts=True, # TODO: truncate_long_prompts=True, # TODO:
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
requires_pooled=True, requires_pooled=get_pooled,
) )
conjunction = Compel.parse_prompt_string(prompt) conjunction = Compel.parse_prompt_string(prompt)

View File

@ -4,6 +4,7 @@ from contextlib import ExitStack
from typing import List, Literal, Optional, Union from typing import List, Literal, Optional, Union
import einops import einops
import numpy as np
import torch import torch
import torchvision.transforms as T import torchvision.transforms as T
from diffusers.image_processor import VaeImageProcessor from diffusers.image_processor import VaeImageProcessor
@ -720,3 +721,81 @@ class ImageToLatentsInvocation(BaseInvocation):
latents = latents.to("cpu") latents = latents.to("cpu")
context.services.latents.save(name, latents) context.services.latents.save(name, latents)
return build_latents_output(latents_name=name, latents=latents, seed=None) return build_latents_output(latents_name=name, latents=latents, seed=None)
@title("Blend Latents")
@tags("latents", "blend")
class BlendLatentsInvocation(BaseInvocation):
"""Blend two latents using a given alpha. Latents must have same size."""
type: Literal["lblend"] = "lblend"
# Inputs
latents_a: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
latents_b: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha)
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents_a = context.services.latents.get(self.latents_a.latents_name)
latents_b = context.services.latents.get(self.latents_b.latents_name)
if latents_a.shape != latents_b.shape:
raise "Latents to blend must be the same size."
# TODO:
device = choose_torch_device()
def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
"""
Spherical linear interpolation
Args:
t (float/np.ndarray): Float value between 0.0 and 1.0
v0 (np.ndarray): Starting vector
v1 (np.ndarray): Final vector
DOT_THRESHOLD (float): Threshold for considering the two vectors as
colineal. Not recommended to alter this.
Returns:
v2 (np.ndarray): Interpolation vector between v0 and v1
"""
inputs_are_torch = False
if not isinstance(v0, np.ndarray):
inputs_are_torch = True
v0 = v0.detach().cpu().numpy()
if not isinstance(v1, np.ndarray):
inputs_are_torch = True
v1 = v1.detach().cpu().numpy()
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
if np.abs(dot) > DOT_THRESHOLD:
v2 = (1 - t) * v0 + t * v1
else:
theta_0 = np.arccos(dot)
sin_theta_0 = np.sin(theta_0)
theta_t = theta_0 * t
sin_theta_t = np.sin(theta_t)
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
s1 = sin_theta_t / sin_theta_0
v2 = s0 * v0 + s1 * v1
if inputs_are_torch:
v2 = torch.from_numpy(v2).to(device)
return v2
# blend
blended_latents = slerp(self.alpha, latents_a, latents_b)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
blended_latents = blended_latents.to("cpu")
torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}"
# context.services.latents.set(name, resized_latents)
context.services.latents.save(name, blended_latents)
return build_latents_output(latents_name=name, latents=blended_latents)

View File

@ -49,9 +49,36 @@ from invokeai.backend.model_management.model_cache import CacheStats
GIG = 1073741824 GIG = 1073741824
@dataclass
class NodeStats:
"""Class for tracking execution stats of an invocation node"""
calls: int = 0
time_used: float = 0.0 # seconds
max_vram: float = 0.0 # GB
cache_hits: int = 0
cache_misses: int = 0
cache_high_watermark: int = 0
@dataclass
class NodeLog:
"""Class for tracking node usage"""
# {node_type => NodeStats}
nodes: Dict[str, NodeStats] = field(default_factory=dict)
class InvocationStatsServiceBase(ABC): class InvocationStatsServiceBase(ABC):
"Abstract base class for recording node memory/time performance statistics" "Abstract base class for recording node memory/time performance statistics"
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
# {graph_id => NodeLog}
_stats: Dict[str, NodeLog]
_cache_stats: Dict[str, CacheStats]
ram_used: float
ram_changed: float
@abstractmethod @abstractmethod
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]): def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
""" """
@ -94,8 +121,6 @@ class InvocationStatsServiceBase(ABC):
invocation_type: str, invocation_type: str,
time_used: float, time_used: float,
vram_used: float, vram_used: float,
ram_used: float,
ram_changed: float,
): ):
""" """
Add timing information on execution of a node. Usually Add timing information on execution of a node. Usually
@ -104,8 +129,6 @@ class InvocationStatsServiceBase(ABC):
:param invocation_type: String literal type of the node :param invocation_type: String literal type of the node
:param time_used: Time used by node's exection (sec) :param time_used: Time used by node's exection (sec)
:param vram_used: Maximum VRAM used during exection (GB) :param vram_used: Maximum VRAM used during exection (GB)
:param ram_used: Current RAM available (GB)
:param ram_changed: Change in RAM usage over course of the run (GB)
""" """
pass pass
@ -116,25 +139,19 @@ class InvocationStatsServiceBase(ABC):
""" """
pass pass
@abstractmethod
def update_mem_stats(
self,
ram_used: float,
ram_changed: float,
):
"""
Update the collector with RAM memory usage info.
@dataclass :param ram_used: How much RAM is currently in use.
class NodeStats: :param ram_changed: How much RAM changed since last generation.
"""Class for tracking execution stats of an invocation node""" """
pass
calls: int = 0
time_used: float = 0.0 # seconds
max_vram: float = 0.0 # GB
cache_hits: int = 0
cache_misses: int = 0
cache_high_watermark: int = 0
@dataclass
class NodeLog:
"""Class for tracking node usage"""
# {node_type => NodeStats}
nodes: Dict[str, NodeStats] = field(default_factory=dict)
class InvocationStatsService(InvocationStatsServiceBase): class InvocationStatsService(InvocationStatsServiceBase):
@ -152,12 +169,12 @@ class InvocationStatsService(InvocationStatsServiceBase):
class StatsContext: class StatsContext:
"""Context manager for collecting statistics.""" """Context manager for collecting statistics."""
invocation: BaseInvocation = None invocation: BaseInvocation
collector: "InvocationStatsServiceBase" = None collector: "InvocationStatsServiceBase"
graph_id: str = None graph_id: str
start_time: int = 0 start_time: float
ram_used: int = 0 ram_used: int
model_manager: ModelManagerService = None model_manager: ModelManagerService
def __init__( def __init__(
self, self,
@ -170,7 +187,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
self.invocation = invocation self.invocation = invocation
self.collector = collector self.collector = collector
self.graph_id = graph_id self.graph_id = graph_id
self.start_time = 0 self.start_time = 0.0
self.ram_used = 0 self.ram_used = 0
self.model_manager = model_manager self.model_manager = model_manager
@ -191,7 +208,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
) )
self.collector.update_invocation_stats( self.collector.update_invocation_stats(
graph_id=self.graph_id, graph_id=self.graph_id,
invocation_type=self.invocation.type, invocation_type=self.invocation.type, # type: ignore - `type` is not on the `BaseInvocation` model, but *is* on all invocations
time_used=time.time() - self.start_time, time_used=time.time() - self.start_time,
vram_used=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0, vram_used=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0,
) )
@ -202,11 +219,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
graph_execution_state_id: str, graph_execution_state_id: str,
model_manager: ModelManagerService, model_manager: ModelManagerService,
) -> StatsContext: ) -> StatsContext:
"""
Return a context object that will capture the statistics.
:param invocation: BaseInvocation object from the current graph.
:param graph_execution_state: GraphExecutionState object from the current session.
"""
if not self._stats.get(graph_execution_state_id): # first time we're seeing this if not self._stats.get(graph_execution_state_id): # first time we're seeing this
self._stats[graph_execution_state_id] = NodeLog() self._stats[graph_execution_state_id] = NodeLog()
self._cache_stats[graph_execution_state_id] = CacheStats() self._cache_stats[graph_execution_state_id] = CacheStats()
@ -217,7 +229,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
self._stats = {} self._stats = {}
def reset_stats(self, graph_execution_id: str): def reset_stats(self, graph_execution_id: str):
"""Zero the statistics for the indicated graph."""
try: try:
self._stats.pop(graph_execution_id) self._stats.pop(graph_execution_id)
except KeyError: except KeyError:
@ -228,12 +239,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
ram_used: float, ram_used: float,
ram_changed: float, ram_changed: float,
): ):
"""
Update the collector with RAM memory usage info.
:param ram_used: How much RAM is currently in use.
:param ram_changed: How much RAM changed since last generation.
"""
self.ram_used = ram_used self.ram_used = ram_used
self.ram_changed = ram_changed self.ram_changed = ram_changed
@ -244,16 +249,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
time_used: float, time_used: float,
vram_used: float, vram_used: float,
): ):
"""
Add timing information on execution of a node. Usually
used internally.
:param graph_id: ID of the graph that is currently executing
:param invocation_type: String literal type of the node
:param time_used: Time used by node's exection (sec)
:param vram_used: Maximum VRAM used during exection (GB)
:param ram_used: Current RAM available (GB)
:param ram_changed: Change in RAM usage over course of the run (GB)
"""
if not self._stats[graph_id].nodes.get(invocation_type): if not self._stats[graph_id].nodes.get(invocation_type):
self._stats[graph_id].nodes[invocation_type] = NodeStats() self._stats[graph_id].nodes[invocation_type] = NodeStats()
stats = self._stats[graph_id].nodes[invocation_type] stats = self._stats[graph_id].nodes[invocation_type]
@ -262,14 +257,15 @@ class InvocationStatsService(InvocationStatsServiceBase):
stats.max_vram = max(stats.max_vram, vram_used) stats.max_vram = max(stats.max_vram, vram_used)
def log_stats(self): def log_stats(self):
"""
Send the statistics to the system logger at the info level.
Stats will only be printed when the execution of the graph
is complete.
"""
completed = set() completed = set()
errored = set()
for graph_id, node_log in self._stats.items(): for graph_id, node_log in self._stats.items():
try:
current_graph_state = self.graph_execution_manager.get(graph_id) current_graph_state = self.graph_execution_manager.get(graph_id)
except Exception:
errored.add(graph_id)
continue
if not current_graph_state.is_complete(): if not current_graph_state.is_complete():
continue continue
@ -302,3 +298,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
for graph_id in completed: for graph_id in completed:
del self._stats[graph_id] del self._stats[graph_id]
del self._cache_stats[graph_id] del self._cache_stats[graph_id]
for graph_id in errored:
del self._stats[graph_id]
del self._cache_stats[graph_id]