mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
model paths fixed, codeformer needs attention
This commit is contained in:
@ -1,5 +1,5 @@
|
||||
import math
|
||||
|
||||
import os.path
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from functools import partial
|
||||
@ -8,6 +8,7 @@ from einops import rearrange, repeat
|
||||
from transformers import CLIPTokenizer, CLIPTextModel
|
||||
import kornia
|
||||
from ldm.invoke.devices import choose_torch_device
|
||||
from ldm.invoke.globals import Globals
|
||||
|
||||
from ldm.modules.x_transformer import (
|
||||
Encoder,
|
||||
@ -98,21 +99,19 @@ class BERTTokenizer(AbstractEncoder):
|
||||
"""Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
|
||||
|
||||
def __init__(
|
||||
self, device=choose_torch_device(), vq_interface=True, max_length=77
|
||||
self, device=choose_torch_device(), vq_interface=True, max_length=77
|
||||
):
|
||||
super().__init__()
|
||||
from transformers import (
|
||||
BertTokenizerFast,
|
||||
) # TODO: add to reuquirements
|
||||
)
|
||||
|
||||
# Modified to allow to run on non-internet connected compute nodes.
|
||||
# Model needs to be loaded into cache from an internet-connected machine
|
||||
# by running:
|
||||
# from transformers import BertTokenizerFast
|
||||
# BertTokenizerFast.from_pretrained("bert-base-uncased")
|
||||
cache = os.path.join(Globals.root,'models/bert-base-uncased')
|
||||
try:
|
||||
self.tokenizer = BertTokenizerFast.from_pretrained(
|
||||
'bert-base-uncased', local_files_only=True
|
||||
'bert-base-uncased',
|
||||
cache_dir=cache,
|
||||
local_files_only=True
|
||||
)
|
||||
except OSError:
|
||||
raise SystemExit(
|
||||
@ -150,14 +149,14 @@ class BERTEmbedder(AbstractEncoder):
|
||||
"""Uses the BERT tokenizr model and add some transformer encoder layers"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_embed,
|
||||
n_layer,
|
||||
vocab_size=30522,
|
||||
max_seq_len=77,
|
||||
device=choose_torch_device(),
|
||||
use_tokenizer=True,
|
||||
embedding_dropout=0.0,
|
||||
self,
|
||||
n_embed,
|
||||
n_layer,
|
||||
vocab_size=30522,
|
||||
max_seq_len=77,
|
||||
device=choose_torch_device(),
|
||||
use_tokenizer=True,
|
||||
embedding_dropout=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_tknz_fn = use_tokenizer
|
||||
@ -245,10 +244,14 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
):
|
||||
super().__init__()
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(
|
||||
version, local_files_only=True
|
||||
version,
|
||||
cache_dir=os.path.join(Globals.root,'models',version),
|
||||
local_files_only=True
|
||||
)
|
||||
self.transformer = CLIPTextModel.from_pretrained(
|
||||
version, local_files_only=True
|
||||
version,
|
||||
cache_dir=os.path.join(Globals.root,'models',version),
|
||||
local_files_only=True
|
||||
)
|
||||
self.device = device
|
||||
self.max_length = max_length
|
||||
|
Reference in New Issue
Block a user