mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fixup unit tests and remove debugging statements
This commit is contained in:
@ -4,7 +4,6 @@ from logging import Logger
|
||||
|
||||
import torch
|
||||
|
||||
import invokeai.backend.util.devices # horrible hack
|
||||
from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk
|
||||
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
|
||||
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||
|
@ -99,6 +99,7 @@ class CompelInvocation(BaseInvocation):
|
||||
textual_inversion_manager=ti_manager,
|
||||
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
|
||||
truncate_long_prompts=False,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
)
|
||||
|
||||
conjunction = Compel.parse_prompt_string(self.prompt)
|
||||
@ -113,6 +114,7 @@ class CompelInvocation(BaseInvocation):
|
||||
conditioning_data = ConditioningFieldData(conditionings=[BasicConditioningInfo(embeds=c)])
|
||||
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
|
||||
return ConditioningOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
|
@ -74,9 +74,9 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
)
|
||||
self._stats[graph_execution_state_id].add_node_execution_stats(node_stats)
|
||||
|
||||
def reset_stats(self):
|
||||
self._stats = {}
|
||||
self._cache_stats = {}
|
||||
def reset_stats(self, graph_execution_state_id: str):
|
||||
self._stats.pop(graph_execution_state_id)
|
||||
self._cache_stats.pop(graph_execution_state_id)
|
||||
|
||||
def get_stats(self, graph_execution_state_id: str) -> InvocationStatsSummary:
|
||||
graph_stats_summary = self._get_graph_summary(graph_execution_state_id)
|
||||
|
@ -76,8 +76,6 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
|
||||
ram_cache = ModelCache(
|
||||
max_cache_size=app_config.ram,
|
||||
max_vram_cache_size=app_config.vram,
|
||||
lazy_offloading=app_config.lazy_offload,
|
||||
logger=logger,
|
||||
)
|
||||
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)
|
||||
|
@ -1,7 +1,7 @@
|
||||
import traceback
|
||||
from contextlib import suppress
|
||||
from queue import Queue
|
||||
from threading import BoundedSemaphore, Thread, Lock
|
||||
from threading import BoundedSemaphore, Lock, Thread
|
||||
from threading import Event as ThreadEvent
|
||||
from typing import Optional, Set
|
||||
|
||||
@ -61,7 +61,9 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
self._on_after_run_session_callbacks = on_after_run_session_callbacks or []
|
||||
self._process_lock = Lock()
|
||||
|
||||
def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None) -> None:
|
||||
def start(
|
||||
self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None
|
||||
) -> None:
|
||||
self._services = services
|
||||
self._cancel_event = cancel_event
|
||||
self._profiler = profiler
|
||||
@ -214,7 +216,7 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
# we don't care about that - suppress the error.
|
||||
with suppress(GESStatsNotFoundError):
|
||||
self._services.performance_statistics.log_stats(queue_item.session.id)
|
||||
self._services.performance_statistics.reset_stats()
|
||||
self._services.performance_statistics.reset_stats(queue_item.session.id)
|
||||
|
||||
for callback in self._on_after_run_session_callbacks:
|
||||
callback(queue_item=queue_item)
|
||||
@ -384,7 +386,6 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
)
|
||||
worker.start()
|
||||
|
||||
|
||||
def stop(self, *args, **kwargs) -> None:
|
||||
self._stop_event.set()
|
||||
|
||||
@ -465,7 +466,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
# Run the graph
|
||||
# self.session_runner.run(queue_item=self._queue_item)
|
||||
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
# Wait for next polling interval or event to try again
|
||||
poll_now_event.wait(self._polling_interval)
|
||||
continue
|
||||
@ -494,7 +495,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
with self._invoker.services.model_manager.load.ram_cache.reserve_execution_device():
|
||||
# Run the session on the reserved GPU
|
||||
self.session_runner.run(queue_item=queue_item)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
continue
|
||||
finally:
|
||||
self._active_queue_items.remove(queue_item)
|
||||
|
@ -239,6 +239,7 @@ class SessionQueueItemWithoutGraph(BaseModel):
|
||||
def __hash__(self) -> int:
|
||||
return self.item_id
|
||||
|
||||
|
||||
class SessionQueueItemDTO(SessionQueueItemWithoutGraph):
|
||||
pass
|
||||
|
||||
|
@ -325,7 +325,6 @@ class ConditioningInterface(InvocationContextInterface):
|
||||
Returns:
|
||||
The loaded conditioning data.
|
||||
"""
|
||||
|
||||
return self._services.conditioning.load(name)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user