mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Make it run (with artifacts).
This commit is contained in:
parent
823c663e1b
commit
fb5db32bb0
@ -118,7 +118,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
img, img_ids = self._prepare_latent_img_patches(x)
|
img, img_ids = self._prepare_latent_img_patches(x)
|
||||||
|
|
||||||
# HACK(ryand): Find a better way to determine if this is a schnell model or not.
|
# HACK(ryand): Find a better way to determine if this is a schnell model or not.
|
||||||
is_schnell = "shnell" in str(flux_transformer_path)
|
is_schnell = "schnell" in str(flux_transformer_path)
|
||||||
timesteps = get_schedule(
|
timesteps = get_schedule(
|
||||||
num_steps=self.num_steps,
|
num_steps=self.num_steps,
|
||||||
image_seq_len=img.shape[1],
|
image_seq_len=img.shape[1],
|
||||||
@ -174,6 +174,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
|
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
|
||||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
|
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
|
||||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||||
|
img_ids = img_ids.to(latent_img.device)
|
||||||
|
|
||||||
return img, img_ids
|
return img, img_ids
|
||||||
|
|
||||||
@ -239,13 +240,13 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported quantization type: {self.quantization_type}")
|
raise ValueError(f"Unsupported quantization type: {self.quantization_type}")
|
||||||
|
|
||||||
assert isinstance(model, FluxTransformer2DModel)
|
assert isinstance(model, Flux)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _load_flux_vae(path: Path) -> AutoEncoder:
|
def _load_flux_vae(path: Path) -> AutoEncoder:
|
||||||
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
|
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
|
||||||
ae_params = flux_configs["flux1-schnell"].ae_params
|
ae_params = flux_configs["flux-schnell"].ae_params
|
||||||
with accelerate.init_empty_weights():
|
with accelerate.init_empty_weights():
|
||||||
ae = AutoEncoder(ae_params)
|
ae = AutoEncoder(ae_params)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user