fix mps crash with safety checker

This commit is contained in:
Lincoln Stein 2022-10-30 16:54:06 -04:00
parent 330b417a7b
commit 23d54ee69e
2 changed files with 2 additions and 0 deletions

View File

@ -217,6 +217,7 @@ class Generate:
safety_model_id = "CompVis/stable-diffusion-safety-checker"
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id, local_files_only=True)
self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id, local_files_only=True)
self.safety_checker.to(self.device)
except Exception:
print('** An error was encountered while installing the safety checker:')
print(traceback.format_exc())

View File

@ -197,6 +197,7 @@ class Generator():
checker = self.safety_checker['checker']
extractor = self.safety_checker['extractor']
features = extractor([image], return_tensors="pt")
features.to(self.model.device)
# unfortunately checker requires the numpy version, so we have to convert back
x_image = np.array(image).astype(np.float32) / 255.0