fixup unit tests and remove debugging statements

This commit is contained in:
Lincoln Stein
2024-06-02 18:19:29 -04:00
parent e26360f85b
commit 589a7959c0
11 changed files with 61 additions and 186 deletions

View File

@ -4,7 +4,6 @@ from logging import Logger
import torch
import invokeai.backend.util.devices # horrible hack
from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
from invokeai.app.services.shared.sqlite.sqlite_util import init_db

View File

@ -99,6 +99,7 @@ class CompelInvocation(BaseInvocation):
textual_inversion_manager=ti_manager,
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
truncate_long_prompts=False,
device=TorchDevice.choose_torch_device(),
)
conjunction = Compel.parse_prompt_string(self.prompt)
@ -113,6 +114,7 @@ class CompelInvocation(BaseInvocation):
conditioning_data = ConditioningFieldData(conditionings=[BasicConditioningInfo(embeds=c)])
conditioning_name = context.conditioning.save(conditioning_data)
return ConditioningOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,

View File

@ -74,9 +74,9 @@ class InvocationStatsService(InvocationStatsServiceBase):
)
self._stats[graph_execution_state_id].add_node_execution_stats(node_stats)
def reset_stats(self):
self._stats = {}
self._cache_stats = {}
def reset_stats(self, graph_execution_state_id: str):
self._stats.pop(graph_execution_state_id)
self._cache_stats.pop(graph_execution_state_id)
def get_stats(self, graph_execution_state_id: str) -> InvocationStatsSummary:
graph_stats_summary = self._get_graph_summary(graph_execution_state_id)

View File

@ -76,8 +76,6 @@ class ModelManagerService(ModelManagerServiceBase):
ram_cache = ModelCache(
max_cache_size=app_config.ram,
max_vram_cache_size=app_config.vram,
lazy_offloading=app_config.lazy_offload,
logger=logger,
)
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)

View File

@ -1,7 +1,7 @@
import traceback
from contextlib import suppress
from queue import Queue
from threading import BoundedSemaphore, Thread, Lock
from threading import BoundedSemaphore, Lock, Thread
from threading import Event as ThreadEvent
from typing import Optional, Set
@ -61,7 +61,9 @@ class DefaultSessionRunner(SessionRunnerBase):
self._on_after_run_session_callbacks = on_after_run_session_callbacks or []
self._process_lock = Lock()
def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None) -> None:
def start(
self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None
) -> None:
self._services = services
self._cancel_event = cancel_event
self._profiler = profiler
@ -214,7 +216,7 @@ class DefaultSessionRunner(SessionRunnerBase):
# we don't care about that - suppress the error.
with suppress(GESStatsNotFoundError):
self._services.performance_statistics.log_stats(queue_item.session.id)
self._services.performance_statistics.reset_stats()
self._services.performance_statistics.reset_stats(queue_item.session.id)
for callback in self._on_after_run_session_callbacks:
callback(queue_item=queue_item)
@ -384,7 +386,6 @@ class DefaultSessionProcessor(SessionProcessorBase):
)
worker.start()
def stop(self, *args, **kwargs) -> None:
self._stop_event.set()
@ -465,7 +466,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
# Run the graph
# self.session_runner.run(queue_item=self._queue_item)
except Exception as e:
except Exception:
# Wait for next polling interval or event to try again
poll_now_event.wait(self._polling_interval)
continue
@ -494,7 +495,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
with self._invoker.services.model_manager.load.ram_cache.reserve_execution_device():
# Run the session on the reserved GPU
self.session_runner.run(queue_item=queue_item)
except Exception as e:
except Exception:
continue
finally:
self._active_queue_items.remove(queue_item)

View File

@ -239,6 +239,7 @@ class SessionQueueItemWithoutGraph(BaseModel):
def __hash__(self) -> int:
return self.item_id
class SessionQueueItemDTO(SessionQueueItemWithoutGraph):
pass

View File

@ -325,7 +325,6 @@ class ConditioningInterface(InvocationContextInterface):
Returns:
The loaded conditioning data.
"""
return self._services.conditioning.load(name)