trying out JPPhoto's patch on vast.ai

This commit is contained in:
damian 2023-01-26 17:27:33 +01:00
parent 8ed8bf52d0
commit 729752620b
2 changed files with 12 additions and 6 deletions

View File

@ -304,6 +304,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
textual_inversion_manager=self.textual_inversion_manager
)
self._enable_memory_efficient_attention()
def _enable_memory_efficient_attention(self):
"""
if xformers is available, use it, otherwise use sliced attention.
"""
if is_xformers_available() and not Globals.disable_xformers:
self.enable_xformers_memory_efficient_attention()
else:
@ -315,7 +322,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
else:
self.enable_attention_slicing(slice_size='auto')
def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
conditioning_data: ConditioningData,
*,
@ -360,6 +366,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
self.scheduler.set_timesteps(num_inference_steps, device=self.unet.device)
timesteps = self.scheduler.timesteps
infer_latents_from_embeddings = GeneratorToCallbackinator(self.generate_latents_from_embeddings, PipelineIntermediateState)
self._enable_memory_efficient_attention()
result: PipelineIntermediateState = infer_latents_from_embeddings(
latents, timesteps, conditioning_data,
noise=noise,
@ -380,8 +387,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
additional_guidance = []
extra_conditioning_info = conditioning_data.extra
with self.invokeai_diffuser.custom_attention_context(extra_conditioning_info=extra_conditioning_info,
step_count=len(self.scheduler.timesteps),
do_attention_map_saving=False):
step_count=len(self.scheduler.timesteps)
):
yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps,
latents=latents)

View File

@ -60,13 +60,12 @@ class InvokeAIDiffuserComponent:
@contextmanager
def custom_attention_context(self,
extra_conditioning_info: Optional[ExtraConditioningInfo],
step_count: int,
do_attention_map_saving: bool):
step_count: int):
do_swap = extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control
old_attn_processor = None
if do_swap:
old_attn_processor = self.setup_cross_attention_control(extra_conditioning_info,
step_count=step_count)
step_count=step_count)
try:
yield None
finally: