mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into refactor/rename-performance-options
This commit is contained in:
commit
9d7dfeb857
@ -71,6 +71,9 @@ class FieldDescriptions:
|
|||||||
safe_mode = "Whether or not to use safe mode"
|
safe_mode = "Whether or not to use safe mode"
|
||||||
scribble_mode = "Whether or not to use scribble mode"
|
scribble_mode = "Whether or not to use scribble mode"
|
||||||
scale_factor = "The factor by which to scale"
|
scale_factor = "The factor by which to scale"
|
||||||
|
blend_alpha = (
|
||||||
|
"Blending factor. 0.0 = use input A only, 1.0 = use input B only, 0.5 = 50% mix of input A and input B."
|
||||||
|
)
|
||||||
num_1 = "The first number"
|
num_1 = "The first number"
|
||||||
num_2 = "The second number"
|
num_2 = "The second number"
|
||||||
mask = "The mask to use for the operation"
|
mask = "The mask to use for the operation"
|
||||||
|
@ -233,7 +233,7 @@ class SDXLPromptInvocationBase:
|
|||||||
dtype_for_device_getter=torch_dtype,
|
dtype_for_device_getter=torch_dtype,
|
||||||
truncate_long_prompts=True, # TODO:
|
truncate_long_prompts=True, # TODO:
|
||||||
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
|
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
|
||||||
requires_pooled=True,
|
requires_pooled=get_pooled,
|
||||||
)
|
)
|
||||||
|
|
||||||
conjunction = Compel.parse_prompt_string(prompt)
|
conjunction = Compel.parse_prompt_string(prompt)
|
||||||
|
@ -4,6 +4,7 @@ from contextlib import ExitStack
|
|||||||
from typing import List, Literal, Optional, Union
|
from typing import List, Literal, Optional, Union
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
from diffusers.image_processor import VaeImageProcessor
|
from diffusers.image_processor import VaeImageProcessor
|
||||||
@ -720,3 +721,81 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
latents = latents.to("cpu")
|
latents = latents.to("cpu")
|
||||||
context.services.latents.save(name, latents)
|
context.services.latents.save(name, latents)
|
||||||
return build_latents_output(latents_name=name, latents=latents, seed=None)
|
return build_latents_output(latents_name=name, latents=latents, seed=None)
|
||||||
|
|
||||||
|
|
||||||
|
@title("Blend Latents")
|
||||||
|
@tags("latents", "blend")
|
||||||
|
class BlendLatentsInvocation(BaseInvocation):
|
||||||
|
"""Blend two latents using a given alpha. Latents must have same size."""
|
||||||
|
|
||||||
|
type: Literal["lblend"] = "lblend"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
latents_a: LatentsField = InputField(
|
||||||
|
description=FieldDescriptions.latents,
|
||||||
|
input=Input.Connection,
|
||||||
|
)
|
||||||
|
latents_b: LatentsField = InputField(
|
||||||
|
description=FieldDescriptions.latents,
|
||||||
|
input=Input.Connection,
|
||||||
|
)
|
||||||
|
alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha)
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
|
latents_a = context.services.latents.get(self.latents_a.latents_name)
|
||||||
|
latents_b = context.services.latents.get(self.latents_b.latents_name)
|
||||||
|
|
||||||
|
if latents_a.shape != latents_b.shape:
|
||||||
|
raise "Latents to blend must be the same size."
|
||||||
|
|
||||||
|
# TODO:
|
||||||
|
device = choose_torch_device()
|
||||||
|
|
||||||
|
def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
|
||||||
|
"""
|
||||||
|
Spherical linear interpolation
|
||||||
|
Args:
|
||||||
|
t (float/np.ndarray): Float value between 0.0 and 1.0
|
||||||
|
v0 (np.ndarray): Starting vector
|
||||||
|
v1 (np.ndarray): Final vector
|
||||||
|
DOT_THRESHOLD (float): Threshold for considering the two vectors as
|
||||||
|
colineal. Not recommended to alter this.
|
||||||
|
Returns:
|
||||||
|
v2 (np.ndarray): Interpolation vector between v0 and v1
|
||||||
|
"""
|
||||||
|
inputs_are_torch = False
|
||||||
|
if not isinstance(v0, np.ndarray):
|
||||||
|
inputs_are_torch = True
|
||||||
|
v0 = v0.detach().cpu().numpy()
|
||||||
|
if not isinstance(v1, np.ndarray):
|
||||||
|
inputs_are_torch = True
|
||||||
|
v1 = v1.detach().cpu().numpy()
|
||||||
|
|
||||||
|
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
|
||||||
|
if np.abs(dot) > DOT_THRESHOLD:
|
||||||
|
v2 = (1 - t) * v0 + t * v1
|
||||||
|
else:
|
||||||
|
theta_0 = np.arccos(dot)
|
||||||
|
sin_theta_0 = np.sin(theta_0)
|
||||||
|
theta_t = theta_0 * t
|
||||||
|
sin_theta_t = np.sin(theta_t)
|
||||||
|
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
|
||||||
|
s1 = sin_theta_t / sin_theta_0
|
||||||
|
v2 = s0 * v0 + s1 * v1
|
||||||
|
|
||||||
|
if inputs_are_torch:
|
||||||
|
v2 = torch.from_numpy(v2).to(device)
|
||||||
|
|
||||||
|
return v2
|
||||||
|
|
||||||
|
# blend
|
||||||
|
blended_latents = slerp(self.alpha, latents_a, latents_b)
|
||||||
|
|
||||||
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
|
blended_latents = blended_latents.to("cpu")
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
|
# context.services.latents.set(name, resized_latents)
|
||||||
|
context.services.latents.save(name, blended_latents)
|
||||||
|
return build_latents_output(latents_name=name, latents=blended_latents)
|
||||||
|
@ -49,9 +49,36 @@ from invokeai.backend.model_management.model_cache import CacheStats
|
|||||||
GIG = 1073741824
|
GIG = 1073741824
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NodeStats:
|
||||||
|
"""Class for tracking execution stats of an invocation node"""
|
||||||
|
|
||||||
|
calls: int = 0
|
||||||
|
time_used: float = 0.0 # seconds
|
||||||
|
max_vram: float = 0.0 # GB
|
||||||
|
cache_hits: int = 0
|
||||||
|
cache_misses: int = 0
|
||||||
|
cache_high_watermark: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NodeLog:
|
||||||
|
"""Class for tracking node usage"""
|
||||||
|
|
||||||
|
# {node_type => NodeStats}
|
||||||
|
nodes: Dict[str, NodeStats] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class InvocationStatsServiceBase(ABC):
|
class InvocationStatsServiceBase(ABC):
|
||||||
"Abstract base class for recording node memory/time performance statistics"
|
"Abstract base class for recording node memory/time performance statistics"
|
||||||
|
|
||||||
|
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
|
||||||
|
# {graph_id => NodeLog}
|
||||||
|
_stats: Dict[str, NodeLog]
|
||||||
|
_cache_stats: Dict[str, CacheStats]
|
||||||
|
ram_used: float
|
||||||
|
ram_changed: float
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
|
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
|
||||||
"""
|
"""
|
||||||
@ -94,8 +121,6 @@ class InvocationStatsServiceBase(ABC):
|
|||||||
invocation_type: str,
|
invocation_type: str,
|
||||||
time_used: float,
|
time_used: float,
|
||||||
vram_used: float,
|
vram_used: float,
|
||||||
ram_used: float,
|
|
||||||
ram_changed: float,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Add timing information on execution of a node. Usually
|
Add timing information on execution of a node. Usually
|
||||||
@ -104,8 +129,6 @@ class InvocationStatsServiceBase(ABC):
|
|||||||
:param invocation_type: String literal type of the node
|
:param invocation_type: String literal type of the node
|
||||||
:param time_used: Time used by node's exection (sec)
|
:param time_used: Time used by node's exection (sec)
|
||||||
:param vram_used: Maximum VRAM used during exection (GB)
|
:param vram_used: Maximum VRAM used during exection (GB)
|
||||||
:param ram_used: Current RAM available (GB)
|
|
||||||
:param ram_changed: Change in RAM usage over course of the run (GB)
|
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -116,25 +139,19 @@ class InvocationStatsServiceBase(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update_mem_stats(
|
||||||
|
self,
|
||||||
|
ram_used: float,
|
||||||
|
ram_changed: float,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update the collector with RAM memory usage info.
|
||||||
|
|
||||||
@dataclass
|
:param ram_used: How much RAM is currently in use.
|
||||||
class NodeStats:
|
:param ram_changed: How much RAM changed since last generation.
|
||||||
"""Class for tracking execution stats of an invocation node"""
|
"""
|
||||||
|
pass
|
||||||
calls: int = 0
|
|
||||||
time_used: float = 0.0 # seconds
|
|
||||||
max_vram: float = 0.0 # GB
|
|
||||||
cache_hits: int = 0
|
|
||||||
cache_misses: int = 0
|
|
||||||
cache_high_watermark: int = 0
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class NodeLog:
|
|
||||||
"""Class for tracking node usage"""
|
|
||||||
|
|
||||||
# {node_type => NodeStats}
|
|
||||||
nodes: Dict[str, NodeStats] = field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
class InvocationStatsService(InvocationStatsServiceBase):
|
class InvocationStatsService(InvocationStatsServiceBase):
|
||||||
@ -152,12 +169,12 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
class StatsContext:
|
class StatsContext:
|
||||||
"""Context manager for collecting statistics."""
|
"""Context manager for collecting statistics."""
|
||||||
|
|
||||||
invocation: BaseInvocation = None
|
invocation: BaseInvocation
|
||||||
collector: "InvocationStatsServiceBase" = None
|
collector: "InvocationStatsServiceBase"
|
||||||
graph_id: str = None
|
graph_id: str
|
||||||
start_time: int = 0
|
start_time: float
|
||||||
ram_used: int = 0
|
ram_used: int
|
||||||
model_manager: ModelManagerService = None
|
model_manager: ModelManagerService
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -170,7 +187,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
self.invocation = invocation
|
self.invocation = invocation
|
||||||
self.collector = collector
|
self.collector = collector
|
||||||
self.graph_id = graph_id
|
self.graph_id = graph_id
|
||||||
self.start_time = 0
|
self.start_time = 0.0
|
||||||
self.ram_used = 0
|
self.ram_used = 0
|
||||||
self.model_manager = model_manager
|
self.model_manager = model_manager
|
||||||
|
|
||||||
@ -191,7 +208,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
)
|
)
|
||||||
self.collector.update_invocation_stats(
|
self.collector.update_invocation_stats(
|
||||||
graph_id=self.graph_id,
|
graph_id=self.graph_id,
|
||||||
invocation_type=self.invocation.type,
|
invocation_type=self.invocation.type, # type: ignore - `type` is not on the `BaseInvocation` model, but *is* on all invocations
|
||||||
time_used=time.time() - self.start_time,
|
time_used=time.time() - self.start_time,
|
||||||
vram_used=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0,
|
vram_used=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0,
|
||||||
)
|
)
|
||||||
@ -202,11 +219,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
graph_execution_state_id: str,
|
graph_execution_state_id: str,
|
||||||
model_manager: ModelManagerService,
|
model_manager: ModelManagerService,
|
||||||
) -> StatsContext:
|
) -> StatsContext:
|
||||||
"""
|
|
||||||
Return a context object that will capture the statistics.
|
|
||||||
:param invocation: BaseInvocation object from the current graph.
|
|
||||||
:param graph_execution_state: GraphExecutionState object from the current session.
|
|
||||||
"""
|
|
||||||
if not self._stats.get(graph_execution_state_id): # first time we're seeing this
|
if not self._stats.get(graph_execution_state_id): # first time we're seeing this
|
||||||
self._stats[graph_execution_state_id] = NodeLog()
|
self._stats[graph_execution_state_id] = NodeLog()
|
||||||
self._cache_stats[graph_execution_state_id] = CacheStats()
|
self._cache_stats[graph_execution_state_id] = CacheStats()
|
||||||
@ -217,7 +229,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
self._stats = {}
|
self._stats = {}
|
||||||
|
|
||||||
def reset_stats(self, graph_execution_id: str):
|
def reset_stats(self, graph_execution_id: str):
|
||||||
"""Zero the statistics for the indicated graph."""
|
|
||||||
try:
|
try:
|
||||||
self._stats.pop(graph_execution_id)
|
self._stats.pop(graph_execution_id)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
@ -228,12 +239,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
ram_used: float,
|
ram_used: float,
|
||||||
ram_changed: float,
|
ram_changed: float,
|
||||||
):
|
):
|
||||||
"""
|
|
||||||
Update the collector with RAM memory usage info.
|
|
||||||
|
|
||||||
:param ram_used: How much RAM is currently in use.
|
|
||||||
:param ram_changed: How much RAM changed since last generation.
|
|
||||||
"""
|
|
||||||
self.ram_used = ram_used
|
self.ram_used = ram_used
|
||||||
self.ram_changed = ram_changed
|
self.ram_changed = ram_changed
|
||||||
|
|
||||||
@ -244,16 +249,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
time_used: float,
|
time_used: float,
|
||||||
vram_used: float,
|
vram_used: float,
|
||||||
):
|
):
|
||||||
"""
|
|
||||||
Add timing information on execution of a node. Usually
|
|
||||||
used internally.
|
|
||||||
:param graph_id: ID of the graph that is currently executing
|
|
||||||
:param invocation_type: String literal type of the node
|
|
||||||
:param time_used: Time used by node's exection (sec)
|
|
||||||
:param vram_used: Maximum VRAM used during exection (GB)
|
|
||||||
:param ram_used: Current RAM available (GB)
|
|
||||||
:param ram_changed: Change in RAM usage over course of the run (GB)
|
|
||||||
"""
|
|
||||||
if not self._stats[graph_id].nodes.get(invocation_type):
|
if not self._stats[graph_id].nodes.get(invocation_type):
|
||||||
self._stats[graph_id].nodes[invocation_type] = NodeStats()
|
self._stats[graph_id].nodes[invocation_type] = NodeStats()
|
||||||
stats = self._stats[graph_id].nodes[invocation_type]
|
stats = self._stats[graph_id].nodes[invocation_type]
|
||||||
@ -262,14 +257,15 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
stats.max_vram = max(stats.max_vram, vram_used)
|
stats.max_vram = max(stats.max_vram, vram_used)
|
||||||
|
|
||||||
def log_stats(self):
|
def log_stats(self):
|
||||||
"""
|
|
||||||
Send the statistics to the system logger at the info level.
|
|
||||||
Stats will only be printed when the execution of the graph
|
|
||||||
is complete.
|
|
||||||
"""
|
|
||||||
completed = set()
|
completed = set()
|
||||||
|
errored = set()
|
||||||
for graph_id, node_log in self._stats.items():
|
for graph_id, node_log in self._stats.items():
|
||||||
current_graph_state = self.graph_execution_manager.get(graph_id)
|
try:
|
||||||
|
current_graph_state = self.graph_execution_manager.get(graph_id)
|
||||||
|
except Exception:
|
||||||
|
errored.add(graph_id)
|
||||||
|
continue
|
||||||
|
|
||||||
if not current_graph_state.is_complete():
|
if not current_graph_state.is_complete():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -302,3 +298,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
for graph_id in completed:
|
for graph_id in completed:
|
||||||
del self._stats[graph_id]
|
del self._stats[graph_id]
|
||||||
del self._cache_stats[graph_id]
|
del self._cache_stats[graph_id]
|
||||||
|
|
||||||
|
for graph_id in errored:
|
||||||
|
del self._stats[graph_id]
|
||||||
|
del self._cache_stats[graph_id]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user