mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Release session if applying ti or lora
This commit is contained in:
parent
bfdc8c80f3
commit
1ea9ba84f5
@ -96,7 +96,8 @@ class ONNXPromptInvocation(BaseInvocation):
|
|||||||
#import traceback
|
#import traceback
|
||||||
#print(traceback.format_exc())
|
#print(traceback.format_exc())
|
||||||
print(f"Warn: trigger: \"{trigger}\" not found")
|
print(f"Warn: trigger: \"{trigger}\" not found")
|
||||||
|
if loras or ti_list:
|
||||||
|
text_encoder.release_session()
|
||||||
with ONNXModelPatcher.apply_lora_text_encoder(text_encoder, loras),\
|
with ONNXModelPatcher.apply_lora_text_encoder(text_encoder, loras),\
|
||||||
ONNXModelPatcher.apply_ti(orig_tokenizer, text_encoder, ti_list) as (tokenizer, ti_manager):
|
ONNXModelPatcher.apply_ti(orig_tokenizer, text_encoder, ti_list) as (tokenizer, ti_manager):
|
||||||
|
|
||||||
@ -127,7 +128,6 @@ class ONNXPromptInvocation(BaseInvocation):
|
|||||||
|
|
||||||
prompt_embeds = text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
|
prompt_embeds = text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
|
||||||
|
|
||||||
text_encoder.release_session()
|
|
||||||
|
|
||||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||||
|
|
||||||
@ -255,6 +255,8 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
|||||||
#loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
#loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
||||||
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras]
|
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras]
|
||||||
|
|
||||||
|
if loras:
|
||||||
|
unet.release_session()
|
||||||
with ONNXModelPatcher.apply_lora_unet(unet, loras):
|
with ONNXModelPatcher.apply_lora_unet(unet, loras):
|
||||||
# TODO:
|
# TODO:
|
||||||
_, _, h, w = latents.shape
|
_, _, h, w = latents.shape
|
||||||
@ -303,7 +305,6 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
|||||||
# call the callback, if provided
|
# call the callback, if provided
|
||||||
#if callback is not None and i % callback_steps == 0:
|
#if callback is not None and i % callback_steps == 0:
|
||||||
# callback(i, t, latents)
|
# callback(i, t, latents)
|
||||||
unet.release_session()
|
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
@ -360,7 +361,6 @@ class ONNXLatentsToImageInvocation(BaseInvocation):
|
|||||||
image = image.transpose((0, 2, 3, 1))
|
image = image.transpose((0, 2, 3, 1))
|
||||||
image = VaeImageProcessor.numpy_to_pil(image)[0]
|
image = VaeImageProcessor.numpy_to_pil(image)[0]
|
||||||
|
|
||||||
vae.release_session()
|
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
@ -387,7 +387,7 @@ def _calc_model_by_data(model) -> int:
|
|||||||
|
|
||||||
|
|
||||||
def _calc_onnx_model_by_data(model) -> int:
|
def _calc_onnx_model_by_data(model) -> int:
|
||||||
tensor_size = model.tensors.size()
|
tensor_size = model.tensors.size() * 2 # The session doubles this
|
||||||
mem = tensor_size # in bytes
|
mem = tensor_size # in bytes
|
||||||
return mem
|
return mem
|
||||||
|
|
||||||
@ -608,9 +608,9 @@ class IAIOnnxRuntimeModel:
|
|||||||
# self.io_binding = self.session.io_binding()
|
# self.io_binding = self.session.io_binding()
|
||||||
|
|
||||||
def release_session(self):
|
def release_session(self):
|
||||||
# self.session = None
|
self.session = None
|
||||||
# import gc
|
import gc
|
||||||
# gc.collect()
|
gc.collect()
|
||||||
return
|
return
|
||||||
|
|
||||||
def __call__(self, **kwargs):
|
def __call__(self, **kwargs):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user