mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
tidy(config): move a few get_config calls to inside the functions where they are needed
This commit is contained in:
parent
f1450c2c24
commit
982b513af3
@ -9,8 +9,6 @@ import invokeai.backend.util.logging as logger
|
|||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config.config_default import get_config
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
from invokeai.backend.util.devices import choose_torch_device
|
||||||
|
|
||||||
config = get_config()
|
|
||||||
|
|
||||||
|
|
||||||
def norm_img(np_img):
|
def norm_img(np_img):
|
||||||
if len(np_img.shape) == 2:
|
if len(np_img.shape) == 2:
|
||||||
|
@ -10,8 +10,6 @@ import numpy as np
|
|||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config.config_default import get_config
|
||||||
|
|
||||||
config = get_config()
|
|
||||||
|
|
||||||
|
|
||||||
class PatchMatch:
|
class PatchMatch:
|
||||||
"""
|
"""
|
||||||
@ -28,7 +26,7 @@ class PatchMatch:
|
|||||||
def _load_patch_match(self):
|
def _load_patch_match(self):
|
||||||
if self.tried_load:
|
if self.tried_load:
|
||||||
return
|
return
|
||||||
if config.patchmatch:
|
if get_config().patchmatch:
|
||||||
from patchmatch import patch_match as pm
|
from patchmatch import patch_match as pm
|
||||||
|
|
||||||
if pm.patchmatch_available:
|
if pm.patchmatch_available:
|
||||||
|
@ -14,8 +14,6 @@ from invokeai.app.services.config.config_default import get_config
|
|||||||
from invokeai.backend.util.devices import choose_torch_device
|
from invokeai.backend.util.devices import choose_torch_device
|
||||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||||
|
|
||||||
config = get_config()
|
|
||||||
|
|
||||||
CHECKER_PATH = "core/convert/stable-diffusion-safety-checker"
|
CHECKER_PATH = "core/convert/stable-diffusion-safety-checker"
|
||||||
|
|
||||||
|
|
||||||
@ -34,8 +32,8 @@ class SafetyChecker:
|
|||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
cls.safety_checker = StableDiffusionSafetyChecker.from_pretrained(config.models_path / CHECKER_PATH)
|
cls.safety_checker = StableDiffusionSafetyChecker.from_pretrained(get_config().models_path / CHECKER_PATH)
|
||||||
cls.feature_extractor = AutoFeatureExtractor.from_pretrained(config.models_path / CHECKER_PATH)
|
cls.feature_extractor = AutoFeatureExtractor.from_pretrained(get_config().models_path / CHECKER_PATH)
|
||||||
logger.info("NSFW checker initialized")
|
logger.info("NSFW checker initialized")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Could not load NSFW checker: {str(e)}")
|
logger.warning(f"Could not load NSFW checker: {str(e)}")
|
||||||
|
@ -37,7 +37,6 @@ from invokeai.app.services.config.config_default import get_config
|
|||||||
|
|
||||||
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
|
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
|
||||||
CLIPSEG_SIZE = 352
|
CLIPSEG_SIZE = 352
|
||||||
config = get_config()
|
|
||||||
|
|
||||||
|
|
||||||
class SegmentedGrayscale(object):
|
class SegmentedGrayscale(object):
|
||||||
@ -78,8 +77,8 @@ class Txt2Mask(object):
|
|||||||
|
|
||||||
# BUG: we are not doing anything with the device option at this time
|
# BUG: we are not doing anything with the device option at this time
|
||||||
self.device = device
|
self.device = device
|
||||||
self.processor = AutoProcessor.from_pretrained(CLIPSEG_MODEL, cache_dir=config.cache_dir)
|
self.processor = AutoProcessor.from_pretrained(CLIPSEG_MODEL, cache_dir=get_config().cache_dir)
|
||||||
self.model = CLIPSegForImageSegmentation.from_pretrained(CLIPSEG_MODEL, cache_dir=config.cache_dir)
|
self.model = CLIPSegForImageSegmentation.from_pretrained(CLIPSEG_MODEL, cache_dir=get_config().cache_dir)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def segment(self, image: Image.Image, prompt: str) -> SegmentedGrayscale:
|
def segment(self, image: Image.Image, prompt: str) -> SegmentedGrayscale:
|
||||||
|
Loading…
Reference in New Issue
Block a user