address all review comments; needs testing

This commit is contained in:
Lincoln Stein 2022-11-18 15:25:23 -05:00
parent 3a5a8ceba5
commit f33df25830
4 changed files with 9 additions and 9 deletions

View File

@ -221,13 +221,14 @@ class Generate:
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor
safety_model_id = "CompVis/stable-diffusion-safety-checker"
safety_model_path = os.path.join(Globals.root,'models',safety_model_id)
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id,
local_files_only=True,
cache_dir=os.path.join(Globals.root,'models',safety_model_id)
cache_dir=safety_model_path,
)
self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id,
local_files_only=True,
cache_dir=os.path.join(Globals.root,'models',safety_model_id)
cache_dir=safety_model_path,
)
self.safety_checker.to(self.device)
except Exception:

View File

@ -60,7 +60,7 @@ def main():
# normalize the config directory relative to root
if not os.path.isabs(opt.conf):
opt.conf=os.path.normpath(os.path.join(Globals.root,opt.conf))
opt.conf = os.path.normpath(os.path.join(Globals.root,opt.conf))
# load the infile as a list of lines
if opt.infile:
@ -78,7 +78,7 @@ def main():
# creating a Generate object:
try:
gen = Generate(
conf = os.path.join(Globals.root,opt.conf),
conf = opt.conf,
model = opt.model,
sampler_name = opt.sampler_name,
embedding_path = opt.embedding_path,
@ -129,8 +129,6 @@ def main_loop(gen, opt):
doneAfterInFile = infile is not None
path_filter = re.compile(r'[<>:"/\\|?*]')
last_results = list()
if not os.path.isabs(opt.conf):
opt.conf = os.path.join(Globals.root,opt.conf)
model_config = OmegaConf.load(opt.conf)
# The readline completer reads history from the .dream_history file located in the

View File

@ -15,7 +15,7 @@ class GFPGAN():
) -> None:
if not os.path.isabs(gfpgan_model_path):
gfpgan_model_path=os.path.join(Globals.root,gfpgan_model_path)
gfpgan_model_path=os.path.abspath(os.path.join(Globals.root,gfpgan_model_path))
self.model_path = gfpgan_model_path
self.gfpgan_model_exists = os.path.isfile(self.model_path)

View File

@ -243,14 +243,15 @@ class FrozenCLIPEmbedder(AbstractEncoder):
max_length=77,
):
super().__init__()
cache = os.path.join(Globals.root,'models',version)
self.tokenizer = CLIPTokenizer.from_pretrained(
version,
cache_dir=os.path.join(Globals.root,'models',version),
cache_dir=cache,
local_files_only=True
)
self.transformer = CLIPTextModel.from_pretrained(
version,
cache_dir=os.path.join(Globals.root,'models',version),
cache_dir=cache,
local_files_only=True
)
self.device = device