mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into bugfix/prevent-ti-frontend-crash
This commit is contained in:
commit
8e47ca8d57
1
.github/workflows/test-invoke-pip.yml
vendored
1
.github/workflows/test-invoke-pip.yml
vendored
@ -8,6 +8,7 @@ on:
|
||||
- 'ready_for_review'
|
||||
- 'opened'
|
||||
- 'synchronize'
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
|
@ -223,7 +223,7 @@ class Generate:
|
||||
self.model_name = model or fallback
|
||||
|
||||
# for VRAM usage statistics
|
||||
self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None
|
||||
self.session_peakmem = torch.cuda.max_memory_allocated(self.device) if self._has_cuda else None
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
# gets rid of annoying messages about random seed
|
||||
@ -592,20 +592,24 @@ class Generate:
|
||||
self.print_cuda_stats()
|
||||
return results
|
||||
|
||||
def clear_cuda_cache(self):
|
||||
def gather_cuda_stats(self):
|
||||
if self._has_cuda():
|
||||
self.max_memory_allocated = max(
|
||||
self.max_memory_allocated,
|
||||
torch.cuda.max_memory_allocated()
|
||||
torch.cuda.max_memory_allocated(self.device)
|
||||
)
|
||||
self.memory_allocated = max(
|
||||
self.memory_allocated,
|
||||
torch.cuda.memory_allocated()
|
||||
torch.cuda.memory_allocated(self.device)
|
||||
)
|
||||
self.session_peakmem = max(
|
||||
self.session_peakmem,
|
||||
torch.cuda.max_memory_allocated()
|
||||
torch.cuda.max_memory_allocated(self.device)
|
||||
)
|
||||
|
||||
def clear_cuda_cache(self):
|
||||
if self._has_cuda():
|
||||
self.gather_cuda_stats()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def clear_cuda_stats(self):
|
||||
@ -614,6 +618,7 @@ class Generate:
|
||||
|
||||
def print_cuda_stats(self):
|
||||
if self._has_cuda():
|
||||
self.gather_cuda_stats()
|
||||
print(
|
||||
'>> Max VRAM used for this generation:',
|
||||
'%4.2fG.' % (self.max_memory_allocated / 1e9),
|
||||
|
@ -301,10 +301,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
textual_inversion_manager=self.textual_inversion_manager
|
||||
)
|
||||
|
||||
self._enable_memory_efficient_attention()
|
||||
|
||||
|
||||
def _enable_memory_efficient_attention(self):
|
||||
def _adjust_memory_efficient_attention(self, latents: Torch.tensor):
|
||||
"""
|
||||
if xformers is available, use it, otherwise use sliced attention.
|
||||
"""
|
||||
@ -317,7 +315,24 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
# fix is in https://github.com/kulinseth/pytorch/pull/222 but no idea when it will get merged to pytorch mainline.
|
||||
pass
|
||||
else:
|
||||
self.enable_attention_slicing(slice_size='max')
|
||||
if self.device.type == 'cpu' or self.device.type == 'mps':
|
||||
mem_free = psutil.virtual_memory().free
|
||||
elif self.device.type == 'cuda':
|
||||
mem_free, _ = torch.cuda.mem_get_info(self.device)
|
||||
else:
|
||||
raise ValueError(f"unrecognized device {device}")
|
||||
# input tensor of [1, 4, h/8, w/8]
|
||||
# output tensor of [16, (h/8 * w/8), (h/8 * w/8)]
|
||||
bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4
|
||||
max_size_required_for_baddbmm = \
|
||||
16 * \
|
||||
latents.size(dim=2) * latents.size(dim=3) * latents.size(dim=2) * latents.size(dim=3) * \
|
||||
bytes_per_element_needed_for_baddbmm_duplication
|
||||
if max_size_required_for_baddbmm > (mem_free * 3.3 / 4.0): # 3.3 / 4.0 is from old Invoke code
|
||||
self.enable_attention_slicing(slice_size='max')
|
||||
else:
|
||||
self.disable_attention_slicing()
|
||||
|
||||
|
||||
def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
|
||||
conditioning_data: ConditioningData,
|
||||
@ -377,6 +392,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
noise: torch.Tensor,
|
||||
run_id: str = None,
|
||||
additional_guidance: List[Callable] = None):
|
||||
self._adjust_memory_efficient_attention(latents)
|
||||
if run_id is None:
|
||||
run_id = secrets.token_urlsafe(self.ID_LENGTH)
|
||||
if additional_guidance is None:
|
||||
|
Loading…
Reference in New Issue
Block a user