chore: black

This commit is contained in:
psychedelicious 2023-08-14 13:02:33 +10:00
parent 46a8eed33e
commit 9d3cd85bdd
3 changed files with 27 additions and 18 deletions

View File

@ -212,7 +212,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
seed=0, # TODO: refactor this node
seed=0, # TODO: refactor this node
)
def torch2numpy(latent: torch.Tensor):

View File

@ -429,13 +429,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
step_count=len(self.scheduler.timesteps),
):
if callback is not None:
callback(PipelineIntermediateState(
step=-1,
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=self.scheduler.config.num_train_timesteps,
latents=latents,
))
callback(
PipelineIntermediateState(
step=-1,
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=self.scheduler.config.num_train_timesteps,
latents=latents,
)
)
# print("timesteps:", timesteps)
for i, t in enumerate(self.progress_bar(timesteps)):
@ -469,15 +471,17 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
if callback is not None:
callback(PipelineIntermediateState(
step=i,
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=int(t),
latents=latents,
predicted_original=predicted_original,
attention_map_saver=attention_map_saver,
))
callback(
PipelineIntermediateState(
step=i,
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=int(t),
latents=latents,
predicted_original=predicted_original,
attention_map_saver=attention_map_saver,
)
)
return latents, attention_map_saver

View File

@ -3,4 +3,9 @@ Initialization file for invokeai.models.diffusion
"""
from .cross_attention_control import InvokeAICrossAttentionMixin
from .cross_attention_map_saving import AttentionMapSaver
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings, BasicConditioningInfo, SDXLConditioningInfo
from .shared_invokeai_diffusion import (
InvokeAIDiffuserComponent,
PostprocessingSettings,
BasicConditioningInfo,
SDXLConditioningInfo,
)