mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Release session if applying ti or lora
This commit is contained in:
parent
bfdc8c80f3
commit
1ea9ba84f5
@ -96,7 +96,8 @@ class ONNXPromptInvocation(BaseInvocation):
|
||||
#import traceback
|
||||
#print(traceback.format_exc())
|
||||
print(f"Warn: trigger: \"{trigger}\" not found")
|
||||
|
||||
if loras or ti_list:
|
||||
text_encoder.release_session()
|
||||
with ONNXModelPatcher.apply_lora_text_encoder(text_encoder, loras),\
|
||||
ONNXModelPatcher.apply_ti(orig_tokenizer, text_encoder, ti_list) as (tokenizer, ti_manager):
|
||||
|
||||
@ -127,7 +128,6 @@ class ONNXPromptInvocation(BaseInvocation):
|
||||
|
||||
prompt_embeds = text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
|
||||
|
||||
text_encoder.release_session()
|
||||
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
|
||||
@ -255,6 +255,8 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
#loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
||||
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras]
|
||||
|
||||
if loras:
|
||||
unet.release_session()
|
||||
with ONNXModelPatcher.apply_lora_unet(unet, loras):
|
||||
# TODO:
|
||||
_, _, h, w = latents.shape
|
||||
@ -303,7 +305,6 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
# call the callback, if provided
|
||||
#if callback is not None and i % callback_steps == 0:
|
||||
# callback(i, t, latents)
|
||||
unet.release_session()
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@ -360,7 +361,6 @@ class ONNXLatentsToImageInvocation(BaseInvocation):
|
||||
image = image.transpose((0, 2, 3, 1))
|
||||
image = VaeImageProcessor.numpy_to_pil(image)[0]
|
||||
|
||||
vae.release_session()
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
@ -387,7 +387,7 @@ def _calc_model_by_data(model) -> int:
|
||||
|
||||
|
||||
def _calc_onnx_model_by_data(model) -> int:
|
||||
tensor_size = model.tensors.size()
|
||||
tensor_size = model.tensors.size() * 2 # The session doubles this
|
||||
mem = tensor_size # in bytes
|
||||
return mem
|
||||
|
||||
@ -608,9 +608,9 @@ class IAIOnnxRuntimeModel:
|
||||
# self.io_binding = self.session.io_binding()
|
||||
|
||||
def release_session(self):
|
||||
# self.session = None
|
||||
# import gc
|
||||
# gc.collect()
|
||||
self.session = None
|
||||
import gc
|
||||
gc.collect()
|
||||
return
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
|
Loading…
Reference in New Issue
Block a user