mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Disable lazy offloading on disabled vram cache, move resulted tensors to cpu(to not stack vram tensors in cache), fix - text encoder not freed(detach)
This commit is contained in:
parent
ada9b06e48
commit
bc11296a5e
@ -147,6 +147,8 @@ class CompelInvocation(BaseInvocation):
|
|||||||
cross_attention_control_args=options.get(
|
cross_attention_control_args=options.get(
|
||||||
"cross_attention_control", None),)
|
"cross_attention_control", None),)
|
||||||
|
|
||||||
|
c = c.detach().to("cpu")
|
||||||
|
|
||||||
conditioning_data = ConditioningFieldData(
|
conditioning_data = ConditioningFieldData(
|
||||||
conditionings=[
|
conditionings=[
|
||||||
BasicConditioningInfo(
|
BasicConditioningInfo(
|
||||||
@ -229,6 +231,10 @@ class SDXLPromptInvocationBase:
|
|||||||
del tokenizer_info
|
del tokenizer_info
|
||||||
del text_encoder_info
|
del text_encoder_info
|
||||||
|
|
||||||
|
c = c.detach().to("cpu")
|
||||||
|
if c_pooled is not None:
|
||||||
|
c_pooled = c_pooled.detach().to("cpu")
|
||||||
|
|
||||||
return c, c_pooled, None
|
return c, c_pooled, None
|
||||||
|
|
||||||
def run_clip_compel(self, context, clip_field, prompt, get_pooled):
|
def run_clip_compel(self, context, clip_field, prompt, get_pooled):
|
||||||
@ -305,6 +311,10 @@ class SDXLPromptInvocationBase:
|
|||||||
del tokenizer_info
|
del tokenizer_info
|
||||||
del text_encoder_info
|
del text_encoder_info
|
||||||
|
|
||||||
|
c = c.detach().to("cpu")
|
||||||
|
if c_pooled is not None:
|
||||||
|
c_pooled = c_pooled.detach().to("cpu")
|
||||||
|
|
||||||
return c, c_pooled, ec
|
return c, c_pooled, ec
|
||||||
|
|
||||||
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||||
|
@ -167,13 +167,14 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
self,
|
self,
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
scheduler,
|
scheduler,
|
||||||
|
unet,
|
||||||
) -> ConditioningData:
|
) -> ConditioningData:
|
||||||
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||||
c = positive_cond_data.conditionings[0].embeds
|
c = positive_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype)
|
||||||
extra_conditioning_info = positive_cond_data.conditionings[0].extra_conditioning
|
extra_conditioning_info = positive_cond_data.conditionings[0].extra_conditioning
|
||||||
|
|
||||||
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||||
uc = negative_cond_data.conditionings[0].embeds
|
uc = negative_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
|
||||||
conditioning_data = ConditioningData(
|
conditioning_data = ConditioningData(
|
||||||
unconditioned_embeddings=uc,
|
unconditioned_embeddings=uc,
|
||||||
@ -195,7 +196,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
eta=0.0, # ddim_eta
|
eta=0.0, # ddim_eta
|
||||||
|
|
||||||
# for ancestral and sde schedulers
|
# for ancestral and sde schedulers
|
||||||
generator=torch.Generator(device=uc.device).manual_seed(0),
|
generator=torch.Generator(device=unet.device).manual_seed(0),
|
||||||
)
|
)
|
||||||
return conditioning_data
|
return conditioning_data
|
||||||
|
|
||||||
@ -340,7 +341,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
pipeline = self.create_pipeline(unet, scheduler)
|
pipeline = self.create_pipeline(unet, scheduler)
|
||||||
conditioning_data = self.get_conditioning_data(context, scheduler)
|
conditioning_data = self.get_conditioning_data(context, scheduler, unet)
|
||||||
|
|
||||||
control_data = self.prep_control_data(
|
control_data = self.prep_control_data(
|
||||||
model=pipeline, context=context, control_input=self.control,
|
model=pipeline, context=context, control_input=self.control,
|
||||||
@ -361,6 +362,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
|
result_latents = result_latents.to("cpu")
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||||
@ -430,7 +432,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
pipeline = self.create_pipeline(unet, scheduler)
|
pipeline = self.create_pipeline(unet, scheduler)
|
||||||
conditioning_data = self.get_conditioning_data(context, scheduler)
|
conditioning_data = self.get_conditioning_data(context, scheduler, unet)
|
||||||
|
|
||||||
control_data = self.prep_control_data(
|
control_data = self.prep_control_data(
|
||||||
model=pipeline, context=context, control_input=self.control,
|
model=pipeline, context=context, control_input=self.control,
|
||||||
@ -462,6 +464,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
|
result_latents = result_latents.to("cpu")
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||||
@ -502,6 +505,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
with vae_info as vae:
|
with vae_info as vae:
|
||||||
|
latents = latents.to(vae.device)
|
||||||
if self.fp32:
|
if self.fp32:
|
||||||
vae.to(dtype=torch.float32)
|
vae.to(dtype=torch.float32)
|
||||||
|
|
||||||
@ -589,13 +593,17 @@ class ResizeLatentsInvocation(BaseInvocation):
|
|||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.services.latents.get(self.latents.latents_name)
|
||||||
|
|
||||||
|
# TODO:
|
||||||
|
device=choose_torch_device()
|
||||||
|
|
||||||
resized_latents = torch.nn.functional.interpolate(
|
resized_latents = torch.nn.functional.interpolate(
|
||||||
latents, size=(self.height // 8, self.width // 8),
|
latents.to(device), size=(self.height // 8, self.width // 8),
|
||||||
mode=self.mode, antialias=self.antialias
|
mode=self.mode, antialias=self.antialias
|
||||||
if self.mode in ["bilinear", "bicubic"] else False,
|
if self.mode in ["bilinear", "bicubic"] else False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
|
resized_latents = resized_latents.to("cpu")
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
@ -623,14 +631,18 @@ class ScaleLatentsInvocation(BaseInvocation):
|
|||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.services.latents.get(self.latents.latents_name)
|
||||||
|
|
||||||
|
# TODO:
|
||||||
|
device=choose_torch_device()
|
||||||
|
|
||||||
# resizing
|
# resizing
|
||||||
resized_latents = torch.nn.functional.interpolate(
|
resized_latents = torch.nn.functional.interpolate(
|
||||||
latents, scale_factor=self.scale_factor, mode=self.mode,
|
latents.to(device), scale_factor=self.scale_factor, mode=self.mode,
|
||||||
antialias=self.antialias
|
antialias=self.antialias
|
||||||
if self.mode in ["bilinear", "bicubic"] else False,
|
if self.mode in ["bilinear", "bicubic"] else False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
|
resized_latents = resized_latents.to("cpu")
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
@ -721,6 +733,6 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
latents = latents.to(dtype=orig_dtype)
|
latents = latents.to(dtype=orig_dtype)
|
||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
# context.services.latents.set(name, latents)
|
latents = latents.to("cpu")
|
||||||
context.services.latents.save(name, latents)
|
context.services.latents.save(name, latents)
|
||||||
return build_latents_output(latents_name=name, latents=latents)
|
return build_latents_output(latents_name=name, latents=latents)
|
||||||
|
@ -48,7 +48,7 @@ def get_noise(
|
|||||||
dtype=torch_dtype(device),
|
dtype=torch_dtype(device),
|
||||||
device=noise_device_type,
|
device=noise_device_type,
|
||||||
generator=generator,
|
generator=generator,
|
||||||
).to(device)
|
).to("cpu")
|
||||||
|
|
||||||
return noise_tensor
|
return noise_tensor
|
||||||
|
|
||||||
|
@ -415,6 +415,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
#################
|
#################
|
||||||
|
|
||||||
|
latents = latents.to("cpu")
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||||
@ -651,6 +652,7 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
#################
|
#################
|
||||||
|
|
||||||
|
latents = latents.to("cpu")
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||||
|
@ -104,7 +104,8 @@ class ModelCache(object):
|
|||||||
:param sha_chunksize: Chunksize to use when calculating sha256 model hash
|
:param sha_chunksize: Chunksize to use when calculating sha256 model hash
|
||||||
'''
|
'''
|
||||||
self.model_infos: Dict[str, ModelBase] = dict()
|
self.model_infos: Dict[str, ModelBase] = dict()
|
||||||
self.lazy_offloading = lazy_offloading
|
# allow lazy offloading only when vram cache enabled
|
||||||
|
self.lazy_offloading = lazy_offloading and max_vram_cache_size > 0
|
||||||
self.precision: torch.dtype=precision
|
self.precision: torch.dtype=precision
|
||||||
self.max_cache_size: float=max_cache_size
|
self.max_cache_size: float=max_cache_size
|
||||||
self.max_vram_cache_size: float=max_vram_cache_size
|
self.max_vram_cache_size: float=max_vram_cache_size
|
||||||
|
Loading…
x
Reference in New Issue
Block a user