Apply lora by patching lora instead of hooks

This commit is contained in:
Sergey Borisov
2023-06-26 03:57:33 +03:00
parent 1ba94a92b3
commit 5cebf67ee4
4 changed files with 52 additions and 38 deletions

View File

@ -362,8 +362,7 @@ class TextToLatentsInvocation(BaseInvocation):
self.dispatch_progress(context, source_node_id, state)
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
with unet_info as unet,\
ExitStack() as stack:
with unet_info as unet:
scheduler = get_scheduler(
context=context,
@ -374,7 +373,7 @@ class TextToLatentsInvocation(BaseInvocation):
pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler)
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras]
control_data = self.prep_control_data(
model=pipeline, context=context, control_input=self.control,
@ -438,8 +437,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
**self.unet.unet.dict(),
)
with unet_info as unet,\
ExitStack() as stack:
with unet_info as unet:
scheduler = get_scheduler(
context=context,
@ -468,7 +466,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
device=unet.device,
)
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras]
with ModelPatcher.apply_lora_unet(pipeline.unet, loras):
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(