From fd7b84241988a57d7b9dae24798cc0bb475dd92c Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Tue, 1 Aug 2023 17:44:09 -0400 Subject: [PATCH 01/21] add execution stat reporting after each invocation --- invokeai/app/services/invocation_stats.py | 115 ++++++++++++++++++++++ invokeai/app/services/processor.py | 43 ++++---- 2 files changed, 139 insertions(+), 19 deletions(-) create mode 100644 invokeai/app/services/invocation_stats.py diff --git a/invokeai/app/services/invocation_stats.py b/invokeai/app/services/invocation_stats.py new file mode 100644 index 0000000000..8d41b60d49 --- /dev/null +++ b/invokeai/app/services/invocation_stats.py @@ -0,0 +1,115 @@ +# Copyright 2023 Lincoln D. Stein +"""Utility to collect execution time and GPU usage stats on invocations in flight""" + +""" +Usage: +statistics = InvocationStats() # keep track of performance metrics +... +with statistics.collect_stats(invocation, graph_execution_state): + outputs = invocation.invoke( + InvocationContext( + services=self.__invoker.services, + graph_execution_state_id=graph_execution_state.id, + ) + ) +... +statistics.log_stats() + +Typical output: +[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> Node Calls Seconds +[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> main_model_loader 1 0.006s +[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> clip_skip 1 0.005s +[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> compel 2 0.351s +[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> rand_int 1 0.001s +[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> range_of_size 1 0.001s +[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> iterate 1 0.001s +[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> metadata_accumulator 1 0.002s +[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> noise 1 0.002s +[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> t2l 1 3.117s +[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> l2i 1 0.377s +[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> TOTAL: 3.865s +[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> Max VRAM used for execution: 3.12G. +[2023-08-01 17:34:44,586]::[InvokeAI]::INFO --> Current VRAM utilization 2.31G. +""" + +import time +from typing import Dict, List + +import torch + +from .graph import GraphExecutionState +from .invocation_queue import InvocationQueueItem +from ..invocations.baseinvocation import BaseInvocation + +import invokeai.backend.util.logging as logger + +class InvocationStats(): + """Accumulate performance information about a running graph. Collects time spent in each node, + as well as the maximum and current VRAM utilisation for CUDA systems""" + + def __init__(self): + self._stats: Dict[str, int] = {} + + class StatsContext(): + def __init__(self, invocation: BaseInvocation, collector): + self.invocation = invocation + self.collector = collector + self.start_time = 0 + + def __enter__(self): + self.start_time = time.time() + + def __exit__(self, *args): + self.collector.log_time(self.invocation.type, time.time() - self.start_time) + + def collect_stats(self, + invocation: BaseInvocation, + graph_execution_state: GraphExecutionState, + ) -> StatsContext: + """ + Return a context object that will capture the statistics. + :param invocation: BaseInvocation object from the current graph. + :param graph_execution_state: GraphExecutionState object from the current session. + """ + if len(graph_execution_state.executed)==0: # new graph is starting + self.reset_stats() + self._current_graph_state = graph_execution_state + sc = self.StatsContext(invocation, self) + return self.StatsContext(invocation, self) + + def reset_stats(self): + """Zero the statistics. Ordinarily called internally.""" + if torch.cuda.is_available(): + torch.cuda.reset_peak_memory_stats() + self._stats: Dict[str, List[int, float]] = {} + + + def log_time(self, invocation_type: str, time_used: float): + """ + Add timing information on execution of a node. Usually + used internally. + :param invocation_type: String literal type of the node + :param time_used: Floating point seconds used by node's exection + """ + if not self._stats.get(invocation_type): + self._stats[invocation_type] = [0, 0.0] + self._stats[invocation_type][0] += 1 + self._stats[invocation_type][1] += time_used + + def log_stats(self): + """ + Send the statistics to the system logger at the info level. + Stats will only be printed if when the execution of the graph + is complete. + """ + if self._current_graph_state.is_complete(): + logger.info('Node Calls Seconds') + for node_type, (calls, time_used) in self._stats.items(): + logger.info(f'{node_type:<20} {calls:>5} {time_used:4.3f}s') + + total_time = sum([ticks for _,ticks in self._stats.values()]) + logger.info(f'TOTAL: {total_time:4.3f}s') + if torch.cuda.is_available(): + logger.info('Max VRAM used for execution: '+'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9)) + logger.info('Current VRAM utilization '+'%4.2fG' % (torch.cuda.memory_allocated() / 1e9)) + diff --git a/invokeai/app/services/processor.py b/invokeai/app/services/processor.py index 50fe217e05..a43e2878ac 100644 --- a/invokeai/app/services/processor.py +++ b/invokeai/app/services/processor.py @@ -5,6 +5,7 @@ from threading import Event, Thread, BoundedSemaphore from ..invocations.baseinvocation import InvocationContext from .invocation_queue import InvocationQueueItem from .invoker import InvocationProcessorABC, Invoker +from .invocation_stats import InvocationStats from ..models.exceptions import CanceledException import invokeai.backend.util.logging as logger @@ -35,6 +36,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC): def __process(self, stop_event: Event): try: self.__threadLimit.acquire() + statistics = InvocationStats() # keep track of performance metrics while not stop_event.is_set(): try: queue_item: InvocationQueueItem = self.__invoker.services.queue.get() @@ -83,30 +85,32 @@ class DefaultInvocationProcessor(InvocationProcessorABC): # Invoke try: - outputs = invocation.invoke( - InvocationContext( - services=self.__invoker.services, - graph_execution_state_id=graph_execution_state.id, + with statistics.collect_stats(invocation, graph_execution_state): + outputs = invocation.invoke( + InvocationContext( + services=self.__invoker.services, + graph_execution_state_id=graph_execution_state.id, + ) ) - ) - # Check queue to see if this is canceled, and skip if so - if self.__invoker.services.queue.is_canceled(graph_execution_state.id): - continue + # Check queue to see if this is canceled, and skip if so + if self.__invoker.services.queue.is_canceled(graph_execution_state.id): + continue - # Save outputs and history - graph_execution_state.complete(invocation.id, outputs) + # Save outputs and history + graph_execution_state.complete(invocation.id, outputs) - # Save the state changes - self.__invoker.services.graph_execution_manager.set(graph_execution_state) + # Save the state changes + self.__invoker.services.graph_execution_manager.set(graph_execution_state) - # Send complete event - self.__invoker.services.events.emit_invocation_complete( - graph_execution_state_id=graph_execution_state.id, - node=invocation.dict(), - source_node_id=source_node_id, - result=outputs.dict(), - ) + # Send complete event + self.__invoker.services.events.emit_invocation_complete( + graph_execution_state_id=graph_execution_state.id, + node=invocation.dict(), + source_node_id=source_node_id, + result=outputs.dict(), + ) + statistics.log_stats() except KeyboardInterrupt: pass @@ -161,3 +165,4 @@ class DefaultInvocationProcessor(InvocationProcessorABC): pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor finally: self.__threadLimit.release() + From 8a4e5f73aa6e914494885438a560a7d6694b6ce2 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Tue, 1 Aug 2023 19:39:42 -0400 Subject: [PATCH 02/21] reset stats on exception --- invokeai/app/services/invocation_stats.py | 38 +++++++++++------------ invokeai/app/services/processor.py | 4 +-- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/invokeai/app/services/invocation_stats.py b/invokeai/app/services/invocation_stats.py index 8d41b60d49..24a5662647 100644 --- a/invokeai/app/services/invocation_stats.py +++ b/invokeai/app/services/invocation_stats.py @@ -43,14 +43,15 @@ from ..invocations.baseinvocation import BaseInvocation import invokeai.backend.util.logging as logger -class InvocationStats(): + +class InvocationStats: """Accumulate performance information about a running graph. Collects time spent in each node, as well as the maximum and current VRAM utilisation for CUDA systems""" def __init__(self): self._stats: Dict[str, int] = {} - - class StatsContext(): + + class StatsContext: def __init__(self, invocation: BaseInvocation, collector): self.invocation = invocation self.collector = collector @@ -61,17 +62,18 @@ class InvocationStats(): def __exit__(self, *args): self.collector.log_time(self.invocation.type, time.time() - self.start_time) - - def collect_stats(self, - invocation: BaseInvocation, - graph_execution_state: GraphExecutionState, - ) -> StatsContext: + + def collect_stats( + self, + invocation: BaseInvocation, + graph_execution_state: GraphExecutionState, + ) -> StatsContext: """ Return a context object that will capture the statistics. :param invocation: BaseInvocation object from the current graph. :param graph_execution_state: GraphExecutionState object from the current session. """ - if len(graph_execution_state.executed)==0: # new graph is starting + if len(graph_execution_state.executed) == 0: # new graph is starting self.reset_stats() self._current_graph_state = graph_execution_state sc = self.StatsContext(invocation, self) @@ -83,7 +85,6 @@ class InvocationStats(): torch.cuda.reset_peak_memory_stats() self._stats: Dict[str, List[int, float]] = {} - def log_time(self, invocation_type: str, time_used: float): """ Add timing information on execution of a node. Usually @@ -95,7 +96,7 @@ class InvocationStats(): self._stats[invocation_type] = [0, 0.0] self._stats[invocation_type][0] += 1 self._stats[invocation_type][1] += time_used - + def log_stats(self): """ Send the statistics to the system logger at the info level. @@ -103,13 +104,12 @@ class InvocationStats(): is complete. """ if self._current_graph_state.is_complete(): - logger.info('Node Calls Seconds') + logger.info("Node Calls Seconds") for node_type, (calls, time_used) in self._stats.items(): - logger.info(f'{node_type:<20} {calls:>5} {time_used:4.3f}s') - - total_time = sum([ticks for _,ticks in self._stats.values()]) - logger.info(f'TOTAL: {total_time:4.3f}s') + logger.info(f"{node_type:<20} {calls:>5} {time_used:4.3f}s") + + total_time = sum([ticks for _, ticks in self._stats.values()]) + logger.info(f"TOTAL: {total_time:4.3f}s") if torch.cuda.is_available(): - logger.info('Max VRAM used for execution: '+'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9)) - logger.info('Current VRAM utilization '+'%4.2fG' % (torch.cuda.memory_allocated() / 1e9)) - + logger.info("Max VRAM used for execution: " + "%4.2fG" % (torch.cuda.max_memory_allocated() / 1e9)) + logger.info("Current VRAM utilization " + "%4.2fG" % (torch.cuda.memory_allocated() / 1e9)) diff --git a/invokeai/app/services/processor.py b/invokeai/app/services/processor.py index a43e2878ac..e9511aa283 100644 --- a/invokeai/app/services/processor.py +++ b/invokeai/app/services/processor.py @@ -116,6 +116,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC): pass except CanceledException: + statistics.reset_stats() pass except Exception as e: @@ -137,7 +138,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC): error_type=e.__class__.__name__, error=error, ) - + statistics.reset_stats() pass # Check queue to see if this is canceled, and skip if so @@ -165,4 +166,3 @@ class DefaultInvocationProcessor(InvocationProcessorABC): pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor finally: self.__threadLimit.release() - From 4d22cafdad0d83523ce43b6f564d07918f4fe02a Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Tue, 1 Aug 2023 22:06:27 -0400 Subject: [PATCH 03/21] Installer should download fp16 models if user has specified 'auto' in config - Closes #4127 --- invokeai/backend/install/model_install_backend.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/invokeai/backend/install/model_install_backend.py b/invokeai/backend/install/model_install_backend.py index c0a7244367..449c234144 100644 --- a/invokeai/backend/install/model_install_backend.py +++ b/invokeai/backend/install/model_install_backend.py @@ -13,6 +13,7 @@ import requests from diffusers import DiffusionPipeline from diffusers import logging as dlogging import onnx +import torch from huggingface_hub import hf_hub_url, HfFolder, HfApi from omegaconf import OmegaConf from tqdm import tqdm @@ -23,6 +24,7 @@ from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult from invokeai.backend.model_management.model_probe import ModelProbe, SchedulerPredictionType, ModelProbeInfo from invokeai.backend.util import download_with_resume +from invokeai.backend.util.devices import torch_dtype, choose_torch_device from ..util.logging import InvokeAILogger warnings.filterwarnings("ignore") @@ -416,13 +418,17 @@ class ModelInstall(object): does a save_pretrained() to the indicated staging area. """ _, name = repo_id.split("/") - revisions = ["fp16", "main"] if self.config.precision == "float16" else ["main"] + precision = torch_dtype(choose_torch_device()) + revisions = ["fp16", "main"] if precision == torch.float16 else ["main"] model = None for revision in revisions: try: - model = DiffusionPipeline.from_pretrained(repo_id, revision=revision, safety_checker=None) - except: # most errors are due to fp16 not being present. Fix this to catch other errors - pass + model = DiffusionPipeline.from_pretrained( + repo_id, revision=revision, safety_checker=None, torch_dtype=precision + ) + except Exception as e: # most errors are due to fp16 not being present. Fix this to catch other errors + if "fp16" not in str(e): + print(e) if model: break if not model: From ed76250dbaa7ea27d0a0d06051337cace9f6a23e Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Wed, 2 Aug 2023 07:21:21 -0400 Subject: [PATCH 04/21] Stop checking for unet/model.onnx when a model_index.json is detected --- invokeai/backend/install/model_install_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/backend/install/model_install_backend.py b/invokeai/backend/install/model_install_backend.py index c0a7244367..ac032d4955 100644 --- a/invokeai/backend/install/model_install_backend.py +++ b/invokeai/backend/install/model_install_backend.py @@ -303,7 +303,7 @@ class ModelInstall(object): with TemporaryDirectory(dir=self.config.models_path) as staging: staging = Path(staging) - if "model_index.json" in files and "unet/model.onnx" not in files: + if "model_index.json" in files: location = self._download_hf_pipeline(repo_id, staging) # pipeline elif "unet/model.onnx" in files: location = self._download_hf_model(repo_id, files, staging) From 8fc75a71ee20dd96ceb59f9a4a904125d71e8ad6 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 2 Aug 2023 18:10:52 -0400 Subject: [PATCH 05/21] integrate correctly into app API and add features - Create abstract base class InvocationStatsServiceBase - Store InvocationStatsService in the InvocationServices object - Collect and report stats on simultaneous graph execution independently for each graph id - Track VRAM usage for each node - Handle cancellations and other exceptions gracefully --- invokeai/app/api/dependencies.py | 3 +- invokeai/app/services/invocation_services.py | 3 + invokeai/app/services/invocation_stats.py | 222 ++++++++++++++----- invokeai/app/services/processor.py | 23 +- 4 files changed, 182 insertions(+), 69 deletions(-) diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index d609ce3be2..b25009c8c9 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -2,7 +2,6 @@ from typing import Optional from logging import Logger -import os from invokeai.app.services.board_image_record_storage import ( SqliteBoardImageRecordStorage, ) @@ -30,6 +29,7 @@ from ..services.invoker import Invoker from ..services.processor import DefaultInvocationProcessor from ..services.sqlite import SqliteItemStorage from ..services.model_manager_service import ModelManagerService +from ..services.invocation_stats import InvocationStatsService from .events import FastAPIEventService @@ -128,6 +128,7 @@ class ApiDependencies: graph_execution_manager=graph_execution_manager, processor=DefaultInvocationProcessor(), configuration=config, + performance_statistics=InvocationStatsService(graph_execution_manager), logger=logger, ) diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index 8af17c7643..d7d9aae024 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -32,6 +32,7 @@ class InvocationServices: logger: "Logger" model_manager: "ModelManagerServiceBase" processor: "InvocationProcessorABC" + performance_statistics: "InvocationStatsServiceBase" queue: "InvocationQueueABC" def __init__( @@ -47,6 +48,7 @@ class InvocationServices: logger: "Logger", model_manager: "ModelManagerServiceBase", processor: "InvocationProcessorABC", + performance_statistics: "InvocationStatsServiceBase", queue: "InvocationQueueABC", ): self.board_images = board_images @@ -61,4 +63,5 @@ class InvocationServices: self.logger = logger self.model_manager = model_manager self.processor = processor + self.performance_statistics = performance_statistics self.queue = queue diff --git a/invokeai/app/services/invocation_stats.py b/invokeai/app/services/invocation_stats.py index 24a5662647..aca1dba550 100644 --- a/invokeai/app/services/invocation_stats.py +++ b/invokeai/app/services/invocation_stats.py @@ -3,99 +3,196 @@ """ Usage: -statistics = InvocationStats() # keep track of performance metrics -... -with statistics.collect_stats(invocation, graph_execution_state): - outputs = invocation.invoke( - InvocationContext( - services=self.__invoker.services, - graph_execution_state_id=graph_execution_state.id, - ) - ) -... + +statistics = InvocationStatsService(graph_execution_manager) +with statistics.collect_stats(invocation, graph_execution_state.id): + ... execute graphs... statistics.log_stats() Typical output: -[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> Node Calls Seconds -[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> main_model_loader 1 0.006s -[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> clip_skip 1 0.005s -[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> compel 2 0.351s -[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> rand_int 1 0.001s -[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> range_of_size 1 0.001s -[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> iterate 1 0.001s -[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> metadata_accumulator 1 0.002s -[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> noise 1 0.002s -[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> t2l 1 3.117s -[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> l2i 1 0.377s -[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> TOTAL: 3.865s -[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> Max VRAM used for execution: 3.12G. -[2023-08-01 17:34:44,586]::[InvokeAI]::INFO --> Current VRAM utilization 2.31G. +[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Graph stats: c7764585-9c68-4d9d-a199-55e8186790f3 +[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Node Calls Seconds VRAM Used +[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> main_model_loader 1 0.005s 0.01G +[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> clip_skip 1 0.004s 0.01G +[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> compel 2 0.512s 0.26G +[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> rand_int 1 0.001s 0.01G +[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> range_of_size 1 0.001s 0.01G +[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> iterate 1 0.001s 0.01G +[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> metadata_accumulator 1 0.002s 0.01G +[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> noise 1 0.002s 0.01G +[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> t2l 1 3.541s 1.93G +[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> l2i 1 0.679s 0.58G +[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> TOTAL GRAPH EXECUTION TIME: 4.749s +[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> Current VRAM utilization 0.01G + +The abstract base class for this class is InvocationStatsServiceBase. An implementing class which +writes to the system log is stored in InvocationServices.performance_statistics. """ import time -from typing import Dict, List +from abc import ABC, abstractmethod +from contextlib import AbstractContextManager +from dataclasses import dataclass, field +from typing import Dict import torch -from .graph import GraphExecutionState -from .invocation_queue import InvocationQueueItem -from ..invocations.baseinvocation import BaseInvocation - import invokeai.backend.util.logging as logger +from ..invocations.baseinvocation import BaseInvocation +from .graph import GraphExecutionState +from .item_storage import ItemStorageABC -class InvocationStats: + +class InvocationStatsServiceBase(ABC): + "Abstract base class for recording node memory/time performance statistics" + + @abstractmethod + def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]): + """ + Initialize the InvocationStatsService and reset counters to zero + :param graph_execution_manager: Graph execution manager for this session + """ + pass + + @abstractmethod + def collect_stats( + self, + invocation: BaseInvocation, + graph_execution_state_id: str, + ) -> AbstractContextManager: + """ + Return a context object that will capture the statistics on the execution + of invocaation. Use with: to place around the part of the code that executes the invocation. + :param invocation: BaseInvocation object from the current graph. + :param graph_execution_state: GraphExecutionState object from the current session. + """ + pass + + @abstractmethod + def reset_stats(self, graph_execution_state_id: str): + """ + Reset all statistics for the indicated graph + :param graph_execution_state_id + """ + pass + + @abstractmethod + def reset_all_stats(self): + """Zero all statistics""" + pass + + @abstractmethod + def update_invocation_stats( + self, + graph_id: str, + invocation_type: str, + time_used: float, + vram_used: float, + ): + """ + Add timing information on execution of a node. Usually + used internally. + :param graph_id: ID of the graph that is currently executing + :param invocation_type: String literal type of the node + :param time_used: Time used by node's exection (sec) + :param vram_used: Maximum VRAM used during exection (GB) + """ + pass + + @abstractmethod + def log_stats(self): + """ + Write out the accumulated statistics to the log or somewhere else. + """ + pass + + +@dataclass +class NodeStats: + """Class for tracking execution stats of an invocation node""" + + calls: int = 0 + time_used: float = 0.0 # seconds + max_vram: float = 0.0 # GB + + +@dataclass +class NodeLog: + """Class for tracking node usage""" + + # {node_type => NodeStats} + nodes: Dict[str, NodeStats] = field(default_factory=dict) + + +class InvocationStatsService(InvocationStatsServiceBase): """Accumulate performance information about a running graph. Collects time spent in each node, as well as the maximum and current VRAM utilisation for CUDA systems""" - def __init__(self): - self._stats: Dict[str, int] = {} + def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]): + self.graph_execution_manager = graph_execution_manager + # {graph_id => NodeLog} + self._stats: Dict[str, NodeLog] = {} class StatsContext: - def __init__(self, invocation: BaseInvocation, collector): + def __init__(self, invocation: BaseInvocation, graph_id: str, collector: "InvocationStatsServiceBase"): self.invocation = invocation self.collector = collector + self.graph_id = graph_id self.start_time = 0 def __enter__(self): self.start_time = time.time() + if torch.cuda.is_available(): + torch.cuda.reset_peak_memory_stats() def __exit__(self, *args): - self.collector.log_time(self.invocation.type, time.time() - self.start_time) + self.collector.update_invocation_stats( + self.graph_id, + self.invocation.type, + time.time() - self.start_time, + torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0.0, + ) def collect_stats( self, invocation: BaseInvocation, - graph_execution_state: GraphExecutionState, + graph_execution_state_id: str, ) -> StatsContext: """ Return a context object that will capture the statistics. :param invocation: BaseInvocation object from the current graph. :param graph_execution_state: GraphExecutionState object from the current session. """ - if len(graph_execution_state.executed) == 0: # new graph is starting - self.reset_stats() - self._current_graph_state = graph_execution_state - sc = self.StatsContext(invocation, self) - return self.StatsContext(invocation, self) + if not self._stats.get(graph_execution_state_id): # first time we're seeing this + self._stats[graph_execution_state_id] = NodeLog() + return self.StatsContext(invocation, graph_execution_state_id, self) - def reset_stats(self): - """Zero the statistics. Ordinarily called internally.""" - if torch.cuda.is_available(): - torch.cuda.reset_peak_memory_stats() - self._stats: Dict[str, List[int, float]] = {} + def reset_all_stats(self): + """Zero all statistics""" + self._stats = {} - def log_time(self, invocation_type: str, time_used: float): + def reset_stats(self, graph_execution_id: str): + """Zero the statistics for the indicated graph.""" + try: + self._stats.pop(graph_execution_id) + except KeyError: + logger.warning(f"Attempted to clear statistics for unknown graph {graph_execution_id}") + + def update_invocation_stats(self, graph_id: str, invocation_type: str, time_used: float, vram_used: float): """ Add timing information on execution of a node. Usually used internally. + :param graph_id: ID of the graph that is currently executing :param invocation_type: String literal type of the node :param time_used: Floating point seconds used by node's exection """ - if not self._stats.get(invocation_type): - self._stats[invocation_type] = [0, 0.0] - self._stats[invocation_type][0] += 1 - self._stats[invocation_type][1] += time_used + if not self._stats[graph_id].nodes.get(invocation_type): + self._stats[graph_id].nodes[invocation_type] = NodeStats() + stats = self._stats[graph_id].nodes[invocation_type] + stats.calls += 1 + stats.time_used += time_used + stats.max_vram = max(stats.max_vram, vram_used) def log_stats(self): """ @@ -103,13 +200,24 @@ class InvocationStats: Stats will only be printed if when the execution of the graph is complete. """ - if self._current_graph_state.is_complete(): - logger.info("Node Calls Seconds") - for node_type, (calls, time_used) in self._stats.items(): - logger.info(f"{node_type:<20} {calls:>5} {time_used:4.3f}s") + completed = set() + for graph_id, node_log in self._stats.items(): + current_graph_state = self.graph_execution_manager.get(graph_id) + if not current_graph_state.is_complete(): + continue - total_time = sum([ticks for _, ticks in self._stats.values()]) - logger.info(f"TOTAL: {total_time:4.3f}s") + total_time = 0 + logger.info(f"Graph stats: {graph_id}") + logger.info("Node Calls Seconds VRAM Used") + for node_type, stats in self._stats[graph_id].nodes.items(): + logger.info(f"{node_type:<20} {stats.calls:>5} {stats.time_used:4.3f}s {stats.max_vram:4.2f}G") + total_time += stats.time_used + + logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:4.3f}s") if torch.cuda.is_available(): - logger.info("Max VRAM used for execution: " + "%4.2fG" % (torch.cuda.max_memory_allocated() / 1e9)) logger.info("Current VRAM utilization " + "%4.2fG" % (torch.cuda.memory_allocated() / 1e9)) + + completed.add(graph_id) + + for graph_id in completed: + del self._stats[graph_id] diff --git a/invokeai/app/services/processor.py b/invokeai/app/services/processor.py index e9511aa283..41170a304b 100644 --- a/invokeai/app/services/processor.py +++ b/invokeai/app/services/processor.py @@ -1,15 +1,15 @@ import time import traceback -from threading import Event, Thread, BoundedSemaphore - -from ..invocations.baseinvocation import InvocationContext -from .invocation_queue import InvocationQueueItem -from .invoker import InvocationProcessorABC, Invoker -from .invocation_stats import InvocationStats -from ..models.exceptions import CanceledException +from threading import BoundedSemaphore, Event, Thread import invokeai.backend.util.logging as logger +from ..invocations.baseinvocation import InvocationContext +from ..models.exceptions import CanceledException +from .invocation_queue import InvocationQueueItem +from .invocation_stats import InvocationStatsServiceBase +from .invoker import InvocationProcessorABC, Invoker + class DefaultInvocationProcessor(InvocationProcessorABC): __invoker_thread: Thread @@ -36,7 +36,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC): def __process(self, stop_event: Event): try: self.__threadLimit.acquire() - statistics = InvocationStats() # keep track of performance metrics + statistics: InvocationStatsServiceBase = self.__invoker.services.performance_statistics + while not stop_event.is_set(): try: queue_item: InvocationQueueItem = self.__invoker.services.queue.get() @@ -85,7 +86,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC): # Invoke try: - with statistics.collect_stats(invocation, graph_execution_state): + with statistics.collect_stats(invocation, graph_execution_state.id): outputs = invocation.invoke( InvocationContext( services=self.__invoker.services, @@ -116,7 +117,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC): pass except CanceledException: - statistics.reset_stats() + statistics.reset_stats(graph_execution_state.id) pass except Exception as e: @@ -138,7 +139,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC): error_type=e.__class__.__name__, error=error, ) - statistics.reset_stats() + statistics.reset_stats(graph_execution_state.id) pass # Check queue to see if this is canceled, and skip if so From 3fc789a7eeb9efa6255e331ab7dda03097c451d6 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 2 Aug 2023 18:31:10 -0400 Subject: [PATCH 06/21] fix unit tests --- tests/nodes/test_graph_execution_state.py | 9 ++++++--- tests/nodes/test_invoker.py | 9 ++++++--- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/tests/nodes/test_graph_execution_state.py b/tests/nodes/test_graph_execution_state.py index e0ee120b54..248bc6fee1 100644 --- a/tests/nodes/test_graph_execution_state.py +++ b/tests/nodes/test_graph_execution_state.py @@ -16,6 +16,7 @@ from invokeai.app.invocations.baseinvocation import ( from invokeai.app.invocations.collections import RangeInvocation from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation from invokeai.app.services.invocation_services import InvocationServices +from invokeai.app.services.invocation_stats import InvocationStatsService from invokeai.app.services.graph import ( Graph, CollectInvocation, @@ -41,6 +42,9 @@ def simple_graph(): @pytest.fixture def mock_services() -> InvocationServices: # NOTE: none of these are actually called by the test invocations + graph_execution_manager = SqliteItemStorage[GraphExecutionState]( + filename=sqlite_memory, table_name="graph_executions" + ) return InvocationServices( model_manager=None, # type: ignore events=TestEventService(), @@ -51,9 +55,8 @@ def mock_services() -> InvocationServices: board_images=None, # type: ignore queue=MemoryInvocationQueue(), graph_library=SqliteItemStorage[LibraryGraph](filename=sqlite_memory, table_name="graphs"), - graph_execution_manager=SqliteItemStorage[GraphExecutionState]( - filename=sqlite_memory, table_name="graph_executions" - ), + graph_execution_manager=graph_execution_manager, + performance_statistics=InvocationStatsService(graph_execution_manager), processor=DefaultInvocationProcessor(), configuration=None, # type: ignore ) diff --git a/tests/nodes/test_invoker.py b/tests/nodes/test_invoker.py index 8eba6d468f..5985c7e8bb 100644 --- a/tests/nodes/test_invoker.py +++ b/tests/nodes/test_invoker.py @@ -11,6 +11,7 @@ from invokeai.app.services.processor import DefaultInvocationProcessor from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory from invokeai.app.services.invoker import Invoker from invokeai.app.services.invocation_services import InvocationServices +from invokeai.app.services.invocation_stats import InvocationStatsService from invokeai.app.services.graph import ( Graph, GraphExecutionState, @@ -34,6 +35,9 @@ def simple_graph(): @pytest.fixture def mock_services() -> InvocationServices: # NOTE: none of these are actually called by the test invocations + graph_execution_manager = SqliteItemStorage[GraphExecutionState]( + filename=sqlite_memory, table_name="graph_executions" + ) return InvocationServices( model_manager=None, # type: ignore events=TestEventService(), @@ -44,10 +48,9 @@ def mock_services() -> InvocationServices: board_images=None, # type: ignore queue=MemoryInvocationQueue(), graph_library=SqliteItemStorage[LibraryGraph](filename=sqlite_memory, table_name="graphs"), - graph_execution_manager=SqliteItemStorage[GraphExecutionState]( - filename=sqlite_memory, table_name="graph_executions" - ), + graph_execution_manager=graph_execution_manager, processor=DefaultInvocationProcessor(), + performance_statistics=InvocationStatsService(graph_execution_manager), configuration=None, # type: ignore ) From 921ccad04d2b903617cd4dd6d15c4f7fa8fe9c66 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 2 Aug 2023 18:41:43 -0400 Subject: [PATCH 07/21] added stats service to the cli_app startup --- invokeai/app/cli_app.py | 2 ++ scripts/dream.py | 6 ++++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/invokeai/app/cli_app.py b/invokeai/app/cli_app.py index bad95bb559..4558c9219f 100644 --- a/invokeai/app/cli_app.py +++ b/invokeai/app/cli_app.py @@ -37,6 +37,7 @@ from invokeai.app.services.image_record_storage import SqliteImageRecordStorage from invokeai.app.services.images import ImageService, ImageServiceDependencies from invokeai.app.services.resource_name import SimpleNameService from invokeai.app.services.urls import LocalUrlService +from invokeai.app.services.invocation_stats import InvocationStatsService from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage @@ -311,6 +312,7 @@ def invoke_cli(): graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"), graph_execution_manager=graph_execution_manager, processor=DefaultInvocationProcessor(), + performance_statistics=InvocationStatsService(graph_execution_manager), logger=logger, configuration=config, ) diff --git a/scripts/dream.py b/scripts/dream.py index 12176db41e..d2735046a4 100755 --- a/scripts/dream.py +++ b/scripts/dream.py @@ -2,10 +2,12 @@ # Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein) import warnings -from invokeai.frontend.CLI import invokeai_command_line_interface as main warnings.warn( "dream.py is being deprecated, please run invoke.py for the " "new UI/API or legacy_api.py for the old API", DeprecationWarning, ) -main() + +from invokeai.app.cli_app import invoke_cli + +invoke_cli() From d2bddf7f9161d8c18dda60ebebeedc35a0dd1a78 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 3 Aug 2023 08:47:56 -0400 Subject: [PATCH 08/21] tweak formatting to accommodate longer runtimes --- invokeai/app/services/invocation_stats.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/invokeai/app/services/invocation_stats.py b/invokeai/app/services/invocation_stats.py index aca1dba550..50320a6611 100644 --- a/invokeai/app/services/invocation_stats.py +++ b/invokeai/app/services/invocation_stats.py @@ -208,12 +208,12 @@ class InvocationStatsService(InvocationStatsServiceBase): total_time = 0 logger.info(f"Graph stats: {graph_id}") - logger.info("Node Calls Seconds VRAM Used") + logger.info("Node Calls Seconds VRAM Used") for node_type, stats in self._stats[graph_id].nodes.items(): - logger.info(f"{node_type:<20} {stats.calls:>5} {stats.time_used:4.3f}s {stats.max_vram:4.2f}G") + logger.info(f"{node_type:<20} {stats.calls:>5} {stats.time_used:7.3f}s {stats.max_vram:4.2f}G") total_time += stats.time_used - logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:4.3f}s") + logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:7.3f}s") if torch.cuda.is_available(): logger.info("Current VRAM utilization " + "%4.2fG" % (torch.cuda.memory_allocated() / 1e9)) From ab5d938a1d47bd74e6b54d90f72f2905103638f2 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 3 Aug 2023 19:23:52 -0400 Subject: [PATCH 09/21] use variant instead of revision --- invokeai/backend/install/model_install_backend.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/invokeai/backend/install/model_install_backend.py b/invokeai/backend/install/model_install_backend.py index 449c234144..691a461652 100644 --- a/invokeai/backend/install/model_install_backend.py +++ b/invokeai/backend/install/model_install_backend.py @@ -419,18 +419,24 @@ class ModelInstall(object): """ _, name = repo_id.split("/") precision = torch_dtype(choose_torch_device()) - revisions = ["fp16", "main"] if precision == torch.float16 else ["main"] + variants = ["fp16",None] if precision == torch.float16 else [None,"fp16"] + model = None - for revision in revisions: + for variant in variants: try: model = DiffusionPipeline.from_pretrained( - repo_id, revision=revision, safety_checker=None, torch_dtype=precision + repo_id, + variant=variant, + torch_dtype=precision, + safety_checker=None, ) except Exception as e: # most errors are due to fp16 not being present. Fix this to catch other errors if "fp16" not in str(e): print(e) + if model: break + if not model: logger.error(f"Diffusers model {repo_id} could not be downloaded. Skipping.") return None From 446fb4a43839211a30bb940ed6532dfcec0e99a3 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 3 Aug 2023 19:24:23 -0400 Subject: [PATCH 10/21] blackify --- invokeai/backend/install/model_install_backend.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/invokeai/backend/install/model_install_backend.py b/invokeai/backend/install/model_install_backend.py index 691a461652..b2d09c1ffe 100644 --- a/invokeai/backend/install/model_install_backend.py +++ b/invokeai/backend/install/model_install_backend.py @@ -419,7 +419,7 @@ class ModelInstall(object): """ _, name = repo_id.split("/") precision = torch_dtype(choose_torch_device()) - variants = ["fp16",None] if precision == torch.float16 else [None,"fp16"] + variants = ["fp16", None] if precision == torch.float16 else [None, "fp16"] model = None for variant in variants: @@ -428,7 +428,7 @@ class ModelInstall(object): repo_id, variant=variant, torch_dtype=precision, - safety_checker=None, + safety_checker=None, ) except Exception as e: # most errors are due to fp16 not being present. Fix this to catch other errors if "fp16" not in str(e): @@ -436,7 +436,7 @@ class ModelInstall(object): if model: break - + if not model: logger.error(f"Diffusers model {repo_id} could not be downloaded. Skipping.") return None From cfc3a20565810a72cb572019f1ad9ce99b4e1177 Mon Sep 17 00:00:00 2001 From: Mary Hipp Date: Thu, 3 Aug 2023 15:47:59 -0400 Subject: [PATCH 11/21] autoAddBoardId should always be defined --- .../middleware/listenerMiddleware/listeners/imageUploaded.ts | 2 +- .../frontend/web/src/features/gallery/store/gallerySlice.ts | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts index f488259eb7..6dc2d482a9 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts @@ -40,7 +40,7 @@ export const addImageUploadedFulfilledListener = () => { // default action - just upload and alert user if (postUploadAction?.type === 'TOAST') { const { toastOptions } = postUploadAction; - if (!autoAddBoardId) { + if (!autoAddBoardId || autoAddBoardId === 'none') { dispatch(addToast({ ...DEFAULT_UPLOADED_TOAST, ...toastOptions })); } else { // Add this image to the board diff --git a/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts b/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts index 3b0dd233f1..bc7acff6f4 100644 --- a/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts +++ b/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts @@ -41,6 +41,10 @@ export const gallerySlice = createSlice({ state.galleryView = 'images'; }, autoAddBoardIdChanged: (state, action: PayloadAction) => { + if (!action.payload) { + state.autoAddBoardId = 'none'; + return; + } state.autoAddBoardId = action.payload; }, galleryViewChanged: (state, action: PayloadAction) => { From 1ac14a1e43f5943435a8a28a1bf57d059b532af4 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Mon, 31 Jul 2023 23:18:02 +0300 Subject: [PATCH 12/21] add sdxl lora support --- invokeai/app/invocations/compel.py | 30 +- invokeai/app/invocations/latent.py | 2 +- invokeai/app/invocations/model.py | 97 ++++ invokeai/app/invocations/sdxl.py | 16 +- invokeai/backend/model_management/__init__.py | 1 + invokeai/backend/model_management/lora.py | 436 +------------- .../backend/model_management/model_cache.py | 2 - .../backend/model_management/models/lora.py | 537 +++++++++++++++++- 8 files changed, 683 insertions(+), 438 deletions(-) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index ada7a06a57..bbe372ff57 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -173,7 +173,7 @@ class CompelInvocation(BaseInvocation): class SDXLPromptInvocationBase: - def run_clip_raw(self, context, clip_field, prompt, get_pooled): + def run_clip_raw(self, context, clip_field, prompt, get_pooled, lora_prefix): tokenizer_info = context.services.model_manager.get_model( **clip_field.tokenizer.dict(), context=context, @@ -210,8 +210,8 @@ class SDXLPromptInvocationBase: # print(traceback.format_exc()) print(f'Warn: trigger: "{trigger}" not found') - with ModelPatcher.apply_lora_text_encoder( - text_encoder_info.context.model, _lora_loader() + with ModelPatcher.apply_lora( + text_encoder_info.context.model, _lora_loader(), lora_prefix ), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as ( tokenizer, ti_manager, @@ -247,7 +247,7 @@ class SDXLPromptInvocationBase: return c, c_pooled, None - def run_clip_compel(self, context, clip_field, prompt, get_pooled): + def run_clip_compel(self, context, clip_field, prompt, get_pooled, lora_prefix): tokenizer_info = context.services.model_manager.get_model( **clip_field.tokenizer.dict(), context=context, @@ -284,8 +284,8 @@ class SDXLPromptInvocationBase: # print(traceback.format_exc()) print(f'Warn: trigger: "{trigger}" not found') - with ModelPatcher.apply_lora_text_encoder( - text_encoder_info.context.model, _lora_loader() + with ModelPatcher.apply_lora( + text_encoder_info.context.model, _lora_loader(), lora_prefix ), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as ( tokenizer, ti_manager, @@ -357,11 +357,11 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): @torch.no_grad() def invoke(self, context: InvocationContext) -> CompelOutput: - c1, c1_pooled, ec1 = self.run_clip_compel(context, self.clip, self.prompt, False) + c1, c1_pooled, ec1 = self.run_clip_compel(context, self.clip, self.prompt, False, "lora_te1_") if self.style.strip() == "": - c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.prompt, True) + c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.prompt, True, "lora_te2_") else: - c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True) + c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "lora_te2_") original_size = (self.original_height, self.original_width) crop_coords = (self.crop_top, self.crop_left) @@ -415,7 +415,8 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase @torch.no_grad() def invoke(self, context: InvocationContext) -> CompelOutput: - c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True) + # TODO: if there will appear lora for refiner - write proper prefix + c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "") original_size = (self.original_height, self.original_width) crop_coords = (self.crop_top, self.crop_left) @@ -467,11 +468,11 @@ class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): @torch.no_grad() def invoke(self, context: InvocationContext) -> CompelOutput: - c1, c1_pooled, ec1 = self.run_clip_raw(context, self.clip, self.prompt, False) + c1, c1_pooled, ec1 = self.run_clip_raw(context, self.clip, self.prompt, False, "lora_te1_") if self.style.strip() == "": - c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.prompt, True) + c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.prompt, True, "lora_te2_") else: - c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True) + c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True, "lora_te2_") original_size = (self.original_height, self.original_width) crop_coords = (self.crop_top, self.crop_left) @@ -525,7 +526,8 @@ class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): @torch.no_grad() def invoke(self, context: InvocationContext) -> CompelOutput: - c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True) + # TODO: if there will appear lora for refiner - write proper prefix + c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True, "") original_size = (self.original_height, self.original_width) crop_coords = (self.crop_top, self.crop_left) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 3edbe86150..6e2e0838bc 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -14,7 +14,7 @@ from invokeai.app.invocations.metadata import CoreMetadata from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.backend.model_management.models import ModelType, SilenceWarnings -from ...backend.model_management.lora import ModelPatcher +from ...backend.model_management import ModelPatcher from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion.diffusers_pipeline import ( ConditioningData, diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index c19e5c5c9a..d215d500a6 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -262,6 +262,103 @@ class LoraLoaderInvocation(BaseInvocation): return output +class SDXLLoraLoaderOutput(BaseInvocationOutput): + """Model loader output""" + + # fmt: off + type: Literal["sdxl_lora_loader_output"] = "sdxl_lora_loader_output" + + unet: Optional[UNetField] = Field(default=None, description="UNet submodel") + clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels") + clip2: Optional[ClipField] = Field(default=None, description="Tokenizer2 and text_encoder2 submodels") + # fmt: on + + +class SDXLLoraLoaderInvocation(BaseInvocation): + """Apply selected lora to unet and text_encoder.""" + + type: Literal["sdxl_lora_loader"] = "sdxl_lora_loader" + + lora: Union[LoRAModelField, None] = Field(default=None, description="Lora model name") + weight: float = Field(default=0.75, description="With what weight to apply lora") + + unet: Optional[UNetField] = Field(description="UNet model for applying lora") + clip: Optional[ClipField] = Field(description="Clip model for applying lora") + clip2: Optional[ClipField] = Field(description="Clip2 model for applying lora") + + class Config(InvocationConfig): + schema_extra = { + "ui": { + "title": "SDXL Lora Loader", + "tags": ["lora", "loader"], + "type_hints": {"lora": "lora_model"}, + }, + } + + def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: + if self.lora is None: + raise Exception("No LoRA provided") + + base_model = self.lora.base_model + lora_name = self.lora.model_name + + if not context.services.model_manager.model_exists( + base_model=base_model, + model_name=lora_name, + model_type=ModelType.Lora, + ): + raise Exception(f"Unkown lora name: {lora_name}!") + + if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras): + raise Exception(f'Lora "{lora_name}" already applied to unet') + + if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras): + raise Exception(f'Lora "{lora_name}" already applied to clip') + + if self.clip2 is not None and any(lora.model_name == lora_name for lora in self.clip2.loras): + raise Exception(f'Lora "{lora_name}" already applied to clip2') + + output = SDXLLoraLoaderOutput() + + if self.unet is not None: + output.unet = copy.deepcopy(self.unet) + output.unet.loras.append( + LoraInfo( + base_model=base_model, + model_name=lora_name, + model_type=ModelType.Lora, + submodel=None, + weight=self.weight, + ) + ) + + if self.clip is not None: + output.clip = copy.deepcopy(self.clip) + output.clip.loras.append( + LoraInfo( + base_model=base_model, + model_name=lora_name, + model_type=ModelType.Lora, + submodel=None, + weight=self.weight, + ) + ) + + if self.clip2 is not None: + output.clip2 = copy.deepcopy(self.clip2) + output.clip2.loras.append( + LoraInfo( + base_model=base_model, + model_name=lora_name, + model_type=ModelType.Lora, + submodel=None, + weight=self.weight, + ) + ) + + return output + + class VAEModelField(BaseModel): """Vae model field""" diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index 7dfceba853..faa6b59782 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -5,7 +5,7 @@ from typing import List, Literal, Optional, Union from pydantic import Field, validator -from ...backend.model_management import ModelType, SubModelType +from ...backend.model_management import ModelType, SubModelType, ModelPatcher from invokeai.app.util.step_callback import stable_diffusion_xl_step_callback from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext @@ -293,10 +293,22 @@ class SDXLTextToLatentsInvocation(BaseInvocation): num_inference_steps = self.steps + def _lora_loader(): + for lora in self.unet.loras: + lora_info = context.services.model_manager.get_model( + **lora.dict(exclude={"weight"}), + context=context, + ) + yield (lora_info.context.model, lora.weight) + del lora_info + return + unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context) do_classifier_free_guidance = True cross_attention_kwargs = None - with unet_info as unet: + with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ + unet_info as unet: + scheduler.set_timesteps(num_inference_steps, device=unet.device) timesteps = scheduler.timesteps diff --git a/invokeai/backend/model_management/__init__.py b/invokeai/backend/model_management/__init__.py index cf057f3a89..8e083c1045 100644 --- a/invokeai/backend/model_management/__init__.py +++ b/invokeai/backend/model_management/__init__.py @@ -13,3 +13,4 @@ from .models import ( DuplicateModelException, ) from .model_merge import ModelMerger, MergeInterpolationMethod +from .lora import ModelPatcher diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py index 4287072a65..56f7a648c9 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_management/lora.py @@ -23,421 +23,6 @@ from transformers import CLIPTextModel, CLIPTokenizer # TODO: rename and split this file -class LoRALayerBase: - # rank: Optional[int] - # alpha: Optional[float] - # bias: Optional[torch.Tensor] - # layer_key: str - - # @property - # def scale(self): - # return self.alpha / self.rank if (self.alpha and self.rank) else 1.0 - - def __init__( - self, - layer_key: str, - values: dict, - ): - if "alpha" in values: - self.alpha = values["alpha"].item() - else: - self.alpha = None - - if "bias_indices" in values and "bias_values" in values and "bias_size" in values: - self.bias = torch.sparse_coo_tensor( - values["bias_indices"], - values["bias_values"], - tuple(values["bias_size"]), - ) - - else: - self.bias = None - - self.rank = None # set in layer implementation - self.layer_key = layer_key - - def forward( - self, - module: torch.nn.Module, - input_h: Any, # for real looks like Tuple[torch.nn.Tensor] but not sure - multiplier: float, - ): - if type(module) == torch.nn.Conv2d: - op = torch.nn.functional.conv2d - extra_args = dict( - stride=module.stride, - padding=module.padding, - dilation=module.dilation, - groups=module.groups, - ) - - else: - op = torch.nn.functional.linear - extra_args = {} - - weight = self.get_weight() - - bias = self.bias if self.bias is not None else 0 - scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0 - return ( - op( - *input_h, - (weight + bias).view(module.weight.shape), - None, - **extra_args, - ) - * multiplier - * scale - ) - - def get_weight(self): - raise NotImplementedError() - - def calc_size(self) -> int: - model_size = 0 - for val in [self.bias]: - if val is not None: - model_size += val.nelement() * val.element_size() - return model_size - - def to( - self, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - if self.bias is not None: - self.bias = self.bias.to(device=device, dtype=dtype) - - -# TODO: find and debug lora/locon with bias -class LoRALayer(LoRALayerBase): - # up: torch.Tensor - # mid: Optional[torch.Tensor] - # down: torch.Tensor - - def __init__( - self, - layer_key: str, - values: dict, - ): - super().__init__(layer_key, values) - - self.up = values["lora_up.weight"] - self.down = values["lora_down.weight"] - if "lora_mid.weight" in values: - self.mid = values["lora_mid.weight"] - else: - self.mid = None - - self.rank = self.down.shape[0] - - def get_weight(self): - if self.mid is not None: - up = self.up.reshape(self.up.shape[0], self.up.shape[1]) - down = self.down.reshape(self.down.shape[0], self.down.shape[1]) - weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down) - else: - weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1) - - return weight - - def calc_size(self) -> int: - model_size = super().calc_size() - for val in [self.up, self.mid, self.down]: - if val is not None: - model_size += val.nelement() * val.element_size() - return model_size - - def to( - self, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - super().to(device=device, dtype=dtype) - - self.up = self.up.to(device=device, dtype=dtype) - self.down = self.down.to(device=device, dtype=dtype) - - if self.mid is not None: - self.mid = self.mid.to(device=device, dtype=dtype) - - -class LoHALayer(LoRALayerBase): - # w1_a: torch.Tensor - # w1_b: torch.Tensor - # w2_a: torch.Tensor - # w2_b: torch.Tensor - # t1: Optional[torch.Tensor] = None - # t2: Optional[torch.Tensor] = None - - def __init__( - self, - layer_key: str, - values: dict, - ): - super().__init__(layer_key, values) - - self.w1_a = values["hada_w1_a"] - self.w1_b = values["hada_w1_b"] - self.w2_a = values["hada_w2_a"] - self.w2_b = values["hada_w2_b"] - - if "hada_t1" in values: - self.t1 = values["hada_t1"] - else: - self.t1 = None - - if "hada_t2" in values: - self.t2 = values["hada_t2"] - else: - self.t2 = None - - self.rank = self.w1_b.shape[0] - - def get_weight(self): - if self.t1 is None: - weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b) - - else: - rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a) - rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a) - weight = rebuild1 * rebuild2 - - return weight - - def calc_size(self) -> int: - model_size = super().calc_size() - for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]: - if val is not None: - model_size += val.nelement() * val.element_size() - return model_size - - def to( - self, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - super().to(device=device, dtype=dtype) - - self.w1_a = self.w1_a.to(device=device, dtype=dtype) - self.w1_b = self.w1_b.to(device=device, dtype=dtype) - if self.t1 is not None: - self.t1 = self.t1.to(device=device, dtype=dtype) - - self.w2_a = self.w2_a.to(device=device, dtype=dtype) - self.w2_b = self.w2_b.to(device=device, dtype=dtype) - if self.t2 is not None: - self.t2 = self.t2.to(device=device, dtype=dtype) - - -class LoKRLayer(LoRALayerBase): - # w1: Optional[torch.Tensor] = None - # w1_a: Optional[torch.Tensor] = None - # w1_b: Optional[torch.Tensor] = None - # w2: Optional[torch.Tensor] = None - # w2_a: Optional[torch.Tensor] = None - # w2_b: Optional[torch.Tensor] = None - # t2: Optional[torch.Tensor] = None - - def __init__( - self, - layer_key: str, - values: dict, - ): - super().__init__(layer_key, values) - - if "lokr_w1" in values: - self.w1 = values["lokr_w1"] - self.w1_a = None - self.w1_b = None - else: - self.w1 = None - self.w1_a = values["lokr_w1_a"] - self.w1_b = values["lokr_w1_b"] - - if "lokr_w2" in values: - self.w2 = values["lokr_w2"] - self.w2_a = None - self.w2_b = None - else: - self.w2 = None - self.w2_a = values["lokr_w2_a"] - self.w2_b = values["lokr_w2_b"] - - if "lokr_t2" in values: - self.t2 = values["lokr_t2"] - else: - self.t2 = None - - if "lokr_w1_b" in values: - self.rank = values["lokr_w1_b"].shape[0] - elif "lokr_w2_b" in values: - self.rank = values["lokr_w2_b"].shape[0] - else: - self.rank = None # unscaled - - def get_weight(self): - w1 = self.w1 - if w1 is None: - w1 = self.w1_a @ self.w1_b - - w2 = self.w2 - if w2 is None: - if self.t2 is None: - w2 = self.w2_a @ self.w2_b - else: - w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b) - - if len(w2.shape) == 4: - w1 = w1.unsqueeze(2).unsqueeze(2) - w2 = w2.contiguous() - weight = torch.kron(w1, w2) - - return weight - - def calc_size(self) -> int: - model_size = super().calc_size() - for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]: - if val is not None: - model_size += val.nelement() * val.element_size() - return model_size - - def to( - self, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - super().to(device=device, dtype=dtype) - - if self.w1 is not None: - self.w1 = self.w1.to(device=device, dtype=dtype) - else: - self.w1_a = self.w1_a.to(device=device, dtype=dtype) - self.w1_b = self.w1_b.to(device=device, dtype=dtype) - - if self.w2 is not None: - self.w2 = self.w2.to(device=device, dtype=dtype) - else: - self.w2_a = self.w2_a.to(device=device, dtype=dtype) - self.w2_b = self.w2_b.to(device=device, dtype=dtype) - - if self.t2 is not None: - self.t2 = self.t2.to(device=device, dtype=dtype) - - -class LoRAModel: # (torch.nn.Module): - _name: str - layers: Dict[str, LoRALayer] - _device: torch.device - _dtype: torch.dtype - - def __init__( - self, - name: str, - layers: Dict[str, LoRALayer], - device: torch.device, - dtype: torch.dtype, - ): - self._name = name - self._device = device or torch.cpu - self._dtype = dtype or torch.float32 - self.layers = layers - - @property - def name(self): - return self._name - - @property - def device(self): - return self._device - - @property - def dtype(self): - return self._dtype - - def to( - self, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ) -> LoRAModel: - # TODO: try revert if exception? - for key, layer in self.layers.items(): - layer.to(device=device, dtype=dtype) - self._device = device - self._dtype = dtype - - def calc_size(self) -> int: - model_size = 0 - for _, layer in self.layers.items(): - model_size += layer.calc_size() - return model_size - - @classmethod - def from_checkpoint( - cls, - file_path: Union[str, Path], - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - device = device or torch.device("cpu") - dtype = dtype or torch.float32 - - if isinstance(file_path, str): - file_path = Path(file_path) - - model = cls( - device=device, - dtype=dtype, - name=file_path.stem, # TODO: - layers=dict(), - ) - - if file_path.suffix == ".safetensors": - state_dict = load_file(file_path.absolute().as_posix(), device="cpu") - else: - state_dict = torch.load(file_path, map_location="cpu") - - state_dict = cls._group_state(state_dict) - - for layer_key, values in state_dict.items(): - # lora and locon - if "lora_down.weight" in values: - layer = LoRALayer(layer_key, values) - - # loha - elif "hada_w1_b" in values: - layer = LoHALayer(layer_key, values) - - # lokr - elif "lokr_w1_b" in values or "lokr_w1" in values: - layer = LoKRLayer(layer_key, values) - - else: - # TODO: diff/ia3/... format - print(f">> Encountered unknown lora layer module in {model.name}: {layer_key}") - return - - # lower memory consumption by removing already parsed layer values - state_dict[layer_key].clear() - - layer.to(device=device, dtype=dtype) - model.layers[layer_key] = layer - - return model - - @staticmethod - def _group_state(state_dict: dict): - state_dict_groupped = dict() - - for key, value in state_dict.items(): - stem, leaf = key.split(".", 1) - if stem not in state_dict_groupped: - state_dict_groupped[stem] = dict() - state_dict_groupped[stem][leaf] = value - - return state_dict_groupped - - """ loras = [ (lora_model1, 0.7), @@ -516,6 +101,27 @@ class ModelPatcher: with cls.apply_lora(text_encoder, loras, "lora_te_"): yield + + @classmethod + @contextmanager + def apply_sdxl_lora_text_encoder( + cls, + text_encoder: CLIPTextModel, + loras: List[Tuple[LoRAModel, float]], + ): + with cls.apply_lora(text_encoder, loras, "lora_te1_"): + yield + + @classmethod + @contextmanager + def apply_sdxl_lora_text_encoder2( + cls, + text_encoder: CLIPTextModel, + loras: List[Tuple[LoRAModel, float]], + ): + with cls.apply_lora(text_encoder, loras, "lora_te2_"): + yield + @classmethod @contextmanager def apply_lora( diff --git a/invokeai/backend/model_management/model_cache.py b/invokeai/backend/model_management/model_cache.py index b4c3e48a48..71e1ebc0d4 100644 --- a/invokeai/backend/model_management/model_cache.py +++ b/invokeai/backend/model_management/model_cache.py @@ -28,8 +28,6 @@ import torch import logging import invokeai.backend.util.logging as logger -from invokeai.app.services.config import get_invokeai_config -from .lora import LoRAModel, TextualInversionModel from .models import BaseModelType, ModelType, SubModelType, ModelBase # Maximum size of the cache, in gigs diff --git a/invokeai/backend/model_management/models/lora.py b/invokeai/backend/model_management/models/lora.py index 642f8bbeec..0351bf2652 100644 --- a/invokeai/backend/model_management/models/lora.py +++ b/invokeai/backend/model_management/models/lora.py @@ -1,7 +1,9 @@ import os import torch from enum import Enum -from typing import Optional, Union, Literal +from typing import Optional, Dict, Union, Literal, Any +from pathlib import Path +from safetensors.torch import load_file from .base import ( ModelBase, ModelConfigBase, @@ -13,9 +15,6 @@ from .base import ( ModelNotFoundException, ) -# TODO: naming -from ..lora import LoRAModel as LoRAModelRaw - class LoRAModelFormat(str, Enum): LyCORIS = "lycoris" @@ -50,6 +49,7 @@ class LoRAModel(ModelBase): model = LoRAModelRaw.from_checkpoint( file_path=self.model_path, dtype=torch_dtype, + base_model=self.base_model, ) self.model_size = model.calc_size() @@ -87,3 +87,532 @@ class LoRAModel(ModelBase): raise NotImplementedError("Diffusers lora not supported") else: return model_path + +class LoRALayerBase: + # rank: Optional[int] + # alpha: Optional[float] + # bias: Optional[torch.Tensor] + # layer_key: str + + # @property + # def scale(self): + # return self.alpha / self.rank if (self.alpha and self.rank) else 1.0 + + def __init__( + self, + layer_key: str, + values: dict, + ): + if "alpha" in values: + self.alpha = values["alpha"].item() + else: + self.alpha = None + + if "bias_indices" in values and "bias_values" in values and "bias_size" in values: + self.bias = torch.sparse_coo_tensor( + values["bias_indices"], + values["bias_values"], + tuple(values["bias_size"]), + ) + + else: + self.bias = None + + self.rank = None # set in layer implementation + self.layer_key = layer_key + + def forward( + self, + module: torch.nn.Module, + input_h: Any, # for real looks like Tuple[torch.nn.Tensor] but not sure + multiplier: float, + ): + if type(module) == torch.nn.Conv2d: + op = torch.nn.functional.conv2d + extra_args = dict( + stride=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + ) + + else: + op = torch.nn.functional.linear + extra_args = {} + + weight = self.get_weight() + + bias = self.bias if self.bias is not None else 0 + scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0 + return ( + op( + *input_h, + (weight + bias).view(module.weight.shape), + None, + **extra_args, + ) + * multiplier + * scale + ) + + def get_weight(self): + raise NotImplementedError() + + def calc_size(self) -> int: + model_size = 0 + for val in [self.bias]: + if val is not None: + model_size += val.nelement() * val.element_size() + return model_size + + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + if self.bias is not None: + self.bias = self.bias.to(device=device, dtype=dtype) + +# TODO: find and debug lora/locon with bias +class LoRALayer(LoRALayerBase): + # up: torch.Tensor + # mid: Optional[torch.Tensor] + # down: torch.Tensor + + def __init__( + self, + layer_key: str, + values: dict, + ): + super().__init__(layer_key, values) + + self.up = values["lora_up.weight"] + self.down = values["lora_down.weight"] + if "lora_mid.weight" in values: + self.mid = values["lora_mid.weight"] + else: + self.mid = None + + self.rank = self.down.shape[0] + + def get_weight(self): + if self.mid is not None: + up = self.up.reshape(self.up.shape[0], self.up.shape[1]) + down = self.down.reshape(self.down.shape[0], self.down.shape[1]) + weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down) + else: + weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1) + + return weight + + def calc_size(self) -> int: + model_size = super().calc_size() + for val in [self.up, self.mid, self.down]: + if val is not None: + model_size += val.nelement() * val.element_size() + return model_size + + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + super().to(device=device, dtype=dtype) + + self.up = self.up.to(device=device, dtype=dtype) + self.down = self.down.to(device=device, dtype=dtype) + + if self.mid is not None: + self.mid = self.mid.to(device=device, dtype=dtype) + +class LoHALayer(LoRALayerBase): + # w1_a: torch.Tensor + # w1_b: torch.Tensor + # w2_a: torch.Tensor + # w2_b: torch.Tensor + # t1: Optional[torch.Tensor] = None + # t2: Optional[torch.Tensor] = None + + def __init__( + self, + layer_key: str, + values: dict, + ): + super().__init__(layer_key, values) + + self.w1_a = values["hada_w1_a"] + self.w1_b = values["hada_w1_b"] + self.w2_a = values["hada_w2_a"] + self.w2_b = values["hada_w2_b"] + + if "hada_t1" in values: + self.t1 = values["hada_t1"] + else: + self.t1 = None + + if "hada_t2" in values: + self.t2 = values["hada_t2"] + else: + self.t2 = None + + self.rank = self.w1_b.shape[0] + + def get_weight(self): + if self.t1 is None: + weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b) + + else: + rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a) + rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a) + weight = rebuild1 * rebuild2 + + return weight + + def calc_size(self) -> int: + model_size = super().calc_size() + for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]: + if val is not None: + model_size += val.nelement() * val.element_size() + return model_size + + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + super().to(device=device, dtype=dtype) + + self.w1_a = self.w1_a.to(device=device, dtype=dtype) + self.w1_b = self.w1_b.to(device=device, dtype=dtype) + if self.t1 is not None: + self.t1 = self.t1.to(device=device, dtype=dtype) + + self.w2_a = self.w2_a.to(device=device, dtype=dtype) + self.w2_b = self.w2_b.to(device=device, dtype=dtype) + if self.t2 is not None: + self.t2 = self.t2.to(device=device, dtype=dtype) + +class LoKRLayer(LoRALayerBase): + # w1: Optional[torch.Tensor] = None + # w1_a: Optional[torch.Tensor] = None + # w1_b: Optional[torch.Tensor] = None + # w2: Optional[torch.Tensor] = None + # w2_a: Optional[torch.Tensor] = None + # w2_b: Optional[torch.Tensor] = None + # t2: Optional[torch.Tensor] = None + + def __init__( + self, + layer_key: str, + values: dict, + ): + super().__init__(layer_key, values) + + if "lokr_w1" in values: + self.w1 = values["lokr_w1"] + self.w1_a = None + self.w1_b = None + else: + self.w1 = None + self.w1_a = values["lokr_w1_a"] + self.w1_b = values["lokr_w1_b"] + + if "lokr_w2" in values: + self.w2 = values["lokr_w2"] + self.w2_a = None + self.w2_b = None + else: + self.w2 = None + self.w2_a = values["lokr_w2_a"] + self.w2_b = values["lokr_w2_b"] + + if "lokr_t2" in values: + self.t2 = values["lokr_t2"] + else: + self.t2 = None + + if "lokr_w1_b" in values: + self.rank = values["lokr_w1_b"].shape[0] + elif "lokr_w2_b" in values: + self.rank = values["lokr_w2_b"].shape[0] + else: + self.rank = None # unscaled + + def get_weight(self): + w1 = self.w1 + if w1 is None: + w1 = self.w1_a @ self.w1_b + + w2 = self.w2 + if w2 is None: + if self.t2 is None: + w2 = self.w2_a @ self.w2_b + else: + w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b) + + if len(w2.shape) == 4: + w1 = w1.unsqueeze(2).unsqueeze(2) + w2 = w2.contiguous() + weight = torch.kron(w1, w2) + + return weight + + def calc_size(self) -> int: + model_size = super().calc_size() + for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]: + if val is not None: + model_size += val.nelement() * val.element_size() + return model_size + + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + super().to(device=device, dtype=dtype) + + if self.w1 is not None: + self.w1 = self.w1.to(device=device, dtype=dtype) + else: + self.w1_a = self.w1_a.to(device=device, dtype=dtype) + self.w1_b = self.w1_b.to(device=device, dtype=dtype) + + if self.w2 is not None: + self.w2 = self.w2.to(device=device, dtype=dtype) + else: + self.w2_a = self.w2_a.to(device=device, dtype=dtype) + self.w2_b = self.w2_b.to(device=device, dtype=dtype) + + if self.t2 is not None: + self.t2 = self.t2.to(device=device, dtype=dtype) + +# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix +class LoRAModelRaw: # (torch.nn.Module): + _name: str + layers: Dict[str, LoRALayer] + _device: torch.device + _dtype: torch.dtype + + def __init__( + self, + name: str, + layers: Dict[str, LoRALayer], + device: torch.device, + dtype: torch.dtype, + ): + self._name = name + self._device = device or torch.cpu + self._dtype = dtype or torch.float32 + self.layers = layers + + @property + def name(self): + return self._name + + @property + def device(self): + return self._device + + @property + def dtype(self): + return self._dtype + + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + # TODO: try revert if exception? + for key, layer in self.layers.items(): + layer.to(device=device, dtype=dtype) + self._device = device + self._dtype = dtype + + def calc_size(self) -> int: + model_size = 0 + for _, layer in self.layers.items(): + model_size += layer.calc_size() + return model_size + + @classmethod + def _convert_sdxl_compvis_keys(cls, state_dict): + new_state_dict = dict() + for full_key, value in state_dict.items(): + if full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"): + continue # clip same + + if not full_key.startswith("lora_unet_"): + raise NotImplementedError(f"Unknown prefix for sdxl lora key - {full_key}") + src_key = full_key.replace("lora_unet_", "") + try: + dst_key = None + while "_" in src_key: + if src_key in SDXL_UNET_COMPVIS_MAP: + dst_key = SDXL_UNET_COMPVIS_MAP[src_key] + break + src_key = "_".join(src_key.split('_')[:-1]) + + if dst_key is None: + raise Exception(f"Unknown sdxl lora key - {full_key}") + new_key = full_key.replace(src_key, dst_key) + except: + print(SDXL_UNET_COMPVIS_MAP) + raise + new_state_dict[new_key] = value + return new_state_dict + + @classmethod + def from_checkpoint( + cls, + file_path: Union[str, Path], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + base_model: Optional[BaseModelType] = None, + ): + device = device or torch.device("cpu") + dtype = dtype or torch.float32 + + if isinstance(file_path, str): + file_path = Path(file_path) + + model = cls( + device=device, + dtype=dtype, + name=file_path.stem, # TODO: + layers=dict(), + ) + + if file_path.suffix == ".safetensors": + state_dict = load_file(file_path.absolute().as_posix(), device="cpu") + else: + state_dict = torch.load(file_path, map_location="cpu") + + state_dict = cls._group_state(state_dict) + + if base_model == BaseModelType.StableDiffusionXL: + state_dict = cls._convert_sdxl_compvis_keys(state_dict) + + for layer_key, values in state_dict.items(): + # lora and locon + if "lora_down.weight" in values: + layer = LoRALayer(layer_key, values) + + # loha + elif "hada_w1_b" in values: + layer = LoHALayer(layer_key, values) + + # lokr + elif "lokr_w1_b" in values or "lokr_w1" in values: + layer = LoKRLayer(layer_key, values) + + else: + # TODO: diff/ia3/... format + print(f">> Encountered unknown lora layer module in {model.name}: {layer_key}") + return + + # lower memory consumption by removing already parsed layer values + state_dict[layer_key].clear() + + layer.to(device=device, dtype=dtype) + model.layers[layer_key] = layer + + return model + + @staticmethod + def _group_state(state_dict: dict): + state_dict_groupped = dict() + + for key, value in state_dict.items(): + stem, leaf = key.split(".", 1) + if stem not in state_dict_groupped: + state_dict_groupped[stem] = dict() + state_dict_groupped[stem][leaf] = value + + return state_dict_groupped + + +def make_sdxl_unet_conversion_map(): + unet_conversion_map_layer = [] + + for i in range(3): # num_blocks is 3 in sdxl + # loop over downblocks/upblocks + for j in range(2): + # loop over resnets/attentions for downblocks + hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." + sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." + unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) + + if i < 3: + # no attention layers in down_blocks.3 + hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." + sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." + unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) + + for j in range(3): + # loop over resnets/attentions for upblocks + hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." + sd_up_res_prefix = f"output_blocks.{3*i + j}.0." + unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) + + # if i > 0: commentout for sdxl + # no attention layers in up_blocks.0 + hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." + sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." + unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) + + if i < 3: + # no downsample in down_blocks.3 + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." + sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." + unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) + + # no upsample in up_blocks.3 + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl + unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) + + hf_mid_atn_prefix = "mid_block.attentions.0." + sd_mid_atn_prefix = "middle_block.1." + unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) + + for j in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{j}." + sd_mid_res_prefix = f"middle_block.{2*j}." + unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) + + unet_conversion_map_resnet = [ + # (stable-diffusion, HF Diffusers) + ("in_layers.0.", "norm1."), + ("in_layers.2.", "conv1."), + ("out_layers.0.", "norm2."), + ("out_layers.3.", "conv2."), + ("emb_layers.1.", "time_emb_proj."), + ("skip_connection.", "conv_shortcut."), + ] + + unet_conversion_map = [] + for sd, hf in unet_conversion_map_layer: + if "resnets" in hf: + for sd_res, hf_res in unet_conversion_map_resnet: + unet_conversion_map.append((sd + sd_res, hf + hf_res)) + else: + unet_conversion_map.append((sd, hf)) + + for j in range(2): + hf_time_embed_prefix = f"time_embedding.linear_{j+1}." + sd_time_embed_prefix = f"time_embed.{j*2}." + unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix)) + + for j in range(2): + hf_label_embed_prefix = f"add_embedding.linear_{j+1}." + sd_label_embed_prefix = f"label_emb.0.{j*2}." + unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix)) + + unet_conversion_map.append(("input_blocks.0.0.", "conv_in.")) + unet_conversion_map.append(("out.0.", "conv_norm_out.")) + unet_conversion_map.append(("out.2.", "conv_out.")) + + return unet_conversion_map + +#_sdxl_conversion_map = {f"lora_unet_{sd}".rstrip(".").replace(".", "_"): f"lora_unet_{hf}".rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()} +SDXL_UNET_COMPVIS_MAP = {f"{sd}".rstrip(".").replace(".", "_"): f"{hf}".rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()} From 1d5d187ba10f4d2dab3d2bb1213e41cae37d701c Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 3 Aug 2023 10:26:52 -0400 Subject: [PATCH 13/21] model probe detects sdxl lora models --- invokeai/app/invocations/sdxl.py | 4 +-- invokeai/backend/model_management/lora.py | 1 - .../backend/model_management/model_probe.py | 25 ++++++++++++++++--- .../backend/model_management/models/lora.py | 17 ++++++++++--- scripts/probe-model.py | 6 +++-- 5 files changed, 39 insertions(+), 14 deletions(-) diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index faa6b59782..aaa616a378 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -306,9 +306,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation): unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context) do_classifier_free_guidance = True cross_attention_kwargs = None - with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ - unet_info as unet: - + with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()), unet_info as unet: scheduler.set_timesteps(num_inference_steps, device=unet.device) timesteps = scheduler.timesteps diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py index 56f7a648c9..0a0ab3d629 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_management/lora.py @@ -101,7 +101,6 @@ class ModelPatcher: with cls.apply_lora(text_encoder, loras, "lora_te_"): yield - @classmethod @contextmanager def apply_sdxl_lora_text_encoder( diff --git a/invokeai/backend/model_management/model_probe.py b/invokeai/backend/model_management/model_probe.py index c3964d760c..21462cf6e6 100644 --- a/invokeai/backend/model_management/model_probe.py +++ b/invokeai/backend/model_management/model_probe.py @@ -315,21 +315,38 @@ class LoRACheckpointProbe(CheckpointProbeBase): def get_base_type(self) -> BaseModelType: checkpoint = self.checkpoint + + # SD-2 models are very hard to probe. These probes are brittle and likely to fail in the future + # There are also some "SD-2 LoRAs" that have identical keys and shapes to SD-1 and will be + # misclassified as SD-1 + key = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight" + if key in checkpoint and checkpoint[key].shape[0] == 320: + return BaseModelType.StableDiffusion2 + + key = "lora_unet_output_blocks_5_1_transformer_blocks_1_ff_net_2.lora_up.weight" + if key in checkpoint: + return BaseModelType.StableDiffusionXL + key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight" - key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a" + key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" + key3 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a" + lora_token_vector_length = ( checkpoint[key1].shape[1] if key1 in checkpoint - else checkpoint[key2].shape[0] + else checkpoint[key2].shape[1] if key2 in checkpoint - else 768 + else checkpoint[key3].shape[0] + if key3 in checkpoint + else None ) + if lora_token_vector_length == 768: return BaseModelType.StableDiffusion1 elif lora_token_vector_length == 1024: return BaseModelType.StableDiffusion2 else: - return None + raise InvalidModelException(f"Unknown LoRA type") class TextualInversionCheckpointProbe(CheckpointProbeBase): diff --git a/invokeai/backend/model_management/models/lora.py b/invokeai/backend/model_management/models/lora.py index 0351bf2652..0870e78469 100644 --- a/invokeai/backend/model_management/models/lora.py +++ b/invokeai/backend/model_management/models/lora.py @@ -88,6 +88,7 @@ class LoRAModel(ModelBase): else: return model_path + class LoRALayerBase: # rank: Optional[int] # alpha: Optional[float] @@ -173,6 +174,7 @@ class LoRALayerBase: if self.bias is not None: self.bias = self.bias.to(device=device, dtype=dtype) + # TODO: find and debug lora/locon with bias class LoRALayer(LoRALayerBase): # up: torch.Tensor @@ -225,6 +227,7 @@ class LoRALayer(LoRALayerBase): if self.mid is not None: self.mid = self.mid.to(device=device, dtype=dtype) + class LoHALayer(LoRALayerBase): # w1_a: torch.Tensor # w1_b: torch.Tensor @@ -292,6 +295,7 @@ class LoHALayer(LoRALayerBase): if self.t2 is not None: self.t2 = self.t2.to(device=device, dtype=dtype) + class LoKRLayer(LoRALayerBase): # w1: Optional[torch.Tensor] = None # w1_a: Optional[torch.Tensor] = None @@ -386,6 +390,7 @@ class LoKRLayer(LoRALayerBase): if self.t2 is not None: self.t2 = self.t2.to(device=device, dtype=dtype) + # TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix class LoRAModelRaw: # (torch.nn.Module): _name: str @@ -439,7 +444,7 @@ class LoRAModelRaw: # (torch.nn.Module): new_state_dict = dict() for full_key, value in state_dict.items(): if full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"): - continue # clip same + continue # clip same if not full_key.startswith("lora_unet_"): raise NotImplementedError(f"Unknown prefix for sdxl lora key - {full_key}") @@ -450,7 +455,7 @@ class LoRAModelRaw: # (torch.nn.Module): if src_key in SDXL_UNET_COMPVIS_MAP: dst_key = SDXL_UNET_COMPVIS_MAP[src_key] break - src_key = "_".join(src_key.split('_')[:-1]) + src_key = "_".join(src_key.split("_")[:-1]) if dst_key is None: raise Exception(f"Unknown sdxl lora key - {full_key}") @@ -614,5 +619,9 @@ def make_sdxl_unet_conversion_map(): return unet_conversion_map -#_sdxl_conversion_map = {f"lora_unet_{sd}".rstrip(".").replace(".", "_"): f"lora_unet_{hf}".rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()} -SDXL_UNET_COMPVIS_MAP = {f"{sd}".rstrip(".").replace(".", "_"): f"{hf}".rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()} + +# _sdxl_conversion_map = {f"lora_unet_{sd}".rstrip(".").replace(".", "_"): f"lora_unet_{hf}".rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()} +SDXL_UNET_COMPVIS_MAP = { + f"{sd}".rstrip(".").replace(".", "_"): f"{hf}".rstrip(".").replace(".", "_") + for sd, hf in make_sdxl_unet_conversion_map() +} diff --git a/scripts/probe-model.py b/scripts/probe-model.py index 7281dafc3f..4cf2c50263 100755 --- a/scripts/probe-model.py +++ b/scripts/probe-model.py @@ -9,8 +9,10 @@ parser = argparse.ArgumentParser(description="Probe model type") parser.add_argument( "model_path", type=Path, + nargs="+", ) args = parser.parse_args() -info = ModelProbe().probe(args.model_path) -print(info) +for path in args.model_path: + info = ModelProbe().probe(path) + print(f"{path}: {info}") From cff91f06d36376e7936db97761df9aaea4395d53 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Thu, 3 Aug 2023 19:04:44 +0300 Subject: [PATCH 14/21] Add lora apply in sdxl l2l node --- invokeai/app/invocations/sdxl.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index aaa616a378..5bcd85db28 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -553,9 +553,19 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation): context=context, ) + def _lora_loader(): + for lora in self.unet.loras: + lora_info = context.services.model_manager.get_model( + **lora.dict(exclude={"weight"}), + context=context, + ) + yield (lora_info.context.model, lora.weight) + del lora_info + return + do_classifier_free_guidance = True cross_attention_kwargs = None - with unet_info as unet: + with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()), unet_info as unet: # apply denoising_start num_inference_steps = self.steps scheduler.set_timesteps(num_inference_steps, device=unet.device) From 0d3c27f46c3276d86a51191ad3d036a2fe51f13f Mon Sep 17 00:00:00 2001 From: StAlKeR7779 Date: Fri, 4 Aug 2023 03:07:21 +0300 Subject: [PATCH 15/21] Fix typo Co-authored-by: Ryan Dick --- invokeai/app/invocations/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index d215d500a6..0d21f8f0ce 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -307,7 +307,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation): model_name=lora_name, model_type=ModelType.Lora, ): - raise Exception(f"Unkown lora name: {lora_name}!") + raise Exception(f"Unknown lora name: {lora_name}!") if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras): raise Exception(f'Lora "{lora_name}" already applied to unet') From 2f8b928486eb8d4e7c94a7eca122e1ab08fbc0e4 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Tue, 1 Aug 2023 17:02:57 +0300 Subject: [PATCH 16/21] Add support for diff/full lora layers --- invokeai/backend/model_management/lora.py | 46 +++++++++++++++++++++-- 1 file changed, 43 insertions(+), 3 deletions(-) diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py index 4287072a65..14a78693ea 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_management/lora.py @@ -325,6 +325,43 @@ class LoKRLayer(LoRALayerBase): self.t2 = self.t2.to(device=device, dtype=dtype) +class FullLayer(LoRALayerBase): + # weight: torch.Tensor + + def __init__( + self, + layer_key: str, + values: dict, + ): + super().__init__(layer_key, values) + + self.weight = values["diff"] + + if len(values.keys()) > 1: + _keys = list(values.keys()) + _keys.remove("diff") + raise NotImplementedError(f"Unexpected keys in lora diff layer: {_keys}") + + self.rank = None # unscaled + + def get_weight(self): + return self.weight + + def calc_size(self) -> int: + model_size = super().calc_size() + model_size += self.weight.nelement() * self.weight.element_size() + return model_size + + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + super().to(device=device, dtype=dtype) + + self.weight = self.weight.to(device=device, dtype=dtype) + + class LoRAModel: # (torch.nn.Module): _name: str layers: Dict[str, LoRALayer] @@ -412,10 +449,13 @@ class LoRAModel: # (torch.nn.Module): elif "lokr_w1_b" in values or "lokr_w1" in values: layer = LoKRLayer(layer_key, values) + elif "diff" in values: + layer = FullLayer(layer_key, values) + else: - # TODO: diff/ia3/... format - print(f">> Encountered unknown lora layer module in {model.name}: {layer_key}") - return + # TODO: ia3/... format + print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}") + raise Exception("Unknown lora format!") # lower memory consumption by removing already parsed layer values state_dict[layer_key].clear() From 7d0cc6ec3f97a4d0c8b61c6af6e3bc4e29c4f1b9 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 3 Aug 2023 11:18:22 +1000 Subject: [PATCH 17/21] chore: black --- invokeai/backend/model_management/lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py index 14a78693ea..9f196d659d 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_management/lora.py @@ -359,7 +359,7 @@ class FullLayer(LoRALayerBase): ): super().to(device=device, dtype=dtype) - self.weight = self.weight.to(device=device, dtype=dtype) + self.weight = self.weight.to(device=device, dtype=dtype) class LoRAModel: # (torch.nn.Module): From f0613bb0ef1642e3c34ba7bc3146294d5de15ad2 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Fri, 4 Aug 2023 19:53:27 +0300 Subject: [PATCH 18/21] Fix merge conflict resolve - restore full/diff layer support --- .../backend/model_management/models/lora.py | 49 +++++++++++++++++-- 1 file changed, 45 insertions(+), 4 deletions(-) diff --git a/invokeai/backend/model_management/models/lora.py b/invokeai/backend/model_management/models/lora.py index 0870e78469..1983c05503 100644 --- a/invokeai/backend/model_management/models/lora.py +++ b/invokeai/backend/model_management/models/lora.py @@ -391,6 +391,43 @@ class LoKRLayer(LoRALayerBase): self.t2 = self.t2.to(device=device, dtype=dtype) +class FullLayer(LoRALayerBase): + # weight: torch.Tensor + + def __init__( + self, + layer_key: str, + values: dict, + ): + super().__init__(layer_key, values) + + self.weight = values["diff"] + + if len(values.keys()) > 1: + _keys = list(values.keys()) + _keys.remove("diff") + raise NotImplementedError(f"Unexpected keys in lora diff layer: {_keys}") + + self.rank = None # unscaled + + def get_weight(self): + return self.weight + + def calc_size(self) -> int: + model_size = super().calc_size() + model_size += self.weight.nelement() * self.weight.element_size() + return model_size + + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + super().to(device=device, dtype=dtype) + + self.weight = self.weight.to(device=device, dtype=dtype) + + # TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix class LoRAModelRaw: # (torch.nn.Module): _name: str @@ -510,10 +547,13 @@ class LoRAModelRaw: # (torch.nn.Module): elif "lokr_w1_b" in values or "lokr_w1" in values: layer = LoKRLayer(layer_key, values) + elif "diff" in values: + layer = FullLayer(layer_key, values) + else: - # TODO: diff/ia3/... format - print(f">> Encountered unknown lora layer module in {model.name}: {layer_key}") - return + # TODO: ia3/... format + print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}") + raise Exception("Unknown lora format!") # lower memory consumption by removing already parsed layer values state_dict[layer_key].clear() @@ -536,6 +576,8 @@ class LoRAModelRaw: # (torch.nn.Module): return state_dict_groupped +# code from +# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32 def make_sdxl_unet_conversion_map(): unet_conversion_map_layer = [] @@ -620,7 +662,6 @@ def make_sdxl_unet_conversion_map(): return unet_conversion_map -# _sdxl_conversion_map = {f"lora_unet_{sd}".rstrip(".").replace(".", "_"): f"lora_unet_{hf}".rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()} SDXL_UNET_COMPVIS_MAP = { f"{sd}".rstrip(".").replace(".", "_"): f"{hf}".rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map() From 04229082d61f10efd70b2ff289b03c46ae4571a8 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Tue, 1 Aug 2023 18:04:10 +0300 Subject: [PATCH 19/21] Provide ti name from model manager, not from ti itself --- invokeai/app/invocations/compel.py | 15 +++++++----- invokeai/app/invocations/onnx.py | 13 +++------- invokeai/backend/model_management/lora.py | 30 +++++++++++------------ 3 files changed, 26 insertions(+), 32 deletions(-) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index bbe372ff57..c11ebd3f56 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -108,14 +108,15 @@ class CompelInvocation(BaseInvocation): for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt): name = trigger[1:-1] try: - ti_list.append( + ti_list.append(( + name, context.services.model_manager.get_model( model_name=name, base_model=self.clip.text_encoder.base_model, model_type=ModelType.TextualInversion, context=context, ).context.model - ) + )) except ModelNotFoundException: # print(e) # import traceback @@ -196,14 +197,15 @@ class SDXLPromptInvocationBase: for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt): name = trigger[1:-1] try: - ti_list.append( + ti_list.append(( + name, context.services.model_manager.get_model( model_name=name, base_model=clip_field.text_encoder.base_model, model_type=ModelType.TextualInversion, context=context, ).context.model - ) + )) except ModelNotFoundException: # print(e) # import traceback @@ -270,14 +272,15 @@ class SDXLPromptInvocationBase: for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt): name = trigger[1:-1] try: - ti_list.append( + ti_list.append(( + name, context.services.model_manager.get_model( model_name=name, base_model=clip_field.text_encoder.base_model, model_type=ModelType.TextualInversion, context=context, ).context.model - ) + )) except ModelNotFoundException: # print(e) # import traceback diff --git a/invokeai/app/invocations/onnx.py b/invokeai/app/invocations/onnx.py index 2bec128b87..dec5b939a0 100644 --- a/invokeai/app/invocations/onnx.py +++ b/invokeai/app/invocations/onnx.py @@ -65,7 +65,6 @@ class ONNXPromptInvocation(BaseInvocation): **self.clip.text_encoder.dict(), ) with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder, ExitStack() as stack: - # loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras] loras = [ (context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras @@ -75,20 +74,14 @@ class ONNXPromptInvocation(BaseInvocation): for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt): name = trigger[1:-1] try: - ti_list.append( - # stack.enter_context( - # context.services.model_manager.get_model( - # model_name=name, - # base_model=self.clip.text_encoder.base_model, - # model_type=ModelType.TextualInversion, - # ) - # ) + ti_list.append(( + name, context.services.model_manager.get_model( model_name=name, base_model=self.clip.text_encoder.base_model, model_type=ModelType.TextualInversion, ).context.model - ) + )) except Exception: # print(e) # import traceback diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py index 7ccf5e57ae..e8e2b3f51f 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_management/lora.py @@ -164,7 +164,7 @@ class ModelPatcher: cls, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, - ti_list: List[Any], + ti_list: List[Tuple[str, Any]], ) -> Tuple[CLIPTokenizer, TextualInversionManager]: init_tokens_count = None new_tokens_added = None @@ -174,27 +174,27 @@ class ModelPatcher: ti_manager = TextualInversionManager(ti_tokenizer) init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings - def _get_trigger(ti, index): - trigger = ti.name + def _get_trigger(ti_name, index): + trigger = ti_name if index > 0: trigger += f"-!pad-{i}" return f"<{trigger}>" # modify tokenizer new_tokens_added = 0 - for ti in ti_list: + for ti_name, ti in ti_list: for i in range(ti.embedding.shape[0]): - new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i)) + new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i)) # modify text_encoder text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added) model_embeddings = text_encoder.get_input_embeddings() - for ti in ti_list: + for ti_name, ti in ti_list: ti_tokens = [] for i in range(ti.embedding.shape[0]): embedding = ti.embedding[i] - trigger = _get_trigger(ti, i) + trigger = _get_trigger(ti_name, i) token_id = ti_tokenizer.convert_tokens_to_ids(trigger) if token_id == ti_tokenizer.unk_token_id: @@ -239,7 +239,6 @@ class ModelPatcher: class TextualInversionModel: - name: str embedding: torch.Tensor # [n, 768]|[n, 1280] @classmethod @@ -253,7 +252,6 @@ class TextualInversionModel: file_path = Path(file_path) result = cls() # TODO: - result.name = file_path.stem # TODO: if file_path.suffix == ".safetensors": state_dict = load_file(file_path.absolute().as_posix(), device="cpu") @@ -430,7 +428,7 @@ class ONNXModelPatcher: cls, tokenizer: CLIPTokenizer, text_encoder: IAIOnnxRuntimeModel, - ti_list: List[Any], + ti_list: List[Tuple[str, Any]], ) -> Tuple[CLIPTokenizer, TextualInversionManager]: from .models.base import IAIOnnxRuntimeModel @@ -443,17 +441,17 @@ class ONNXModelPatcher: ti_tokenizer = copy.deepcopy(tokenizer) ti_manager = TextualInversionManager(ti_tokenizer) - def _get_trigger(ti, index): - trigger = ti.name + def _get_trigger(ti_name, index): + trigger = ti_name if index > 0: trigger += f"-!pad-{i}" return f"<{trigger}>" # modify tokenizer new_tokens_added = 0 - for ti in ti_list: + for ti_name, ti in ti_list: for i in range(ti.embedding.shape[0]): - new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i)) + new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i)) # modify text_encoder orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"] @@ -463,11 +461,11 @@ class ONNXModelPatcher: axis=0, ) - for ti in ti_list: + for ti_name, ti in ti_list: ti_tokens = [] for i in range(ti.embedding.shape[0]): embedding = ti.embedding[i].detach().numpy() - trigger = _get_trigger(ti, i) + trigger = _get_trigger(ti_name, i) token_id = ti_tokenizer.convert_tokens_to_ids(trigger) if token_id == ti_tokenizer.unk_token_id: From 6ad565d84c22243838cd5fca1c579267ac3f4de5 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 3 Aug 2023 19:01:05 -0400 Subject: [PATCH 20/21] folded in changes from 4099 --- invokeai/app/invocations/compel.py | 60 ++++++++++--------- invokeai/app/invocations/onnx.py | 18 +++--- .../backend/model_management/model_cache.py | 2 +- .../backend/model_management/model_manager.py | 6 +- 4 files changed, 47 insertions(+), 39 deletions(-) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index c11ebd3f56..7c3ce7a819 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -108,15 +108,17 @@ class CompelInvocation(BaseInvocation): for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt): name = trigger[1:-1] try: - ti_list.append(( - name, - context.services.model_manager.get_model( - model_name=name, - base_model=self.clip.text_encoder.base_model, - model_type=ModelType.TextualInversion, - context=context, - ).context.model - )) + ti_list.append( + ( + name, + context.services.model_manager.get_model( + model_name=name, + base_model=self.clip.text_encoder.base_model, + model_type=ModelType.TextualInversion, + context=context, + ).context.model, + ) + ) except ModelNotFoundException: # print(e) # import traceback @@ -197,15 +199,17 @@ class SDXLPromptInvocationBase: for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt): name = trigger[1:-1] try: - ti_list.append(( - name, - context.services.model_manager.get_model( - model_name=name, - base_model=clip_field.text_encoder.base_model, - model_type=ModelType.TextualInversion, - context=context, - ).context.model - )) + ti_list.append( + ( + name, + context.services.model_manager.get_model( + model_name=name, + base_model=clip_field.text_encoder.base_model, + model_type=ModelType.TextualInversion, + context=context, + ).context.model, + ) + ) except ModelNotFoundException: # print(e) # import traceback @@ -272,15 +276,17 @@ class SDXLPromptInvocationBase: for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt): name = trigger[1:-1] try: - ti_list.append(( - name, - context.services.model_manager.get_model( - model_name=name, - base_model=clip_field.text_encoder.base_model, - model_type=ModelType.TextualInversion, - context=context, - ).context.model - )) + ti_list.append( + ( + name, + context.services.model_manager.get_model( + model_name=name, + base_model=clip_field.text_encoder.base_model, + model_type=ModelType.TextualInversion, + context=context, + ).context.model, + ) + ) except ModelNotFoundException: # print(e) # import traceback diff --git a/invokeai/app/invocations/onnx.py b/invokeai/app/invocations/onnx.py index dec5b939a0..fe9a64552e 100644 --- a/invokeai/app/invocations/onnx.py +++ b/invokeai/app/invocations/onnx.py @@ -74,14 +74,16 @@ class ONNXPromptInvocation(BaseInvocation): for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt): name = trigger[1:-1] try: - ti_list.append(( - name, - context.services.model_manager.get_model( - model_name=name, - base_model=self.clip.text_encoder.base_model, - model_type=ModelType.TextualInversion, - ).context.model - )) + ti_list.append( + ( + name, + context.services.model_manager.get_model( + model_name=name, + base_model=self.clip.text_encoder.base_model, + model_type=ModelType.TextualInversion, + ).context.model, + ) + ) except Exception: # print(e) # import traceback diff --git a/invokeai/backend/model_management/model_cache.py b/invokeai/backend/model_management/model_cache.py index 71e1ebc0d4..2b8d020269 100644 --- a/invokeai/backend/model_management/model_cache.py +++ b/invokeai/backend/model_management/model_cache.py @@ -186,7 +186,7 @@ class ModelCache(object): cache_entry = self._cached_models.get(key, None) if cache_entry is None: self.logger.info( - f"Loading model {model_path}, type {base_model.value}:{model_type.value}:{submodel.value if submodel else ''}" + f"Loading model {model_path}, type {base_model.value}:{model_type.value}{':'+submodel.value if submodel else ''}" ) # this will remove older cached models until diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 832a96e18f..3fd59d1533 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -670,7 +670,7 @@ class ModelManager(object): # TODO: if path changed and old_model.path inside models folder should we delete this too? # remove conversion cache as config changed - old_model_path = self.app_config.root_path / old_model.path + old_model_path = self.resolve_model_path(old_model.path) old_model_cache = self._get_model_cache_path(old_model_path) if old_model_cache.exists(): if old_model_cache.is_dir(): @@ -780,7 +780,7 @@ class ModelManager(object): model_type, **submodel, ) - checkpoint_path = self.app_config.root_path / info["path"] + checkpoint_path = self.resolve_model_path(info["path"]) old_diffusers_path = self.resolve_model_path(model.location) new_diffusers_path = ( dest_directory or self.app_config.models_path / base_model.value / model_type.value @@ -992,7 +992,7 @@ class ModelManager(object): model_manager=self, prediction_type_helper=ask_user_for_prediction_type, ) - known_paths = {config.root_path / x["path"] for x in self.list_models()} + known_paths = {self.resolve_model_path(x["path"]) for x in self.list_models()} directories = { config.root_path / x for x in [ From 1b158f62c4d4a98da1d5aac2288ad724b4b8a87a Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 3 Aug 2023 19:26:42 -0400 Subject: [PATCH 21/21] resolve vae overrides correctly --- invokeai/backend/model_management/model_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 3fd59d1533..a3b0d4e04a 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -472,7 +472,7 @@ class ModelManager(object): if submodel_type is not None and hasattr(model_config, submodel_type): override_path = getattr(model_config, submodel_type) if override_path: - model_path = self.app_config.root_path / override_path + model_path = self.resolve_path(override_path) model_type = submodel_type submodel_type = None model_class = MODEL_CLASSES[base_model][model_type]