model paths fixed, codeformer needs attention

This commit is contained in:
Lincoln Stein
2022-11-15 01:53:10 +00:00
parent 4c035ad4ae
commit 274b276133
11 changed files with 180 additions and 75 deletions

View File

@ -1,5 +1,5 @@
import math
import os.path
import torch
import torch.nn as nn
from functools import partial
@ -8,6 +8,7 @@ from einops import rearrange, repeat
from transformers import CLIPTokenizer, CLIPTextModel
import kornia
from ldm.invoke.devices import choose_torch_device
from ldm.invoke.globals import Globals
from ldm.modules.x_transformer import (
Encoder,
@ -98,21 +99,19 @@ class BERTTokenizer(AbstractEncoder):
"""Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
def __init__(
self, device=choose_torch_device(), vq_interface=True, max_length=77
self, device=choose_torch_device(), vq_interface=True, max_length=77
):
super().__init__()
from transformers import (
BertTokenizerFast,
) # TODO: add to reuquirements
)
# Modified to allow to run on non-internet connected compute nodes.
# Model needs to be loaded into cache from an internet-connected machine
# by running:
# from transformers import BertTokenizerFast
# BertTokenizerFast.from_pretrained("bert-base-uncased")
cache = os.path.join(Globals.root,'models/bert-base-uncased')
try:
self.tokenizer = BertTokenizerFast.from_pretrained(
'bert-base-uncased', local_files_only=True
'bert-base-uncased',
cache_dir=cache,
local_files_only=True
)
except OSError:
raise SystemExit(
@ -150,14 +149,14 @@ class BERTEmbedder(AbstractEncoder):
"""Uses the BERT tokenizr model and add some transformer encoder layers"""
def __init__(
self,
n_embed,
n_layer,
vocab_size=30522,
max_seq_len=77,
device=choose_torch_device(),
use_tokenizer=True,
embedding_dropout=0.0,
self,
n_embed,
n_layer,
vocab_size=30522,
max_seq_len=77,
device=choose_torch_device(),
use_tokenizer=True,
embedding_dropout=0.0,
):
super().__init__()
self.use_tknz_fn = use_tokenizer
@ -245,10 +244,14 @@ class FrozenCLIPEmbedder(AbstractEncoder):
):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(
version, local_files_only=True
version,
cache_dir=os.path.join(Globals.root,'models',version),
local_files_only=True
)
self.transformer = CLIPTextModel.from_pretrained(
version, local_files_only=True
version,
cache_dir=os.path.join(Globals.root,'models',version),
local_files_only=True
)
self.device = device
self.max_length = max_length