Make it run (with artifacts).

This commit is contained in:
Ryan Dick 2024-08-19 22:12:12 +00:00
parent 823c663e1b
commit fb5db32bb0

View File

@ -118,7 +118,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
img, img_ids = self._prepare_latent_img_patches(x) img, img_ids = self._prepare_latent_img_patches(x)
# HACK(ryand): Find a better way to determine if this is a schnell model or not. # HACK(ryand): Find a better way to determine if this is a schnell model or not.
is_schnell = "shnell" in str(flux_transformer_path) is_schnell = "schnell" in str(flux_transformer_path)
timesteps = get_schedule( timesteps = get_schedule(
num_steps=self.num_steps, num_steps=self.num_steps,
image_seq_len=img.shape[1], image_seq_len=img.shape[1],
@ -174,6 +174,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
img_ids = img_ids.to(latent_img.device)
return img, img_ids return img, img_ids
@ -239,13 +240,13 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
else: else:
raise ValueError(f"Unsupported quantization type: {self.quantization_type}") raise ValueError(f"Unsupported quantization type: {self.quantization_type}")
assert isinstance(model, FluxTransformer2DModel) assert isinstance(model, Flux)
return model return model
@staticmethod @staticmethod
def _load_flux_vae(path: Path) -> AutoEncoder: def _load_flux_vae(path: Path) -> AutoEncoder:
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config. # TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
ae_params = flux_configs["flux1-schnell"].ae_params ae_params = flux_configs["flux-schnell"].ae_params
with accelerate.init_empty_weights(): with accelerate.init_empty_weights():
ae = AutoEncoder(ae_params) ae = AutoEncoder(ae_params)