add note about discriminated union and Body() issue; blackified

This commit is contained in:
Lincoln Stein 2023-11-12 16:50:05 -05:00
parent ef8dcf5fae
commit 8afe517204
11 changed files with 41 additions and 18 deletions

View File

@ -755,7 +755,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
denoising_end=self.denoising_end, denoising_end=self.denoising_end,
) )
(result_latents, result_attention_map_saver,) = pipeline.latents_from_embeddings( (
result_latents,
result_attention_map_saver,
) = pipeline.latents_from_embeddings(
latents=latents, latents=latents,
timesteps=timesteps, timesteps=timesteps,
init_timestep=init_timestep, init_timestep=init_timestep,

View File

@ -268,7 +268,14 @@ AnyModelConfig = Union[
T2IConfig, T2IConfig,
] ]
# Preferred alternative is a discriminated Union, but it breaks FastAPI when applied to a route. AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
# IMPLEMENTATION NOTE:
# The preferred alternative to the above is a discriminated Union as shown
# below. However, it breaks FastAPI when used as the input Body parameter in a route.
# This is a known issue. Please see:
# https://github.com/tiangolo/fastapi/discussions/9761 and
# https://github.com/tiangolo/fastapi/discussions/9287
# AnyModelConfig = Annotated[ # AnyModelConfig = Annotated[
# Union[ # Union[
# _MainModelConfig, # _MainModelConfig,
@ -284,8 +291,6 @@ AnyModelConfig = Union[
# Field(discriminator="type"), # Field(discriminator="type"),
# ] # ]
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
class ModelConfigFactory(object): class ModelConfigFactory(object):
"""Class for parsing config dicts into StableDiffusion Config obects.""" """Class for parsing config dicts into StableDiffusion Config obects."""

View File

@ -175,7 +175,10 @@ class InvokeAIDiffuserComponent:
dim=0, dim=0,
), ),
} }
(encoder_hidden_states, encoder_attention_mask,) = self._concat_conditionings_for_batch( (
encoder_hidden_states,
encoder_attention_mask,
) = self._concat_conditionings_for_batch(
conditioning_data.unconditioned_embeddings.embeds, conditioning_data.unconditioned_embeddings.embeds,
conditioning_data.text_embeddings.embeds, conditioning_data.text_embeddings.embeds,
) )
@ -237,7 +240,10 @@ class InvokeAIDiffuserComponent:
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0 wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
if wants_cross_attention_control: if wants_cross_attention_control:
(unconditioned_next_x, conditioned_next_x,) = self._apply_cross_attention_controlled_conditioning( (
unconditioned_next_x,
conditioned_next_x,
) = self._apply_cross_attention_controlled_conditioning(
sample, sample,
timestep, timestep,
conditioning_data, conditioning_data,
@ -245,14 +251,20 @@ class InvokeAIDiffuserComponent:
**kwargs, **kwargs,
) )
elif self.sequential_guidance: elif self.sequential_guidance:
(unconditioned_next_x, conditioned_next_x,) = self._apply_standard_conditioning_sequentially( (
unconditioned_next_x,
conditioned_next_x,
) = self._apply_standard_conditioning_sequentially(
sample, sample,
timestep, timestep,
conditioning_data, conditioning_data,
**kwargs, **kwargs,
) )
else: else:
(unconditioned_next_x, conditioned_next_x,) = self._apply_standard_conditioning( (
unconditioned_next_x,
conditioned_next_x,
) = self._apply_standard_conditioning(
sample, sample,
timestep, timestep,
conditioning_data, conditioning_data,

View File

@ -470,7 +470,10 @@ class TextualInversionDataset(Dataset):
if self.center_crop: if self.center_crop:
crop = min(img.shape[0], img.shape[1]) crop = min(img.shape[0], img.shape[1])
(h, w,) = ( (
h,
w,
) = (
img.shape[0], img.shape[0],
img.shape[1], img.shape[1],
) )