mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
working but filled with debug statements
This commit is contained in:
@ -4,6 +4,8 @@ from logging import Logger
|
||||
|
||||
import torch
|
||||
|
||||
import invokeai.backend.util.devices # horrible hack
|
||||
|
||||
from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk
|
||||
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
|
||||
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||
@ -100,6 +102,9 @@ class ApiDependencies:
|
||||
download_queue=download_queue_service,
|
||||
events=events,
|
||||
)
|
||||
# horrible hack - remove
|
||||
invokeai.backend.util.devices.RAM_CACHE = model_manager.load.ram_cache
|
||||
|
||||
names = SimpleNameService()
|
||||
session_processor = DefaultSessionProcessor()
|
||||
session_queue = SqliteSessionQueue(db=db)
|
||||
|
@ -4,7 +4,7 @@ import math
|
||||
from contextlib import ExitStack
|
||||
from functools import singledispatchmethod
|
||||
from typing import Any, Iterator, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import threading
|
||||
import einops
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
@ -393,6 +393,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
# flip all bits to have noise different from initial
|
||||
generator=torch.Generator(device=unet.device).manual_seed(seed ^ 0xFFFFFFFF),
|
||||
)
|
||||
|
||||
if conditioning_data.unconditioned_embeddings.embeds.device != conditioning_data.text_embeddings.embeds.device:
|
||||
print(f'DEBUG; ERROR uc={conditioning_data.unconditioned_embeddings.embeds.device} c={conditioning_data.text_embeddings.embeds.device} unet={unet.device}, tid={threading.current_thread().ident}')
|
||||
|
||||
|
||||
return conditioning_data
|
||||
|
||||
def create_pipeline(
|
||||
|
@ -1,5 +1,6 @@
|
||||
from queue import Queue
|
||||
from typing import TYPE_CHECKING, Optional, TypeVar
|
||||
import threading
|
||||
|
||||
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
|
||||
|
||||
@ -18,8 +19,8 @@ class ObjectSerializerForwardCache(ObjectSerializerBase[T]):
|
||||
def __init__(self, underlying_storage: ObjectSerializerBase[T], max_cache_size: int = 20):
|
||||
super().__init__()
|
||||
self._underlying_storage = underlying_storage
|
||||
self._cache: dict[str, T] = {}
|
||||
self._cache_ids = Queue[str]()
|
||||
self._cache: dict[int, dict[str, T]] = {}
|
||||
self._cache_ids: dict[int, Queue[str]] = {}
|
||||
self._max_cache_size = max_cache_size
|
||||
|
||||
def start(self, invoker: "Invoker") -> None:
|
||||
@ -54,12 +55,27 @@ class ObjectSerializerForwardCache(ObjectSerializerBase[T]):
|
||||
del self._cache[name]
|
||||
self._on_deleted(name)
|
||||
|
||||
def _get_tid_cache(self) -> dict[str, T]:
|
||||
tid = threading.current_thread().ident
|
||||
if tid not in self._cache:
|
||||
self._cache[tid] = {}
|
||||
return self._cache[tid]
|
||||
|
||||
def _get_tid_cache_ids(self) -> Queue[str]:
|
||||
tid = threading.current_thread().ident
|
||||
if tid not in self._cache_ids:
|
||||
self._cache_ids[tid] = Queue[str]()
|
||||
return self._cache_ids[tid]
|
||||
|
||||
def _get_cache(self, name: str) -> Optional[T]:
|
||||
return None if name not in self._cache else self._cache[name]
|
||||
cache = self._get_tid_cache()
|
||||
return None if name not in cache else cache[name]
|
||||
|
||||
def _set_cache(self, name: str, data: T):
|
||||
if name not in self._cache:
|
||||
self._cache[name] = data
|
||||
self._cache_ids.put(name)
|
||||
if self._cache_ids.qsize() > self._max_cache_size:
|
||||
self._cache.pop(self._cache_ids.get())
|
||||
cache = self._get_tid_cache()
|
||||
if name not in cache:
|
||||
cache[name] = data
|
||||
cache_ids = self._get_tid_cache_ids()
|
||||
cache_ids.put(name)
|
||||
if cache_ids.qsize() > self._max_cache_size:
|
||||
cache.pop(cache_ids.get())
|
||||
|
@ -175,7 +175,6 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
session = self._session_worker_queue.get()
|
||||
if self._cancel_event.is_set():
|
||||
if session.item_id in self._sessions_to_cancel:
|
||||
print("DEBUG: CANCEL")
|
||||
continue
|
||||
|
||||
if profiler is not None:
|
||||
@ -183,7 +182,6 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
|
||||
# reserve a GPU for this session - may block
|
||||
with self._invoker.services.model_manager.load.ram_cache.reserve_execution_device() as gpu:
|
||||
print(f"DEBUG: session {session.item_id} has reserved gpu {gpu}")
|
||||
|
||||
# Prepare invocations and take the first
|
||||
with self._process_lock:
|
||||
|
Reference in New Issue
Block a user