Release session if applying ti or lora

This commit is contained in:
Brandon Rising 2023-07-27 15:20:38 -04:00
parent bfdc8c80f3
commit 1ea9ba84f5
2 changed files with 8 additions and 8 deletions

View File

@ -96,7 +96,8 @@ class ONNXPromptInvocation(BaseInvocation):
#import traceback
#print(traceback.format_exc())
print(f"Warn: trigger: \"{trigger}\" not found")
if loras or ti_list:
text_encoder.release_session()
with ONNXModelPatcher.apply_lora_text_encoder(text_encoder, loras),\
ONNXModelPatcher.apply_ti(orig_tokenizer, text_encoder, ti_list) as (tokenizer, ti_manager):
@ -127,7 +128,6 @@ class ONNXPromptInvocation(BaseInvocation):
prompt_embeds = text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
text_encoder.release_session()
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
@ -255,6 +255,8 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
#loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras]
if loras:
unet.release_session()
with ONNXModelPatcher.apply_lora_unet(unet, loras):
# TODO:
_, _, h, w = latents.shape
@ -303,7 +305,6 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
# call the callback, if provided
#if callback is not None and i % callback_steps == 0:
# callback(i, t, latents)
unet.release_session()
torch.cuda.empty_cache()
@ -360,7 +361,6 @@ class ONNXLatentsToImageInvocation(BaseInvocation):
image = image.transpose((0, 2, 3, 1))
image = VaeImageProcessor.numpy_to_pil(image)[0]
vae.release_session()
torch.cuda.empty_cache()

View File

@ -387,7 +387,7 @@ def _calc_model_by_data(model) -> int:
def _calc_onnx_model_by_data(model) -> int:
tensor_size = model.tensors.size()
tensor_size = model.tensors.size() * 2 # The session doubles this
mem = tensor_size # in bytes
return mem
@ -608,9 +608,9 @@ class IAIOnnxRuntimeModel:
# self.io_binding = self.session.io_binding()
def release_session(self):
# self.session = None
# import gc
# gc.collect()
self.session = None
import gc
gc.collect()
return
def __call__(self, **kwargs):