chore: Black linting

This commit is contained in:
blessedcoolant 2023-08-13 21:28:39 +12:00
parent 3ff9961bda
commit 561951ad98
4 changed files with 52 additions and 36 deletions

View File

@ -294,11 +294,17 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput:
c1, c1_pooled, ec1 = self.run_clip_compel(context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=False)
c1, c1_pooled, ec1 = self.run_clip_compel(
context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=False
)
if self.style.strip() == "":
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.prompt, True, "lora_te2_", zero_on_empty=True)
c2, c2_pooled, ec2 = self.run_clip_compel(
context, self.clip2, self.prompt, True, "lora_te2_", zero_on_empty=True
)
else:
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "lora_te2_", zero_on_empty=True)
c2, c2_pooled, ec2 = self.run_clip_compel(
context, self.clip2, self.style, True, "lora_te2_", zero_on_empty=True
)
original_size = (self.original_height, self.original_width)
crop_coords = (self.crop_top, self.crop_left)

View File

@ -320,7 +320,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_end):
if scheduler.config.get("cpu_only", False):
device = torch.device("cpu")
num_inference_steps = steps
scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = scheduler.timesteps
@ -344,7 +344,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
# calculate step count based on scheduler order
num_inference_steps = len(timesteps)
if scheduler.order == 2:
num_inference_steps += (num_inference_steps % 2)
num_inference_steps += num_inference_steps % 2
num_inference_steps = num_inference_steps // 2
return num_inference_steps, timesteps, init_timestep

View File

@ -202,8 +202,8 @@ class ControlNetData:
@dataclass
class ConditioningData:
unconditioned_embeddings: Any # TODO: type
text_embeddings: Any # TODO: type
unconditioned_embeddings: Any # TODO: type
text_embeddings: Any # TODO: type
guidance_scale: Union[float, List[float]]
"""
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
@ -389,19 +389,17 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
batched_t = init_timestep.repeat(batch_size)
if noise is not None:
#latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
latents = self.scheduler.add_noise(latents, noise, batched_t)
if mask is not None:
if is_inpainting_model(self.unet):
# You'd think the inpainting model wouldn't be paying attention to the area it is going to repaint
# (that's why there's a mask!) but it seems to really want that blanked out.
#masked_latents = latents * torch.where(mask < 0.5, 1, 0) TODO: inpaint/outpaint/infill
# masked_latents = latents * torch.where(mask < 0.5, 1, 0) TODO: inpaint/outpaint/infill
# TODO: we should probably pass this in so we don't have to try/finally around setting it.
self.invokeai_diffuser.model_forward_callback = AddsMaskLatents(
self._unet_forward, mask, orig_latents
)
self.invokeai_diffuser.model_forward_callback = AddsMaskLatents(self._unet_forward, mask, orig_latents)
else:
# if no noise provided, noisify unmasked area based on seed(or 0 as fallback)
if noise is None:
@ -413,7 +411,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
).to(device=orig_latents.device, dtype=orig_latents.dtype)
latents = self.scheduler.add_noise(latents, noise, batched_t)
latents = torch.lerp(orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype))
latents = torch.lerp(
orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype)
)
additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise))
@ -549,11 +549,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step(
sample=latent_model_input,
timestep=t, # TODO: debug how handled batched and non batched timesteps
timestep=t, # TODO: debug how handled batched and non batched timesteps
step_index=step_index,
total_step_count=total_step_count,
conditioning_data=conditioning_data,
# extra:
down_block_additional_residuals=controlnet_down_block_samples, # from controlnet(s)
mid_block_additional_residual=controlnet_mid_block_sample, # from controlnet(s)

View File

@ -202,15 +202,21 @@ class InvokeAIDiffuserComponent:
else:
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
added_cond_kwargs = {
"text_embeds": torch.cat([
# TODO: how to pad? just by zeros? or even truncate?
conditioning_data.unconditioned_embeddings.pooled_embeds,
conditioning_data.text_embeddings.pooled_embeds,
], dim=0),
"time_ids": torch.cat([
conditioning_data.unconditioned_embeddings.add_time_ids,
conditioning_data.text_embeddings.add_time_ids,
], dim=0),
"text_embeds": torch.cat(
[
# TODO: how to pad? just by zeros? or even truncate?
conditioning_data.unconditioned_embeddings.pooled_embeds,
conditioning_data.text_embeddings.pooled_embeds,
],
dim=0,
),
"time_ids": torch.cat(
[
conditioning_data.unconditioned_embeddings.add_time_ids,
conditioning_data.text_embeddings.add_time_ids,
],
dim=0,
),
}
(
encoder_hidden_states,
@ -260,7 +266,7 @@ class InvokeAIDiffuserComponent:
self,
sample: torch.Tensor,
timestep: torch.Tensor,
conditioning_data, # TODO: type
conditioning_data, # TODO: type
step_index: int,
total_step_count: int,
**kwargs,
@ -380,20 +386,25 @@ class InvokeAIDiffuserComponent:
added_cond_kwargs = None
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
added_cond_kwargs = {
"text_embeds": torch.cat([
# TODO: how to pad? just by zeros? or even truncate?
conditioning_data.unconditioned_embeddings.pooled_embeds,
conditioning_data.text_embeddings.pooled_embeds,
], dim=0),
"time_ids": torch.cat([
conditioning_data.unconditioned_embeddings.add_time_ids,
conditioning_data.text_embeddings.add_time_ids,
], dim=0),
"text_embeds": torch.cat(
[
# TODO: how to pad? just by zeros? or even truncate?
conditioning_data.unconditioned_embeddings.pooled_embeds,
conditioning_data.text_embeddings.pooled_embeds,
],
dim=0,
),
"time_ids": torch.cat(
[
conditioning_data.unconditioned_embeddings.add_time_ids,
conditioning_data.text_embeddings.add_time_ids,
],
dim=0,
),
}
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
conditioning_data.unconditioned_embeddings.embeds,
conditioning_data.text_embeddings.embeds
conditioning_data.unconditioned_embeddings.embeds, conditioning_data.text_embeddings.embeds
)
both_results = self.model_forward_callback(
x_twice,