VRAM Optimizations, sdxl on 8gb (#3818)

## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [x] Bug Fix
- [x] Optimization
- [ ] Documentation Update

      
## Description

Various fixes to consume less memory and make run sdxl on 8gb vram.
Most changes due to moving all output tensors to cpu, so that cached
tensors not consume vram.
This commit is contained in:
blessedcoolant 2023-07-19 02:36:58 +12:00 committed by GitHub
commit ae5cb63f3c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 47 additions and 16 deletions

View File

@ -148,6 +148,8 @@ class CompelInvocation(BaseInvocation):
cross_attention_control_args=options.get( cross_attention_control_args=options.get(
"cross_attention_control", None),) "cross_attention_control", None),)
c = c.detach().to("cpu")
conditioning_data = ConditioningFieldData( conditioning_data = ConditioningFieldData(
conditionings=[ conditionings=[
BasicConditioningInfo( BasicConditioningInfo(
@ -230,6 +232,10 @@ class SDXLPromptInvocationBase:
del tokenizer_info del tokenizer_info
del text_encoder_info del text_encoder_info
c = c.detach().to("cpu")
if c_pooled is not None:
c_pooled = c_pooled.detach().to("cpu")
return c, c_pooled, None return c, c_pooled, None
def run_clip_compel(self, context, clip_field, prompt, get_pooled): def run_clip_compel(self, context, clip_field, prompt, get_pooled):
@ -306,6 +312,10 @@ class SDXLPromptInvocationBase:
del tokenizer_info del tokenizer_info
del text_encoder_info del text_encoder_info
c = c.detach().to("cpu")
if c_pooled is not None:
c_pooled = c_pooled.detach().to("cpu")
return c, c_pooled, ec return c, c_pooled, ec
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):

View File

@ -146,13 +146,13 @@ class InpaintInvocation(BaseInvocation):
source_node_id=source_node_id, source_node_id=source_node_id,
) )
def get_conditioning(self, context): def get_conditioning(self, context, unet):
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name) positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
c = positive_cond_data.conditionings[0].embeds c = positive_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype)
extra_conditioning_info = positive_cond_data.conditionings[0].extra_conditioning extra_conditioning_info = positive_cond_data.conditionings[0].extra_conditioning
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name) negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
uc = negative_cond_data.conditionings[0].embeds uc = negative_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype)
return (uc, c, extra_conditioning_info) return (uc, c, extra_conditioning_info)
@ -213,7 +213,6 @@ class InpaintInvocation(BaseInvocation):
) )
source_node_id = graph_execution_state.prepared_source_mapping[self.id] source_node_id = graph_execution_state.prepared_source_mapping[self.id]
conditioning = self.get_conditioning(context)
scheduler = get_scheduler( scheduler = get_scheduler(
context=context, context=context,
scheduler_info=self.unet.scheduler, scheduler_info=self.unet.scheduler,
@ -221,6 +220,8 @@ class InpaintInvocation(BaseInvocation):
) )
with self.load_model_old_way(context, scheduler) as model: with self.load_model_old_way(context, scheduler) as model:
conditioning = self.get_conditioning(context, model.context.model.unet)
outputs = Inpaint(model).generate( outputs = Inpaint(model).generate(
conditioning=conditioning, conditioning=conditioning,
scheduler=scheduler, scheduler=scheduler,

View File

@ -167,13 +167,14 @@ class TextToLatentsInvocation(BaseInvocation):
self, self,
context: InvocationContext, context: InvocationContext,
scheduler, scheduler,
unet,
) -> ConditioningData: ) -> ConditioningData:
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name) positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
c = positive_cond_data.conditionings[0].embeds c = positive_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype)
extra_conditioning_info = positive_cond_data.conditionings[0].extra_conditioning extra_conditioning_info = positive_cond_data.conditionings[0].extra_conditioning
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name) negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
uc = negative_cond_data.conditionings[0].embeds uc = negative_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype)
conditioning_data = ConditioningData( conditioning_data = ConditioningData(
unconditioned_embeddings=uc, unconditioned_embeddings=uc,
@ -195,7 +196,7 @@ class TextToLatentsInvocation(BaseInvocation):
eta=0.0, # ddim_eta eta=0.0, # ddim_eta
# for ancestral and sde schedulers # for ancestral and sde schedulers
generator=torch.Generator(device=uc.device).manual_seed(0), generator=torch.Generator(device=unet.device).manual_seed(0),
) )
return conditioning_data return conditioning_data
@ -334,6 +335,8 @@ class TextToLatentsInvocation(BaseInvocation):
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
unet_info as unet: unet_info as unet:
noise = noise.to(device=unet.device, dtype=unet.dtype)
scheduler = get_scheduler( scheduler = get_scheduler(
context=context, context=context,
scheduler_info=self.unet.scheduler, scheduler_info=self.unet.scheduler,
@ -341,7 +344,7 @@ class TextToLatentsInvocation(BaseInvocation):
) )
pipeline = self.create_pipeline(unet, scheduler) pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler) conditioning_data = self.get_conditioning_data(context, scheduler, unet)
control_data = self.prep_control_data( control_data = self.prep_control_data(
model=pipeline, context=context, control_input=self.control, model=pipeline, context=context, control_input=self.control,
@ -362,6 +365,7 @@ class TextToLatentsInvocation(BaseInvocation):
) )
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.to("cpu")
torch.cuda.empty_cache() torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}' name = f'{context.graph_execution_state_id}__{self.id}'
@ -424,6 +428,9 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
unet_info as unet: unet_info as unet:
noise = noise.to(device=unet.device, dtype=unet.dtype)
latent = latent.to(device=unet.device, dtype=unet.dtype)
scheduler = get_scheduler( scheduler = get_scheduler(
context=context, context=context,
scheduler_info=self.unet.scheduler, scheduler_info=self.unet.scheduler,
@ -431,7 +438,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
) )
pipeline = self.create_pipeline(unet, scheduler) pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler) conditioning_data = self.get_conditioning_data(context, scheduler, unet)
control_data = self.prep_control_data( control_data = self.prep_control_data(
model=pipeline, context=context, control_input=self.control, model=pipeline, context=context, control_input=self.control,
@ -463,6 +470,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
) )
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.to("cpu")
torch.cuda.empty_cache() torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}' name = f'{context.graph_execution_state_id}__{self.id}'
@ -503,6 +511,7 @@ class LatentsToImageInvocation(BaseInvocation):
) )
with vae_info as vae: with vae_info as vae:
latents = latents.to(vae.device)
if self.fp32: if self.fp32:
vae.to(dtype=torch.float32) vae.to(dtype=torch.float32)
@ -590,13 +599,17 @@ class ResizeLatentsInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name) latents = context.services.latents.get(self.latents.latents_name)
# TODO:
device=choose_torch_device()
resized_latents = torch.nn.functional.interpolate( resized_latents = torch.nn.functional.interpolate(
latents, size=(self.height // 8, self.width // 8), latents.to(device), size=(self.height // 8, self.width // 8),
mode=self.mode, antialias=self.antialias mode=self.mode, antialias=self.antialias
if self.mode in ["bilinear", "bicubic"] else False, if self.mode in ["bilinear", "bicubic"] else False,
) )
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
resized_latents = resized_latents.to("cpu")
torch.cuda.empty_cache() torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}" name = f"{context.graph_execution_state_id}__{self.id}"
@ -624,14 +637,18 @@ class ScaleLatentsInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name) latents = context.services.latents.get(self.latents.latents_name)
# TODO:
device=choose_torch_device()
# resizing # resizing
resized_latents = torch.nn.functional.interpolate( resized_latents = torch.nn.functional.interpolate(
latents, scale_factor=self.scale_factor, mode=self.mode, latents.to(device), scale_factor=self.scale_factor, mode=self.mode,
antialias=self.antialias antialias=self.antialias
if self.mode in ["bilinear", "bicubic"] else False, if self.mode in ["bilinear", "bicubic"] else False,
) )
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
resized_latents = resized_latents.to("cpu")
torch.cuda.empty_cache() torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}" name = f"{context.graph_execution_state_id}__{self.id}"
@ -722,6 +739,6 @@ class ImageToLatentsInvocation(BaseInvocation):
latents = latents.to(dtype=orig_dtype) latents = latents.to(dtype=orig_dtype)
name = f"{context.graph_execution_state_id}__{self.id}" name = f"{context.graph_execution_state_id}__{self.id}"
# context.services.latents.set(name, latents) latents = latents.to("cpu")
context.services.latents.save(name, latents) context.services.latents.save(name, latents)
return build_latents_output(latents_name=name, latents=latents) return build_latents_output(latents_name=name, latents=latents)

View File

@ -48,7 +48,7 @@ def get_noise(
dtype=torch_dtype(device), dtype=torch_dtype(device),
device=noise_device_type, device=noise_device_type,
generator=generator, generator=generator,
).to(device) ).to("cpu")
return noise_tensor return noise_tensor

View File

@ -305,7 +305,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype) add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype)
latents = latents.to(device=unet.device, dtype=unet.dtype) latents = latents.to(device=unet.device, dtype=unet.dtype)
with tqdm(total=self.steps) as progress_bar: with tqdm(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
@ -351,7 +351,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype) add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype)
latents = latents.to(device=unet.device, dtype=unet.dtype) latents = latents.to(device=unet.device, dtype=unet.dtype)
with tqdm(total=self.steps) as progress_bar: with tqdm(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
#latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents #latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
@ -415,6 +415,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
################# #################
latents = latents.to("cpu")
torch.cuda.empty_cache() torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}' name = f'{context.graph_execution_state_id}__{self.id}'
@ -651,6 +652,7 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
################# #################
latents = latents.to("cpu")
torch.cuda.empty_cache() torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}' name = f'{context.graph_execution_state_id}__{self.id}'

View File

@ -104,7 +104,8 @@ class ModelCache(object):
:param sha_chunksize: Chunksize to use when calculating sha256 model hash :param sha_chunksize: Chunksize to use when calculating sha256 model hash
''' '''
self.model_infos: Dict[str, ModelBase] = dict() self.model_infos: Dict[str, ModelBase] = dict()
self.lazy_offloading = lazy_offloading # allow lazy offloading only when vram cache enabled
self.lazy_offloading = lazy_offloading and max_vram_cache_size > 0
self.precision: torch.dtype=precision self.precision: torch.dtype=precision
self.max_cache_size: float=max_cache_size self.max_cache_size: float=max_cache_size
self.max_vram_cache_size: float=max_vram_cache_size self.max_vram_cache_size: float=max_vram_cache_size