mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into nodes-stuff
This commit is contained in:
commit
3f1d5000c0
@ -148,6 +148,8 @@ class CompelInvocation(BaseInvocation):
|
||||
cross_attention_control_args=options.get(
|
||||
"cross_attention_control", None),)
|
||||
|
||||
c = c.detach().to("cpu")
|
||||
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[
|
||||
BasicConditioningInfo(
|
||||
@ -230,6 +232,10 @@ class SDXLPromptInvocationBase:
|
||||
del tokenizer_info
|
||||
del text_encoder_info
|
||||
|
||||
c = c.detach().to("cpu")
|
||||
if c_pooled is not None:
|
||||
c_pooled = c_pooled.detach().to("cpu")
|
||||
|
||||
return c, c_pooled, None
|
||||
|
||||
def run_clip_compel(self, context, clip_field, prompt, get_pooled):
|
||||
@ -306,6 +312,10 @@ class SDXLPromptInvocationBase:
|
||||
del tokenizer_info
|
||||
del text_encoder_info
|
||||
|
||||
c = c.detach().to("cpu")
|
||||
if c_pooled is not None:
|
||||
c_pooled = c_pooled.detach().to("cpu")
|
||||
|
||||
return c, c_pooled, ec
|
||||
|
||||
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
|
@ -147,13 +147,13 @@ class InpaintInvocation(BaseInvocation):
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
def get_conditioning(self, context):
|
||||
def get_conditioning(self, context, unet):
|
||||
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||
c = positive_cond_data.conditionings[0].embeds
|
||||
c = positive_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
extra_conditioning_info = positive_cond_data.conditionings[0].extra_conditioning
|
||||
|
||||
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||
uc = negative_cond_data.conditionings[0].embeds
|
||||
uc = negative_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
return (uc, c, extra_conditioning_info)
|
||||
|
||||
@ -214,7 +214,6 @@ class InpaintInvocation(BaseInvocation):
|
||||
)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
conditioning = self.get_conditioning(context)
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
@ -222,6 +221,8 @@ class InpaintInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
with self.load_model_old_way(context, scheduler) as model:
|
||||
conditioning = self.get_conditioning(context, model.context.model.unet)
|
||||
|
||||
outputs = Inpaint(model).generate(
|
||||
conditioning=conditioning,
|
||||
scheduler=scheduler,
|
||||
|
@ -168,13 +168,14 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
self,
|
||||
context: InvocationContext,
|
||||
scheduler,
|
||||
unet,
|
||||
) -> ConditioningData:
|
||||
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||
c = positive_cond_data.conditionings[0].embeds
|
||||
c = positive_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
extra_conditioning_info = positive_cond_data.conditionings[0].extra_conditioning
|
||||
|
||||
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||
uc = negative_cond_data.conditionings[0].embeds
|
||||
uc = negative_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
conditioning_data = ConditioningData(
|
||||
unconditioned_embeddings=uc,
|
||||
@ -196,7 +197,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
eta=0.0, # ddim_eta
|
||||
|
||||
# for ancestral and sde schedulers
|
||||
generator=torch.Generator(device=uc.device).manual_seed(0),
|
||||
generator=torch.Generator(device=unet.device).manual_seed(0),
|
||||
)
|
||||
return conditioning_data
|
||||
|
||||
@ -335,6 +336,8 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
||||
unet_info as unet:
|
||||
|
||||
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
@ -342,7 +345,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
pipeline = self.create_pipeline(unet, scheduler)
|
||||
conditioning_data = self.get_conditioning_data(context, scheduler)
|
||||
conditioning_data = self.get_conditioning_data(context, scheduler, unet)
|
||||
|
||||
control_data = self.prep_control_data(
|
||||
model=pipeline, context=context, control_input=self.control,
|
||||
@ -363,6 +366,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
result_latents = result_latents.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
@ -426,6 +430,9 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
||||
unet_info as unet:
|
||||
|
||||
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
||||
latent = latent.to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
@ -433,7 +440,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
)
|
||||
|
||||
pipeline = self.create_pipeline(unet, scheduler)
|
||||
conditioning_data = self.get_conditioning_data(context, scheduler)
|
||||
conditioning_data = self.get_conditioning_data(context, scheduler, unet)
|
||||
|
||||
control_data = self.prep_control_data(
|
||||
model=pipeline, context=context, control_input=self.control,
|
||||
@ -465,6 +472,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
result_latents = result_latents.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
@ -506,6 +514,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
with vae_info as vae:
|
||||
latents = latents.to(vae.device)
|
||||
if self.fp32:
|
||||
vae.to(dtype=torch.float32)
|
||||
|
||||
@ -601,13 +610,17 @@ class ResizeLatentsInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
# TODO:
|
||||
device=choose_torch_device()
|
||||
|
||||
resized_latents = torch.nn.functional.interpolate(
|
||||
latents, size=(self.height // 8, self.width // 8),
|
||||
latents.to(device), size=(self.height // 8, self.width // 8),
|
||||
mode=self.mode, antialias=self.antialias
|
||||
if self.mode in ["bilinear", "bicubic"] else False,
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
resized_latents = resized_latents.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
@ -643,14 +656,18 @@ class ScaleLatentsInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
# TODO:
|
||||
device=choose_torch_device()
|
||||
|
||||
# resizing
|
||||
resized_latents = torch.nn.functional.interpolate(
|
||||
latents, scale_factor=self.scale_factor, mode=self.mode,
|
||||
latents.to(device), scale_factor=self.scale_factor, mode=self.mode,
|
||||
antialias=self.antialias
|
||||
if self.mode in ["bilinear", "bicubic"] else False,
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
resized_latents = resized_latents.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
@ -742,6 +759,6 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
latents = latents.to(dtype=orig_dtype)
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
# context.services.latents.set(name, latents)
|
||||
latents = latents.to("cpu")
|
||||
context.services.latents.save(name, latents)
|
||||
return build_latents_output(latents_name=name, latents=latents)
|
||||
|
@ -48,7 +48,7 @@ def get_noise(
|
||||
dtype=torch_dtype(device),
|
||||
device=noise_device_type,
|
||||
generator=generator,
|
||||
).to(device)
|
||||
).to("cpu")
|
||||
|
||||
return noise_tensor
|
||||
|
||||
|
@ -306,7 +306,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
||||
add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype)
|
||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
with tqdm(total=self.steps) as progress_bar:
|
||||
with tqdm(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
@ -352,7 +352,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
||||
add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype)
|
||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
with tqdm(total=self.steps) as progress_bar:
|
||||
with tqdm(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
#latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
@ -416,6 +416,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
||||
|
||||
#################
|
||||
|
||||
latents = latents.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
@ -653,6 +654,7 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
||||
|
||||
#################
|
||||
|
||||
latents = latents.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
|
@ -104,7 +104,8 @@ class ModelCache(object):
|
||||
:param sha_chunksize: Chunksize to use when calculating sha256 model hash
|
||||
'''
|
||||
self.model_infos: Dict[str, ModelBase] = dict()
|
||||
self.lazy_offloading = lazy_offloading
|
||||
# allow lazy offloading only when vram cache enabled
|
||||
self.lazy_offloading = lazy_offloading and max_vram_cache_size > 0
|
||||
self.precision: torch.dtype=precision
|
||||
self.max_cache_size: float=max_cache_size
|
||||
self.max_vram_cache_size: float=max_vram_cache_size
|
||||
|
Loading…
Reference in New Issue
Block a user