model_cache: let offload_model work with DiffusionPipeline, sorta.

This commit is contained in:
Kevin Turner 2022-11-12 10:11:50 -08:00
parent 95db6e80ee
commit 01ff1cff58
2 changed files with 20 additions and 8 deletions

View File

@ -4,6 +4,7 @@ They are moved between GPU and CPU as necessary. If CPU memory falls
below a preset minimum, the least recently used model will be
cleared and loaded from disk when next needed.
'''
import warnings
from pathlib import Path
import torch
@ -413,10 +414,13 @@ class ModelCache(object):
def _model_to_cpu(self,model):
if self.device != 'cpu':
try:
model.cond_stage_model.device = 'cpu'
model.first_stage_model.to('cpu')
model.cond_stage_model.to('cpu')
model.model.to('cpu')
except AttributeError as e:
warnings.warn(f"TODO: clean up legacy model-management: {e}")
return model.to('cpu')
else:
return model

View File

@ -1,5 +1,7 @@
import math
import os.path
from typing import Optional
import torch
import torch.nn as nn
from functools import partial
@ -235,13 +237,15 @@ class SpatialRescaler(nn.Module):
class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
tokenizer: CLIPTokenizer
transformer: CLIPTextModel
def __init__(
self,
version='openai/clip-vit-large-patch14',
max_length=77,
tokenizer=None,
transformer=None,
version:str='openai/clip-vit-large-patch14',
max_length:int=77,
tokenizer:Optional[CLIPTokenizer]=None,
transformer:Optional[CLIPTextModel]=None,
):
super().__init__()
cache = os.path.join(Globals.root,'models',version)
@ -464,6 +468,10 @@ class FrozenCLIPEmbedder(AbstractEncoder):
def device(self):
return self.transformer.device
@device.setter
def device(self, device):
self.transformer.to(device=device)
class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
fragment_weights_key = "fragment_weights"