working but filled with debug statements

This commit is contained in:
Lincoln Stein
2024-04-01 18:44:24 -04:00
parent 3d69372785
commit 9adb15f86c
8 changed files with 72 additions and 44 deletions

View File

@ -4,6 +4,8 @@ from logging import Logger
import torch
import invokeai.backend.util.devices # horrible hack
from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
@ -100,6 +102,9 @@ class ApiDependencies:
download_queue=download_queue_service,
events=events,
)
# horrible hack - remove
invokeai.backend.util.devices.RAM_CACHE = model_manager.load.ram_cache
names = SimpleNameService()
session_processor = DefaultSessionProcessor()
session_queue = SqliteSessionQueue(db=db)

View File

@ -4,7 +4,7 @@ import math
from contextlib import ExitStack
from functools import singledispatchmethod
from typing import Any, Iterator, List, Literal, Optional, Tuple, Union
import threading
import einops
import numpy as np
import numpy.typing as npt
@ -393,6 +393,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
# flip all bits to have noise different from initial
generator=torch.Generator(device=unet.device).manual_seed(seed ^ 0xFFFFFFFF),
)
if conditioning_data.unconditioned_embeddings.embeds.device != conditioning_data.text_embeddings.embeds.device:
print(f'DEBUG; ERROR uc={conditioning_data.unconditioned_embeddings.embeds.device} c={conditioning_data.text_embeddings.embeds.device} unet={unet.device}, tid={threading.current_thread().ident}')
return conditioning_data
def create_pipeline(

View File

@ -1,5 +1,6 @@
from queue import Queue
from typing import TYPE_CHECKING, Optional, TypeVar
import threading
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
@ -18,8 +19,8 @@ class ObjectSerializerForwardCache(ObjectSerializerBase[T]):
def __init__(self, underlying_storage: ObjectSerializerBase[T], max_cache_size: int = 20):
super().__init__()
self._underlying_storage = underlying_storage
self._cache: dict[str, T] = {}
self._cache_ids = Queue[str]()
self._cache: dict[int, dict[str, T]] = {}
self._cache_ids: dict[int, Queue[str]] = {}
self._max_cache_size = max_cache_size
def start(self, invoker: "Invoker") -> None:
@ -54,12 +55,27 @@ class ObjectSerializerForwardCache(ObjectSerializerBase[T]):
del self._cache[name]
self._on_deleted(name)
def _get_tid_cache(self) -> dict[str, T]:
tid = threading.current_thread().ident
if tid not in self._cache:
self._cache[tid] = {}
return self._cache[tid]
def _get_tid_cache_ids(self) -> Queue[str]:
tid = threading.current_thread().ident
if tid not in self._cache_ids:
self._cache_ids[tid] = Queue[str]()
return self._cache_ids[tid]
def _get_cache(self, name: str) -> Optional[T]:
return None if name not in self._cache else self._cache[name]
cache = self._get_tid_cache()
return None if name not in cache else cache[name]
def _set_cache(self, name: str, data: T):
if name not in self._cache:
self._cache[name] = data
self._cache_ids.put(name)
if self._cache_ids.qsize() > self._max_cache_size:
self._cache.pop(self._cache_ids.get())
cache = self._get_tid_cache()
if name not in cache:
cache[name] = data
cache_ids = self._get_tid_cache_ids()
cache_ids.put(name)
if cache_ids.qsize() > self._max_cache_size:
cache.pop(cache_ids.get())

View File

@ -175,7 +175,6 @@ class DefaultSessionProcessor(SessionProcessorBase):
session = self._session_worker_queue.get()
if self._cancel_event.is_set():
if session.item_id in self._sessions_to_cancel:
print("DEBUG: CANCEL")
continue
if profiler is not None:
@ -183,7 +182,6 @@ class DefaultSessionProcessor(SessionProcessorBase):
# reserve a GPU for this session - may block
with self._invoker.services.model_manager.load.ram_cache.reserve_execution_device() as gpu:
print(f"DEBUG: session {session.item_id} has reserved gpu {gpu}")
# Prepare invocations and take the first
with self._process_lock: