mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Make it run (with artifacts).
This commit is contained in:
parent
823c663e1b
commit
fb5db32bb0
@ -118,7 +118,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
img, img_ids = self._prepare_latent_img_patches(x)
|
||||
|
||||
# HACK(ryand): Find a better way to determine if this is a schnell model or not.
|
||||
is_schnell = "shnell" in str(flux_transformer_path)
|
||||
is_schnell = "schnell" in str(flux_transformer_path)
|
||||
timesteps = get_schedule(
|
||||
num_steps=self.num_steps,
|
||||
image_seq_len=img.shape[1],
|
||||
@ -174,6 +174,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
img_ids = img_ids.to(latent_img.device)
|
||||
|
||||
return img, img_ids
|
||||
|
||||
@ -239,13 +240,13 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
else:
|
||||
raise ValueError(f"Unsupported quantization type: {self.quantization_type}")
|
||||
|
||||
assert isinstance(model, FluxTransformer2DModel)
|
||||
assert isinstance(model, Flux)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _load_flux_vae(path: Path) -> AutoEncoder:
|
||||
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
|
||||
ae_params = flux_configs["flux1-schnell"].ae_params
|
||||
ae_params = flux_configs["flux-schnell"].ae_params
|
||||
with accelerate.init_empty_weights():
|
||||
ae = AutoEncoder(ae_params)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user