model_cache: let offload_model work with DiffusionPipeline, sorta.

This commit is contained in:
Kevin Turner 2022-11-12 10:11:50 -08:00
parent 95db6e80ee
commit 01ff1cff58
2 changed files with 20 additions and 8 deletions

View File

@ -4,6 +4,7 @@ They are moved between GPU and CPU as necessary. If CPU memory falls
below a preset minimum, the least recently used model will be below a preset minimum, the least recently used model will be
cleared and loaded from disk when next needed. cleared and loaded from disk when next needed.
''' '''
import warnings
from pathlib import Path from pathlib import Path
import torch import torch
@ -413,10 +414,13 @@ class ModelCache(object):
def _model_to_cpu(self,model): def _model_to_cpu(self,model):
if self.device != 'cpu': if self.device != 'cpu':
model.cond_stage_model.device = 'cpu' try:
model.first_stage_model.to('cpu') model.cond_stage_model.device = 'cpu'
model.cond_stage_model.to('cpu') model.first_stage_model.to('cpu')
model.model.to('cpu') model.cond_stage_model.to('cpu')
model.model.to('cpu')
except AttributeError as e:
warnings.warn(f"TODO: clean up legacy model-management: {e}")
return model.to('cpu') return model.to('cpu')
else: else:
return model return model

View File

@ -1,5 +1,7 @@
import math import math
import os.path import os.path
from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from functools import partial from functools import partial
@ -235,13 +237,15 @@ class SpatialRescaler(nn.Module):
class FrozenCLIPEmbedder(AbstractEncoder): class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from Hugging Face)""" """Uses the CLIP transformer encoder for text (from Hugging Face)"""
tokenizer: CLIPTokenizer
transformer: CLIPTextModel
def __init__( def __init__(
self, self,
version='openai/clip-vit-large-patch14', version:str='openai/clip-vit-large-patch14',
max_length=77, max_length:int=77,
tokenizer=None, tokenizer:Optional[CLIPTokenizer]=None,
transformer=None, transformer:Optional[CLIPTextModel]=None,
): ):
super().__init__() super().__init__()
cache = os.path.join(Globals.root,'models',version) cache = os.path.join(Globals.root,'models',version)
@ -464,6 +468,10 @@ class FrozenCLIPEmbedder(AbstractEncoder):
def device(self): def device(self):
return self.transformer.device return self.transformer.device
@device.setter
def device(self, device):
self.transformer.to(device=device)
class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder): class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
fragment_weights_key = "fragment_weights" fragment_weights_key = "fragment_weights"