VRAM Optimizations, sdxl on 8gb (#3818)

## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [x] Bug Fix
- [x] Optimization
- [ ] Documentation Update

      
## Description

Various fixes to consume less memory and make run sdxl on 8gb vram.
Most changes due to moving all output tensors to cpu, so that cached
tensors not consume vram.
This commit is contained in:
blessedcoolant 2023-07-19 02:36:58 +12:00 committed by GitHub
commit ae5cb63f3c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 47 additions and 16 deletions

View File

@ -148,6 +148,8 @@ class CompelInvocation(BaseInvocation):
cross_attention_control_args=options.get(
"cross_attention_control", None),)
c = c.detach().to("cpu")
conditioning_data = ConditioningFieldData(
conditionings=[
BasicConditioningInfo(
@ -230,6 +232,10 @@ class SDXLPromptInvocationBase:
del tokenizer_info
del text_encoder_info
c = c.detach().to("cpu")
if c_pooled is not None:
c_pooled = c_pooled.detach().to("cpu")
return c, c_pooled, None
def run_clip_compel(self, context, clip_field, prompt, get_pooled):
@ -306,6 +312,10 @@ class SDXLPromptInvocationBase:
del tokenizer_info
del text_encoder_info
c = c.detach().to("cpu")
if c_pooled is not None:
c_pooled = c_pooled.detach().to("cpu")
return c, c_pooled, ec
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):

View File

@ -146,13 +146,13 @@ class InpaintInvocation(BaseInvocation):
source_node_id=source_node_id,
)
def get_conditioning(self, context):
def get_conditioning(self, context, unet):
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
c = positive_cond_data.conditionings[0].embeds
c = positive_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype)
extra_conditioning_info = positive_cond_data.conditionings[0].extra_conditioning
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
uc = negative_cond_data.conditionings[0].embeds
uc = negative_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype)
return (uc, c, extra_conditioning_info)
@ -213,7 +213,6 @@ class InpaintInvocation(BaseInvocation):
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
conditioning = self.get_conditioning(context)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
@ -221,6 +220,8 @@ class InpaintInvocation(BaseInvocation):
)
with self.load_model_old_way(context, scheduler) as model:
conditioning = self.get_conditioning(context, model.context.model.unet)
outputs = Inpaint(model).generate(
conditioning=conditioning,
scheduler=scheduler,

View File

@ -167,13 +167,14 @@ class TextToLatentsInvocation(BaseInvocation):
self,
context: InvocationContext,
scheduler,
unet,
) -> ConditioningData:
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
c = positive_cond_data.conditionings[0].embeds
c = positive_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype)
extra_conditioning_info = positive_cond_data.conditionings[0].extra_conditioning
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
uc = negative_cond_data.conditionings[0].embeds
uc = negative_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype)
conditioning_data = ConditioningData(
unconditioned_embeddings=uc,
@ -195,7 +196,7 @@ class TextToLatentsInvocation(BaseInvocation):
eta=0.0, # ddim_eta
# for ancestral and sde schedulers
generator=torch.Generator(device=uc.device).manual_seed(0),
generator=torch.Generator(device=unet.device).manual_seed(0),
)
return conditioning_data
@ -334,6 +335,8 @@ class TextToLatentsInvocation(BaseInvocation):
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
unet_info as unet:
noise = noise.to(device=unet.device, dtype=unet.dtype)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
@ -341,7 +344,7 @@ class TextToLatentsInvocation(BaseInvocation):
)
pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler, unet)
control_data = self.prep_control_data(
model=pipeline, context=context, control_input=self.control,
@ -362,6 +365,7 @@ class TextToLatentsInvocation(BaseInvocation):
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.to("cpu")
torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}'
@ -424,6 +428,9 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
unet_info as unet:
noise = noise.to(device=unet.device, dtype=unet.dtype)
latent = latent.to(device=unet.device, dtype=unet.dtype)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
@ -431,7 +438,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
)
pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler, unet)
control_data = self.prep_control_data(
model=pipeline, context=context, control_input=self.control,
@ -463,6 +470,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.to("cpu")
torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}'
@ -503,6 +511,7 @@ class LatentsToImageInvocation(BaseInvocation):
)
with vae_info as vae:
latents = latents.to(vae.device)
if self.fp32:
vae.to(dtype=torch.float32)
@ -590,13 +599,17 @@ class ResizeLatentsInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name)
# TODO:
device=choose_torch_device()
resized_latents = torch.nn.functional.interpolate(
latents, size=(self.height // 8, self.width // 8),
latents.to(device), size=(self.height // 8, self.width // 8),
mode=self.mode, antialias=self.antialias
if self.mode in ["bilinear", "bicubic"] else False,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
resized_latents = resized_latents.to("cpu")
torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}"
@ -624,14 +637,18 @@ class ScaleLatentsInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name)
# TODO:
device=choose_torch_device()
# resizing
resized_latents = torch.nn.functional.interpolate(
latents, scale_factor=self.scale_factor, mode=self.mode,
latents.to(device), scale_factor=self.scale_factor, mode=self.mode,
antialias=self.antialias
if self.mode in ["bilinear", "bicubic"] else False,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
resized_latents = resized_latents.to("cpu")
torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}"
@ -722,6 +739,6 @@ class ImageToLatentsInvocation(BaseInvocation):
latents = latents.to(dtype=orig_dtype)
name = f"{context.graph_execution_state_id}__{self.id}"
# context.services.latents.set(name, latents)
latents = latents.to("cpu")
context.services.latents.save(name, latents)
return build_latents_output(latents_name=name, latents=latents)

View File

@ -48,7 +48,7 @@ def get_noise(
dtype=torch_dtype(device),
device=noise_device_type,
generator=generator,
).to(device)
).to("cpu")
return noise_tensor

View File

@ -305,7 +305,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype)
latents = latents.to(device=unet.device, dtype=unet.dtype)
with tqdm(total=self.steps) as progress_bar:
with tqdm(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
@ -351,7 +351,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype)
latents = latents.to(device=unet.device, dtype=unet.dtype)
with tqdm(total=self.steps) as progress_bar:
with tqdm(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
#latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
@ -415,6 +415,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
#################
latents = latents.to("cpu")
torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}'
@ -651,6 +652,7 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
#################
latents = latents.to("cpu")
torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}'

View File

@ -104,7 +104,8 @@ class ModelCache(object):
:param sha_chunksize: Chunksize to use when calculating sha256 model hash
'''
self.model_infos: Dict[str, ModelBase] = dict()
self.lazy_offloading = lazy_offloading
# allow lazy offloading only when vram cache enabled
self.lazy_offloading = lazy_offloading and max_vram_cache_size > 0
self.precision: torch.dtype=precision
self.max_cache_size: float=max_cache_size
self.max_vram_cache_size: float=max_vram_cache_size