mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
address all review comments; needs testing
This commit is contained in:
parent
3a5a8ceba5
commit
f33df25830
@ -221,13 +221,14 @@ class Generate:
|
|||||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
from transformers import AutoFeatureExtractor
|
from transformers import AutoFeatureExtractor
|
||||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||||
|
safety_model_path = os.path.join(Globals.root,'models',safety_model_id)
|
||||||
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id,
|
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id,
|
||||||
local_files_only=True,
|
local_files_only=True,
|
||||||
cache_dir=os.path.join(Globals.root,'models',safety_model_id)
|
cache_dir=safety_model_path,
|
||||||
)
|
)
|
||||||
self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id,
|
self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id,
|
||||||
local_files_only=True,
|
local_files_only=True,
|
||||||
cache_dir=os.path.join(Globals.root,'models',safety_model_id)
|
cache_dir=safety_model_path,
|
||||||
)
|
)
|
||||||
self.safety_checker.to(self.device)
|
self.safety_checker.to(self.device)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -78,7 +78,7 @@ def main():
|
|||||||
# creating a Generate object:
|
# creating a Generate object:
|
||||||
try:
|
try:
|
||||||
gen = Generate(
|
gen = Generate(
|
||||||
conf = os.path.join(Globals.root,opt.conf),
|
conf = opt.conf,
|
||||||
model = opt.model,
|
model = opt.model,
|
||||||
sampler_name = opt.sampler_name,
|
sampler_name = opt.sampler_name,
|
||||||
embedding_path = opt.embedding_path,
|
embedding_path = opt.embedding_path,
|
||||||
@ -129,8 +129,6 @@ def main_loop(gen, opt):
|
|||||||
doneAfterInFile = infile is not None
|
doneAfterInFile = infile is not None
|
||||||
path_filter = re.compile(r'[<>:"/\\|?*]')
|
path_filter = re.compile(r'[<>:"/\\|?*]')
|
||||||
last_results = list()
|
last_results = list()
|
||||||
if not os.path.isabs(opt.conf):
|
|
||||||
opt.conf = os.path.join(Globals.root,opt.conf)
|
|
||||||
model_config = OmegaConf.load(opt.conf)
|
model_config = OmegaConf.load(opt.conf)
|
||||||
|
|
||||||
# The readline completer reads history from the .dream_history file located in the
|
# The readline completer reads history from the .dream_history file located in the
|
||||||
|
@ -15,7 +15,7 @@ class GFPGAN():
|
|||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
if not os.path.isabs(gfpgan_model_path):
|
if not os.path.isabs(gfpgan_model_path):
|
||||||
gfpgan_model_path=os.path.join(Globals.root,gfpgan_model_path)
|
gfpgan_model_path=os.path.abspath(os.path.join(Globals.root,gfpgan_model_path))
|
||||||
self.model_path = gfpgan_model_path
|
self.model_path = gfpgan_model_path
|
||||||
self.gfpgan_model_exists = os.path.isfile(self.model_path)
|
self.gfpgan_model_exists = os.path.isfile(self.model_path)
|
||||||
|
|
||||||
|
@ -243,14 +243,15 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
|||||||
max_length=77,
|
max_length=77,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
cache = os.path.join(Globals.root,'models',version)
|
||||||
self.tokenizer = CLIPTokenizer.from_pretrained(
|
self.tokenizer = CLIPTokenizer.from_pretrained(
|
||||||
version,
|
version,
|
||||||
cache_dir=os.path.join(Globals.root,'models',version),
|
cache_dir=cache,
|
||||||
local_files_only=True
|
local_files_only=True
|
||||||
)
|
)
|
||||||
self.transformer = CLIPTextModel.from_pretrained(
|
self.transformer = CLIPTextModel.from_pretrained(
|
||||||
version,
|
version,
|
||||||
cache_dir=os.path.join(Globals.root,'models',version),
|
cache_dir=cache,
|
||||||
local_files_only=True
|
local_files_only=True
|
||||||
)
|
)
|
||||||
self.device = device
|
self.device = device
|
||||||
|
Loading…
x
Reference in New Issue
Block a user