mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): emit model loading events
- remove dependency on having access to a `node` during emits, would need a bit of additional args passed through the system and I don't think its necessary at this point. this also allowed us to drop an extraneous fetching/parsing of the session from db. - provide the invocation context to all `get_model()` calls, so the events are able to be emitted - test all model loading events in the app and confirm socket events are received
This commit is contained in:
@ -76,7 +76,7 @@ def get_scheduler(
|
||||
scheduler_name, SCHEDULER_MAP['ddim']
|
||||
)
|
||||
orig_scheduler_info = context.services.model_manager.get_model(
|
||||
**scheduler_info.dict()
|
||||
**scheduler_info.dict(), context=context,
|
||||
)
|
||||
with orig_scheduler_info as orig_scheduler:
|
||||
scheduler_config = orig_scheduler.config
|
||||
@ -262,6 +262,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
model_name=control_info.control_model.model_name,
|
||||
model_type=ModelType.ControlNet,
|
||||
base_model=control_info.control_model.base_model,
|
||||
context=context,
|
||||
)
|
||||
)
|
||||
|
||||
@ -313,14 +314,14 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
def _lora_loader():
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.dict(exclude={"weight"})
|
||||
**lora.dict(exclude={"weight"}), context=context,
|
||||
)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.dict()
|
||||
**self.unet.unet.dict(), context=context,
|
||||
)
|
||||
with ExitStack() as exit_stack,\
|
||||
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
||||
@ -403,14 +404,14 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
def _lora_loader():
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.dict(exclude={"weight"})
|
||||
**lora.dict(exclude={"weight"}), context=context,
|
||||
)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.dict()
|
||||
**self.unet.unet.dict(), context=context,
|
||||
)
|
||||
with ExitStack() as exit_stack,\
|
||||
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
||||
@ -491,7 +492,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
**self.vae.vae.dict(),
|
||||
**self.vae.vae.dict(), context=context,
|
||||
)
|
||||
|
||||
with vae_info as vae:
|
||||
@ -636,7 +637,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
|
||||
#vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
**self.vae.vae.dict(),
|
||||
**self.vae.vae.dict(), context=context,
|
||||
)
|
||||
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
|
Reference in New Issue
Block a user