References to context.services.model_manager.store.get_model can only accept keys, remove invalid assertion

This commit is contained in:
Brandon Rising 2024-02-14 09:51:11 -05:00 committed by Brandon Rising
parent 5cc73ec5dd
commit aa5d124d70
2 changed files with 13 additions and 13 deletions

View File

@ -681,7 +681,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
# get the unet's config so that we can pass the base to dispatch_progress()
unet_config = context.services.model_manager.store.get_model(**self.unet.unet.model_dump())
unet_config = context.services.model_manager.store.get_model(self.unet.unet.key)
def step_callback(state: PipelineIntermediateState) -> None:
self.dispatch_progress(context, source_node_id, state, unet_config.base)
@ -709,7 +709,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
# Apply the LoRA after unet has been moved to its target device for faster patching.
ModelPatcher.apply_lora_unet(unet, _lora_loader()),
):
assert isinstance(unet, torch.Tensor)
assert isinstance(unet, UNet2DConditionModel)
latents = latents.to(device=unet.device, dtype=unet.dtype)
if noise is not None:
noise = noise.to(device=unet.device, dtype=unet.dtype)

View File

@ -303,7 +303,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
in_vram_models = 0
locked_in_vram_models = 0
for cache_record in self._cached_models.values():
assert hasattr(cache_record.model, "device")
if hasattr(cache_record.model, "device"):
if cache_record.model.device == self.storage_device:
in_ram_models += 1
else: