Suggested changes

Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
This commit is contained in:
Sergey Borisov
2024-07-22 22:17:29 +03:00
parent 9a1420280e
commit 1b359b55cb
2 changed files with 13 additions and 13 deletions

View File

@ -14,7 +14,7 @@ if TYPE_CHECKING:
class RescaleCFGExt(ExtensionBase):
def __init__(self, rescale_multiplier: float):
super().__init__()
self.rescale_multiplier = rescale_multiplier
self._rescale_multiplier = rescale_multiplier
@staticmethod
def _rescale_cfg(total_noise_pred: torch.Tensor, pos_noise_pred: torch.Tensor, multiplier: float = 0.7):
@ -28,9 +28,9 @@ class RescaleCFGExt(ExtensionBase):
@callback(ExtensionCallbackType.POST_COMBINE_NOISE_PREDS)
def rescale_noise_pred(self, ctx: DenoiseContext):
if self.rescale_multiplier > 0:
if self._rescale_multiplier > 0:
ctx.noise_pred = self._rescale_cfg(
ctx.noise_pred,
ctx.positive_noise_pred,
self.rescale_multiplier,
self._rescale_multiplier,
)