trying out JPPhoto's patch on vast.ai

This commit is contained in:
damian 2023-01-26 17:27:33 +01:00
parent 8ed8bf52d0
commit 729752620b
2 changed files with 12 additions and 6 deletions

View File

@ -304,6 +304,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
textual_inversion_manager=self.textual_inversion_manager textual_inversion_manager=self.textual_inversion_manager
) )
self._enable_memory_efficient_attention()
def _enable_memory_efficient_attention(self):
"""
if xformers is available, use it, otherwise use sliced attention.
"""
if is_xformers_available() and not Globals.disable_xformers: if is_xformers_available() and not Globals.disable_xformers:
self.enable_xformers_memory_efficient_attention() self.enable_xformers_memory_efficient_attention()
else: else:
@ -315,7 +322,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
else: else:
self.enable_attention_slicing(slice_size='auto') self.enable_attention_slicing(slice_size='auto')
def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int, def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
conditioning_data: ConditioningData, conditioning_data: ConditioningData,
*, *,
@ -360,6 +366,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
self.scheduler.set_timesteps(num_inference_steps, device=self.unet.device) self.scheduler.set_timesteps(num_inference_steps, device=self.unet.device)
timesteps = self.scheduler.timesteps timesteps = self.scheduler.timesteps
infer_latents_from_embeddings = GeneratorToCallbackinator(self.generate_latents_from_embeddings, PipelineIntermediateState) infer_latents_from_embeddings = GeneratorToCallbackinator(self.generate_latents_from_embeddings, PipelineIntermediateState)
self._enable_memory_efficient_attention()
result: PipelineIntermediateState = infer_latents_from_embeddings( result: PipelineIntermediateState = infer_latents_from_embeddings(
latents, timesteps, conditioning_data, latents, timesteps, conditioning_data,
noise=noise, noise=noise,
@ -380,8 +387,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
additional_guidance = [] additional_guidance = []
extra_conditioning_info = conditioning_data.extra extra_conditioning_info = conditioning_data.extra
with self.invokeai_diffuser.custom_attention_context(extra_conditioning_info=extra_conditioning_info, with self.invokeai_diffuser.custom_attention_context(extra_conditioning_info=extra_conditioning_info,
step_count=len(self.scheduler.timesteps), step_count=len(self.scheduler.timesteps)
do_attention_map_saving=False): ):
yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps, yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps,
latents=latents) latents=latents)

View File

@ -60,13 +60,12 @@ class InvokeAIDiffuserComponent:
@contextmanager @contextmanager
def custom_attention_context(self, def custom_attention_context(self,
extra_conditioning_info: Optional[ExtraConditioningInfo], extra_conditioning_info: Optional[ExtraConditioningInfo],
step_count: int, step_count: int):
do_attention_map_saving: bool):
do_swap = extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control do_swap = extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control
old_attn_processor = None old_attn_processor = None
if do_swap: if do_swap:
old_attn_processor = self.setup_cross_attention_control(extra_conditioning_info, old_attn_processor = self.setup_cross_attention_control(extra_conditioning_info,
step_count=step_count) step_count=step_count)
try: try:
yield None yield None
finally: finally: