mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
tidy(config): move a few get_config calls to inside the functions where they are needed
This commit is contained in:
parent
f1450c2c24
commit
982b513af3
@ -9,8 +9,6 @@ import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
|
||||
config = get_config()
|
||||
|
||||
|
||||
def norm_img(np_img):
|
||||
if len(np_img.shape) == 2:
|
||||
|
@ -10,8 +10,6 @@ import numpy as np
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
|
||||
config = get_config()
|
||||
|
||||
|
||||
class PatchMatch:
|
||||
"""
|
||||
@ -28,7 +26,7 @@ class PatchMatch:
|
||||
def _load_patch_match(self):
|
||||
if self.tried_load:
|
||||
return
|
||||
if config.patchmatch:
|
||||
if get_config().patchmatch:
|
||||
from patchmatch import patch_match as pm
|
||||
|
||||
if pm.patchmatch_available:
|
||||
|
@ -14,8 +14,6 @@ from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
config = get_config()
|
||||
|
||||
CHECKER_PATH = "core/convert/stable-diffusion-safety-checker"
|
||||
|
||||
|
||||
@ -34,8 +32,8 @@ class SafetyChecker:
|
||||
return
|
||||
|
||||
try:
|
||||
cls.safety_checker = StableDiffusionSafetyChecker.from_pretrained(config.models_path / CHECKER_PATH)
|
||||
cls.feature_extractor = AutoFeatureExtractor.from_pretrained(config.models_path / CHECKER_PATH)
|
||||
cls.safety_checker = StableDiffusionSafetyChecker.from_pretrained(get_config().models_path / CHECKER_PATH)
|
||||
cls.feature_extractor = AutoFeatureExtractor.from_pretrained(get_config().models_path / CHECKER_PATH)
|
||||
logger.info("NSFW checker initialized")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load NSFW checker: {str(e)}")
|
||||
|
@ -37,7 +37,6 @@ from invokeai.app.services.config.config_default import get_config
|
||||
|
||||
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
|
||||
CLIPSEG_SIZE = 352
|
||||
config = get_config()
|
||||
|
||||
|
||||
class SegmentedGrayscale(object):
|
||||
@ -78,8 +77,8 @@ class Txt2Mask(object):
|
||||
|
||||
# BUG: we are not doing anything with the device option at this time
|
||||
self.device = device
|
||||
self.processor = AutoProcessor.from_pretrained(CLIPSEG_MODEL, cache_dir=config.cache_dir)
|
||||
self.model = CLIPSegForImageSegmentation.from_pretrained(CLIPSEG_MODEL, cache_dir=config.cache_dir)
|
||||
self.processor = AutoProcessor.from_pretrained(CLIPSEG_MODEL, cache_dir=get_config().cache_dir)
|
||||
self.model = CLIPSegForImageSegmentation.from_pretrained(CLIPSEG_MODEL, cache_dir=get_config().cache_dir)
|
||||
|
||||
@torch.no_grad()
|
||||
def segment(self, image: Image.Image, prompt: str) -> SegmentedGrayscale:
|
||||
|
Loading…
Reference in New Issue
Block a user