Merge branch 'main' into nodes-stuff

This commit is contained in:
blessedcoolant
2023-07-19 02:37:50 +12:00
6 changed files with 47 additions and 16 deletions

View File

@ -168,13 +168,14 @@ class TextToLatentsInvocation(BaseInvocation):
self,
context: InvocationContext,
scheduler,
unet,
) -> ConditioningData:
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
c = positive_cond_data.conditionings[0].embeds
c = positive_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype)
extra_conditioning_info = positive_cond_data.conditionings[0].extra_conditioning
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
uc = negative_cond_data.conditionings[0].embeds
uc = negative_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype)
conditioning_data = ConditioningData(
unconditioned_embeddings=uc,
@ -196,7 +197,7 @@ class TextToLatentsInvocation(BaseInvocation):
eta=0.0, # ddim_eta
# for ancestral and sde schedulers
generator=torch.Generator(device=uc.device).manual_seed(0),
generator=torch.Generator(device=unet.device).manual_seed(0),
)
return conditioning_data
@ -335,6 +336,8 @@ class TextToLatentsInvocation(BaseInvocation):
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
unet_info as unet:
noise = noise.to(device=unet.device, dtype=unet.dtype)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
@ -342,7 +345,7 @@ class TextToLatentsInvocation(BaseInvocation):
)
pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler, unet)
control_data = self.prep_control_data(
model=pipeline, context=context, control_input=self.control,
@ -363,6 +366,7 @@ class TextToLatentsInvocation(BaseInvocation):
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.to("cpu")
torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}'
@ -426,6 +430,9 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
unet_info as unet:
noise = noise.to(device=unet.device, dtype=unet.dtype)
latent = latent.to(device=unet.device, dtype=unet.dtype)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
@ -433,7 +440,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
)
pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler, unet)
control_data = self.prep_control_data(
model=pipeline, context=context, control_input=self.control,
@ -465,6 +472,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.to("cpu")
torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}'
@ -506,6 +514,7 @@ class LatentsToImageInvocation(BaseInvocation):
)
with vae_info as vae:
latents = latents.to(vae.device)
if self.fp32:
vae.to(dtype=torch.float32)
@ -601,13 +610,17 @@ class ResizeLatentsInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name)
# TODO:
device=choose_torch_device()
resized_latents = torch.nn.functional.interpolate(
latents, size=(self.height // 8, self.width // 8),
latents.to(device), size=(self.height // 8, self.width // 8),
mode=self.mode, antialias=self.antialias
if self.mode in ["bilinear", "bicubic"] else False,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
resized_latents = resized_latents.to("cpu")
torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}"
@ -643,14 +656,18 @@ class ScaleLatentsInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name)
# TODO:
device=choose_torch_device()
# resizing
resized_latents = torch.nn.functional.interpolate(
latents, scale_factor=self.scale_factor, mode=self.mode,
latents.to(device), scale_factor=self.scale_factor, mode=self.mode,
antialias=self.antialias
if self.mode in ["bilinear", "bicubic"] else False,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
resized_latents = resized_latents.to("cpu")
torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}"
@ -742,6 +759,6 @@ class ImageToLatentsInvocation(BaseInvocation):
latents = latents.to(dtype=orig_dtype)
name = f"{context.graph_execution_state_id}__{self.id}"
# context.services.latents.set(name, latents)
latents = latents.to("cpu")
context.services.latents.save(name, latents)
return build_latents_output(latents_name=name, latents=latents)