prevent crash when trying to calculate size of missing safety_checker

- Also fixed up order in which logger is created in invokeai-web
  so that handlers are installed after command-line options are
  parsed (and not before!)
This commit is contained in:
Lincoln Stein
2023-06-06 22:57:49 -04:00
parent 1f9e1eb964
commit 04f9757f8d
7 changed files with 33 additions and 30 deletions

View File

@ -16,18 +16,19 @@ context. Use like this:
"""
import contextlib
import gc
import os
import sys
import hashlib
import json
import warnings
from contextlib import suppress
from enum import Enum
from pathlib import Path
from typing import Dict, Sequence, Union, types, Optional, List, Type, Any
from typing import Dict, Union, types, Optional, List, Type, Any
import torch
import transformers
from diffusers import DiffusionPipeline, SchedulerMixin, ConfigMixin
from diffusers import logging as diffusers_logging
@ -63,6 +64,11 @@ def calc_model_size_by_fs(
if subfolder is not None:
model_path = os.path.join(model_path, subfolder)
# this can happen when, for example, the safety checker
# is not downloaded.
if not os.path.exists(model_path):
return 0
all_files = os.listdir(model_path)
all_files = [f for f in all_files if os.path.isfile(os.path.join(model_path, f))]
@ -88,7 +94,7 @@ def calc_model_size_by_fs(
if not file.endswith(index_postfix):
continue
try:
with open(os.path.join(model_path, index_file), "r") as f:
with open(os.path.join(model_path, file), "r") as f:
index_data = json.loads(f.read())
return int(index_data["metadata"]["total_size"])
except:
@ -277,7 +283,7 @@ class ClassifierModelInfo(ModelInfoBase):
self.child_sizes: Dict[str, int] = dict()
try:
main_config = EmptyConfigLoader.load_config(repo_id_or_path, config_name="config.json")
main_config = EmptyConfigLoader.load_config(self.repo_id_or_path, config_name="config.json")
#main_config = json.loads(os.path.join(self.model_path, "config.json"))
except:
raise Exception("Invalid classifier model! (config.json not found or invalid)")
@ -289,7 +295,7 @@ class ClassifierModelInfo(ModelInfoBase):
def _load_tokenizer(self, main_config: dict):
try:
tokenizer_config = EmptyConfigLoader.load_config(repo_id_or_path, config_name="tokenizer_config.json")
tokenizer_config = EmptyConfigLoader.load_config(self.repo_id_or_path, config_name="tokenizer_config.json")
#tokenizer_config = json.loads(os.path.join(self.model_path, "tokenizer_config.json"))
except:
raise Exception("Invalid classifier model! (Failed to load tokenizer_config.json)")
@ -314,13 +320,13 @@ class ClassifierModelInfo(ModelInfoBase):
raise Exception("Invalid classifier model! (Failed to detect text_encoder type)")
self.child_types[SDModelType.TextEncoder] = self._definition_to_type(["transformers", text_encoder_class_name])
self.child_sizes[SDModelType.TextEncoder] = calc_model_size_by_fs(repo_id_or_path)
self.child_sizes[SDModelType.TextEncoder] = calc_model_size_by_fs(self.repo_id_or_path)
def _load_feature_extractor(self, main_config: dict):
self.child_sizes[SDModelType.FeatureExtractor] = 0
try:
feature_extractor_config = EmptyConfigLoader.load_config(repo_id_or_path, config_name="preprocessor_config.json")
feature_extractor_config = EmptyConfigLoader.load_config(self.repo_id_or_path, config_name="preprocessor_config.json")
except:
return # feature extractor not passed with t5

View File

@ -195,13 +195,12 @@ class InvokeAILogger(object):
@classmethod
def getLoggers(cls, config: InvokeAIAppConfig) -> list[logging.Handler]:
handler_strs = config.log_handlers
print(f'handler_strs={handler_strs}')
handlers = list()
for handler in handler_strs:
handler_name,*args = handler.split('=',2)
args = args[0] if len(args) > 0 else None
# console is the only handler that gets a custom formatter
# console and file are the only handlers that gets a custom formatter
if handler_name=='console':
formatter = LOG_FORMATTERS[config.log_format]
ch = logging.StreamHandler()
@ -210,14 +209,16 @@ class InvokeAILogger(object):
elif handler_name=='syslog':
ch = cls._parse_syslog_args(args)
ch.setFormatter(InvokeAISyslogFormatter())
handlers.append(ch)
elif handler_name=='file':
handlers.append(cls._parse_file_args(args))
ch = cls._parse_file_args(args)
ch.setFormatter(InvokeAISyslogFormatter())
handlers.append(ch)
elif handler_name=='http':
handlers.append(cls._parse_http_args(args))
ch = cls._parse_http_args(args)
handlers.append(ch)
return handlers
@staticmethod