tidy(config): move a few get_config calls to inside the functions where they are needed

This commit is contained in:
psychedelicious 2024-03-18 11:50:02 +11:00
parent f1450c2c24
commit 982b513af3
4 changed files with 5 additions and 12 deletions

View File

@ -9,8 +9,6 @@ import invokeai.backend.util.logging as logger
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.util.devices import choose_torch_device
config = get_config()
def norm_img(np_img):
if len(np_img.shape) == 2:

View File

@ -10,8 +10,6 @@ import numpy as np
import invokeai.backend.util.logging as logger
from invokeai.app.services.config.config_default import get_config
config = get_config()
class PatchMatch:
"""
@ -28,7 +26,7 @@ class PatchMatch:
def _load_patch_match(self):
if self.tried_load:
return
if config.patchmatch:
if get_config().patchmatch:
from patchmatch import patch_match as pm
if pm.patchmatch_available:

View File

@ -14,8 +14,6 @@ from invokeai.app.services.config.config_default import get_config
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.silence_warnings import SilenceWarnings
config = get_config()
CHECKER_PATH = "core/convert/stable-diffusion-safety-checker"
@ -34,8 +32,8 @@ class SafetyChecker:
return
try:
cls.safety_checker = StableDiffusionSafetyChecker.from_pretrained(config.models_path / CHECKER_PATH)
cls.feature_extractor = AutoFeatureExtractor.from_pretrained(config.models_path / CHECKER_PATH)
cls.safety_checker = StableDiffusionSafetyChecker.from_pretrained(get_config().models_path / CHECKER_PATH)
cls.feature_extractor = AutoFeatureExtractor.from_pretrained(get_config().models_path / CHECKER_PATH)
logger.info("NSFW checker initialized")
except Exception as e:
logger.warning(f"Could not load NSFW checker: {str(e)}")

View File

@ -37,7 +37,6 @@ from invokeai.app.services.config.config_default import get_config
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
CLIPSEG_SIZE = 352
config = get_config()
class SegmentedGrayscale(object):
@ -78,8 +77,8 @@ class Txt2Mask(object):
# BUG: we are not doing anything with the device option at this time
self.device = device
self.processor = AutoProcessor.from_pretrained(CLIPSEG_MODEL, cache_dir=config.cache_dir)
self.model = CLIPSegForImageSegmentation.from_pretrained(CLIPSEG_MODEL, cache_dir=config.cache_dir)
self.processor = AutoProcessor.from_pretrained(CLIPSEG_MODEL, cache_dir=get_config().cache_dir)
self.model = CLIPSegForImageSegmentation.from_pretrained(CLIPSEG_MODEL, cache_dir=get_config().cache_dir)
@torch.no_grad()
def segment(self, image: Image.Image, prompt: str) -> SegmentedGrayscale: