Merge branch 'main' into lstein-import-safetensors

This commit is contained in:
Lincoln Stein 2023-01-18 17:09:48 -05:00
commit e11f15cf78
2 changed files with 38 additions and 16 deletions

View File

@ -66,7 +66,17 @@ def global_cache_dir(subdir:Union[str,Path]='')->Path:
global_cache_dir('diffusers') global_cache_dir('diffusers')
global_cache_dir('transformers') global_cache_dir('transformers')
''' '''
if (home := os.environ.get('HF_HOME')): home: str = os.getenv('HF_HOME')
if home is None:
home = os.getenv('XDG_CACHE_HOME')
if home is not None:
# Set `home` to $XDG_CACHE_HOME/huggingface, which is the default location mentioned in HuggingFace Hub Client Library.
# See: https://huggingface.co/docs/huggingface_hub/main/en/package_reference/environment_variables#xdgcachehome
home += os.sep + 'huggingface'
if home is not None:
return Path(home,subdir) return Path(home,subdir)
else: else:
return Path(Globals.root,'models',subdir) return Path(Globals.root,'models',subdir)

View File

@ -764,7 +764,11 @@ class ModelManager(object):
from shutil import move, rmtree from shutil import move, rmtree
# transformer files get moved into the hub directory # transformer files get moved into the hub directory
if cls._is_huggingface_hub_directory_present():
hub = global_cache_dir('hub')
else:
hub = models_dir / 'hub' hub = models_dir / 'hub'
os.makedirs(hub, exist_ok=True) os.makedirs(hub, exist_ok=True)
for model in legacy_locations: for model in legacy_locations:
source = models_dir / model source = models_dir / model
@ -777,7 +781,11 @@ class ModelManager(object):
move(source, dest) move(source, dest)
# anything else gets moved into the diffusers directory # anything else gets moved into the diffusers directory
if cls._is_huggingface_hub_directory_present():
diffusers = global_cache_dir('diffusers')
else:
diffusers = models_dir / 'diffusers' diffusers = models_dir / 'diffusers'
os.makedirs(diffusers, exist_ok=True) os.makedirs(diffusers, exist_ok=True)
for root, dirs, _ in os.walk(models_dir, topdown=False): for root, dirs, _ in os.walk(models_dir, topdown=False):
for dir in dirs: for dir in dirs:
@ -968,3 +976,7 @@ class ModelManager(object):
print(f'** Could not load VAE {name_or_path}: {str(deferred_error)}') print(f'** Could not load VAE {name_or_path}: {str(deferred_error)}')
return vae return vae
@staticmethod
def _is_huggingface_hub_directory_present() -> bool:
return os.getenv('HF_HOME') is not None or os.getenv('XDG_CACHE_HOME') is not None