mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
prevent crash when trying to calculate size of missing safety_checker
- Also fixed up order in which logger is created in invokeai-web so that handlers are installed after command-line options are parsed (and not before!)
This commit is contained in:
@ -16,18 +16,19 @@ context. Use like this:
|
||||
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import gc
|
||||
import os
|
||||
import sys
|
||||
import hashlib
|
||||
import json
|
||||
import warnings
|
||||
from contextlib import suppress
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Dict, Sequence, Union, types, Optional, List, Type, Any
|
||||
from typing import Dict, Union, types, Optional, List, Type, Any
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
from diffusers import DiffusionPipeline, SchedulerMixin, ConfigMixin
|
||||
from diffusers import logging as diffusers_logging
|
||||
@ -63,6 +64,11 @@ def calc_model_size_by_fs(
|
||||
if subfolder is not None:
|
||||
model_path = os.path.join(model_path, subfolder)
|
||||
|
||||
# this can happen when, for example, the safety checker
|
||||
# is not downloaded.
|
||||
if not os.path.exists(model_path):
|
||||
return 0
|
||||
|
||||
all_files = os.listdir(model_path)
|
||||
all_files = [f for f in all_files if os.path.isfile(os.path.join(model_path, f))]
|
||||
|
||||
@ -88,7 +94,7 @@ def calc_model_size_by_fs(
|
||||
if not file.endswith(index_postfix):
|
||||
continue
|
||||
try:
|
||||
with open(os.path.join(model_path, index_file), "r") as f:
|
||||
with open(os.path.join(model_path, file), "r") as f:
|
||||
index_data = json.loads(f.read())
|
||||
return int(index_data["metadata"]["total_size"])
|
||||
except:
|
||||
@ -277,7 +283,7 @@ class ClassifierModelInfo(ModelInfoBase):
|
||||
self.child_sizes: Dict[str, int] = dict()
|
||||
|
||||
try:
|
||||
main_config = EmptyConfigLoader.load_config(repo_id_or_path, config_name="config.json")
|
||||
main_config = EmptyConfigLoader.load_config(self.repo_id_or_path, config_name="config.json")
|
||||
#main_config = json.loads(os.path.join(self.model_path, "config.json"))
|
||||
except:
|
||||
raise Exception("Invalid classifier model! (config.json not found or invalid)")
|
||||
@ -289,7 +295,7 @@ class ClassifierModelInfo(ModelInfoBase):
|
||||
|
||||
def _load_tokenizer(self, main_config: dict):
|
||||
try:
|
||||
tokenizer_config = EmptyConfigLoader.load_config(repo_id_or_path, config_name="tokenizer_config.json")
|
||||
tokenizer_config = EmptyConfigLoader.load_config(self.repo_id_or_path, config_name="tokenizer_config.json")
|
||||
#tokenizer_config = json.loads(os.path.join(self.model_path, "tokenizer_config.json"))
|
||||
except:
|
||||
raise Exception("Invalid classifier model! (Failed to load tokenizer_config.json)")
|
||||
@ -314,13 +320,13 @@ class ClassifierModelInfo(ModelInfoBase):
|
||||
raise Exception("Invalid classifier model! (Failed to detect text_encoder type)")
|
||||
|
||||
self.child_types[SDModelType.TextEncoder] = self._definition_to_type(["transformers", text_encoder_class_name])
|
||||
self.child_sizes[SDModelType.TextEncoder] = calc_model_size_by_fs(repo_id_or_path)
|
||||
self.child_sizes[SDModelType.TextEncoder] = calc_model_size_by_fs(self.repo_id_or_path)
|
||||
|
||||
|
||||
def _load_feature_extractor(self, main_config: dict):
|
||||
self.child_sizes[SDModelType.FeatureExtractor] = 0
|
||||
try:
|
||||
feature_extractor_config = EmptyConfigLoader.load_config(repo_id_or_path, config_name="preprocessor_config.json")
|
||||
feature_extractor_config = EmptyConfigLoader.load_config(self.repo_id_or_path, config_name="preprocessor_config.json")
|
||||
except:
|
||||
return # feature extractor not passed with t5
|
||||
|
||||
|
@ -195,13 +195,12 @@ class InvokeAILogger(object):
|
||||
@classmethod
|
||||
def getLoggers(cls, config: InvokeAIAppConfig) -> list[logging.Handler]:
|
||||
handler_strs = config.log_handlers
|
||||
print(f'handler_strs={handler_strs}')
|
||||
handlers = list()
|
||||
for handler in handler_strs:
|
||||
handler_name,*args = handler.split('=',2)
|
||||
args = args[0] if len(args) > 0 else None
|
||||
|
||||
# console is the only handler that gets a custom formatter
|
||||
# console and file are the only handlers that gets a custom formatter
|
||||
if handler_name=='console':
|
||||
formatter = LOG_FORMATTERS[config.log_format]
|
||||
ch = logging.StreamHandler()
|
||||
@ -210,14 +209,16 @@ class InvokeAILogger(object):
|
||||
|
||||
elif handler_name=='syslog':
|
||||
ch = cls._parse_syslog_args(args)
|
||||
ch.setFormatter(InvokeAISyslogFormatter())
|
||||
handlers.append(ch)
|
||||
|
||||
elif handler_name=='file':
|
||||
handlers.append(cls._parse_file_args(args))
|
||||
ch = cls._parse_file_args(args)
|
||||
ch.setFormatter(InvokeAISyslogFormatter())
|
||||
handlers.append(ch)
|
||||
|
||||
elif handler_name=='http':
|
||||
handlers.append(cls._parse_http_args(args))
|
||||
ch = cls._parse_http_args(args)
|
||||
handlers.append(ch)
|
||||
return handlers
|
||||
|
||||
@staticmethod
|
||||
|
Reference in New Issue
Block a user