Fix Inpainting Issues (#3744)

- fix: Inpaint not working with some schedulers: Resolves #3732
- fix: LoRA's not working at all while inpainting.
This commit is contained in:
blessedcoolant 2023-07-13 23:42:44 +12:00 committed by GitHub
commit d4ec8873f7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 47 additions and 45 deletions

View File

@ -154,18 +154,20 @@ class InpaintInvocation(BaseInvocation):
@contextmanager @contextmanager
def load_model_old_way(self, context, scheduler): def load_model_old_way(self, context, scheduler):
def _lora_loader():
for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}))
yield (lora_info.context.model, lora.weight)
del lora_info
return
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict()) unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
vae_info = context.services.model_manager.get_model(**self.vae.vae.dict()) vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
#unet = unet_info.context.model
#vae = vae_info.context.model
with ExitStack() as stack:
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
with vae_info as vae,\ with vae_info as vae,\
unet_info as unet,\ ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
ModelPatcher.apply_lora_unet(unet, loras): unet_info as unet:
device = context.services.model_manager.mgr.cache.execution_device device = context.services.model_manager.mgr.cache.execution_device
dtype = context.services.model_manager.mgr.cache.precision dtype = context.services.model_manager.mgr.cache.precision

View File

@ -127,7 +127,7 @@ class AddsMaskGuidance:
def _t_for_field(self, field_name: str, t): def _t_for_field(self, field_name: str, t):
if field_name == "pred_original_sample": if field_name == "pred_original_sample":
return torch.zeros_like(t, dtype=t.dtype) # it represents t=0 return self.scheduler.timesteps[-1]
return t return t
def apply_mask(self, latents: torch.Tensor, t) -> torch.Tensor: def apply_mask(self, latents: torch.Tensor, t) -> torch.Tensor: