Merge branch 'main' into bugfix/prevent-ti-frontend-crash

This commit is contained in:
Lincoln Stein 2023-02-12 23:56:41 -05:00 committed by GitHub
commit 8e47ca8d57
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 31 additions and 9 deletions

View File

@ -8,6 +8,7 @@ on:
- 'ready_for_review'
- 'opened'
- 'synchronize'
merge_group:
workflow_dispatch:
concurrency:

View File

@ -223,7 +223,7 @@ class Generate:
self.model_name = model or fallback
# for VRAM usage statistics
self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None
self.session_peakmem = torch.cuda.max_memory_allocated(self.device) if self._has_cuda else None
transformers.logging.set_verbosity_error()
# gets rid of annoying messages about random seed
@ -592,20 +592,24 @@ class Generate:
self.print_cuda_stats()
return results
def clear_cuda_cache(self):
def gather_cuda_stats(self):
if self._has_cuda():
self.max_memory_allocated = max(
self.max_memory_allocated,
torch.cuda.max_memory_allocated()
torch.cuda.max_memory_allocated(self.device)
)
self.memory_allocated = max(
self.memory_allocated,
torch.cuda.memory_allocated()
torch.cuda.memory_allocated(self.device)
)
self.session_peakmem = max(
self.session_peakmem,
torch.cuda.max_memory_allocated()
torch.cuda.max_memory_allocated(self.device)
)
def clear_cuda_cache(self):
if self._has_cuda():
self.gather_cuda_stats()
torch.cuda.empty_cache()
def clear_cuda_stats(self):
@ -614,6 +618,7 @@ class Generate:
def print_cuda_stats(self):
if self._has_cuda():
self.gather_cuda_stats()
print(
'>> Max VRAM used for this generation:',
'%4.2fG.' % (self.max_memory_allocated / 1e9),

View File

@ -301,10 +301,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
textual_inversion_manager=self.textual_inversion_manager
)
self._enable_memory_efficient_attention()
def _enable_memory_efficient_attention(self):
def _adjust_memory_efficient_attention(self, latents: Torch.tensor):
"""
if xformers is available, use it, otherwise use sliced attention.
"""
@ -317,7 +315,24 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# fix is in https://github.com/kulinseth/pytorch/pull/222 but no idea when it will get merged to pytorch mainline.
pass
else:
self.enable_attention_slicing(slice_size='max')
if self.device.type == 'cpu' or self.device.type == 'mps':
mem_free = psutil.virtual_memory().free
elif self.device.type == 'cuda':
mem_free, _ = torch.cuda.mem_get_info(self.device)
else:
raise ValueError(f"unrecognized device {device}")
# input tensor of [1, 4, h/8, w/8]
# output tensor of [16, (h/8 * w/8), (h/8 * w/8)]
bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4
max_size_required_for_baddbmm = \
16 * \
latents.size(dim=2) * latents.size(dim=3) * latents.size(dim=2) * latents.size(dim=3) * \
bytes_per_element_needed_for_baddbmm_duplication
if max_size_required_for_baddbmm > (mem_free * 3.3 / 4.0): # 3.3 / 4.0 is from old Invoke code
self.enable_attention_slicing(slice_size='max')
else:
self.disable_attention_slicing()
def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
conditioning_data: ConditioningData,
@ -377,6 +392,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noise: torch.Tensor,
run_id: str = None,
additional_guidance: List[Callable] = None):
self._adjust_memory_efficient_attention(latents)
if run_id is None:
run_id = secrets.token_urlsafe(self.ID_LENGTH)
if additional_guidance is None: