From 01ff1cff587861941f9822781bcfe80290bcc880 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Sat, 12 Nov 2022 10:11:50 -0800 Subject: [PATCH] model_cache: let offload_model work with DiffusionPipeline, sorta. --- ldm/invoke/model_cache.py | 12 ++++++++---- ldm/modules/encoders/modules.py | 16 ++++++++++++---- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/ldm/invoke/model_cache.py b/ldm/invoke/model_cache.py index 8ad8e3913b..181c4d39b4 100644 --- a/ldm/invoke/model_cache.py +++ b/ldm/invoke/model_cache.py @@ -4,6 +4,7 @@ They are moved between GPU and CPU as necessary. If CPU memory falls below a preset minimum, the least recently used model will be cleared and loaded from disk when next needed. ''' +import warnings from pathlib import Path import torch @@ -413,10 +414,13 @@ class ModelCache(object): def _model_to_cpu(self,model): if self.device != 'cpu': - model.cond_stage_model.device = 'cpu' - model.first_stage_model.to('cpu') - model.cond_stage_model.to('cpu') - model.model.to('cpu') + try: + model.cond_stage_model.device = 'cpu' + model.first_stage_model.to('cpu') + model.cond_stage_model.to('cpu') + model.model.to('cpu') + except AttributeError as e: + warnings.warn(f"TODO: clean up legacy model-management: {e}") return model.to('cpu') else: return model diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index ca9a027f13..cf46051933 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -1,5 +1,7 @@ import math import os.path +from typing import Optional + import torch import torch.nn as nn from functools import partial @@ -235,13 +237,15 @@ class SpatialRescaler(nn.Module): class FrozenCLIPEmbedder(AbstractEncoder): """Uses the CLIP transformer encoder for text (from Hugging Face)""" + tokenizer: CLIPTokenizer + transformer: CLIPTextModel def __init__( self, - version='openai/clip-vit-large-patch14', - max_length=77, - tokenizer=None, - transformer=None, + version:str='openai/clip-vit-large-patch14', + max_length:int=77, + tokenizer:Optional[CLIPTokenizer]=None, + transformer:Optional[CLIPTextModel]=None, ): super().__init__() cache = os.path.join(Globals.root,'models',version) @@ -464,6 +468,10 @@ class FrozenCLIPEmbedder(AbstractEncoder): def device(self): return self.transformer.device + @device.setter + def device(self, device): + self.transformer.to(device=device) + class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder): fragment_weights_key = "fragment_weights"