prevent crash when trying to calculate size of missing safety_checker

- Also fixed up order in which logger is created in invokeai-web
  so that handlers are installed after command-line options are
  parsed (and not before!)
This commit is contained in:
Lincoln Stein 2023-06-06 22:57:49 -04:00
parent 1f9e1eb964
commit 04f9757f8d
7 changed files with 33 additions and 30 deletions

View File

@ -3,7 +3,7 @@ import asyncio
from inspect import signature from inspect import signature
import uvicorn import uvicorn
from invokeai.backend.util.logging import InvokeAILogger
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
@ -13,13 +13,18 @@ from fastapi_events.handlers.local import local_handler
from fastapi_events.middleware import EventHandlerASGIMiddleware from fastapi_events.middleware import EventHandlerASGIMiddleware
from pydantic.schema import schema from pydantic.schema import schema
# Do this early so that other modules pick up configuration
from .services.config import InvokeAIAppConfig
app_config = InvokeAIAppConfig.get_config()
app_config.parse_args()
from invokeai.backend.util.logging import InvokeAILogger
logger = InvokeAILogger.getLogger()
from .api.dependencies import ApiDependencies from .api.dependencies import ApiDependencies
from .api.routers import sessions, models, images from .api.routers import sessions, models, images
from .api.sockets import SocketIO from .api.sockets import SocketIO
from .invocations.baseinvocation import BaseInvocation from .invocations.baseinvocation import BaseInvocation
from .services.config import InvokeAIAppConfig
logger = InvokeAILogger.getLogger()
# Create the app # Create the app
# TODO: create this all in a method so configuration/etc. can be passed in? # TODO: create this all in a method so configuration/etc. can be passed in?
@ -37,11 +42,6 @@ app.add_middleware(
socket_io = SocketIO(app) socket_io = SocketIO(app)
# initialize config
# this is a module global
app_config = InvokeAIAppConfig.get_config()
app_config.parse_args()
# Add startup event to load dependencies # Add startup event to load dependencies
@app.on_event("startup") @app.on_event("startup")
async def startup_event(): async def startup_event():

View File

@ -11,7 +11,7 @@ from typing import Union, get_type_hints
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
from pydantic.fields import Field from pydantic.fields import Field
import invokeai.backend.util.logging as logger from invokeai.backend.util.logging import InvokeAILogger
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService from invokeai.app.services.images import ImageService
from invokeai.app.services.metadata import CoreMetadataService from invokeai.app.services.metadata import CoreMetadataService
@ -22,7 +22,6 @@ from .cli.commands import (BaseCommand, CliContext, ExitCli,
SortedHelpFormatter, add_graph_parsers, add_parsers) SortedHelpFormatter, add_graph_parsers, add_parsers)
from .cli.completer import set_autocompleter from .cli.completer import set_autocompleter
from .invocations.baseinvocation import BaseInvocation from .invocations.baseinvocation import BaseInvocation
from .services.config import get_invokeai_config
from .services.default_graphs import (create_system_graphs, from .services.default_graphs import (create_system_graphs,
default_text_to_image_graph_id) default_text_to_image_graph_id)
from .services.events import EventServiceBase from .services.events import EventServiceBase
@ -192,14 +191,11 @@ def invoke_all(context: CliContext):
raise SessionError() raise SessionError()
logger = logger.InvokeAILogger.getLogger()
def invoke_cli(): def invoke_cli():
# this gets the basic configuration # this gets the basic configuration
config = InvokeAIAppConfig.get_config() config = InvokeAIAppConfig.get_config()
config.parse_args() config.parse_args()
logger = InvokeAILogger.getLogger()
# get the optional list of invocations to execute on the command line # get the optional list of invocations to execute on the command line
parser = config.get_parser() parser = config.get_parser()

View File

@ -513,7 +513,7 @@ class LatentsToImageInvocation(BaseInvocation):
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=image_dto.image_name, image_name=image_dto.image_name,
image_type=image_dto.image_type, image_type=image_dto.image_origin,
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,

View File

@ -16,18 +16,19 @@ context. Use like this:
""" """
import contextlib
import gc import gc
import os import os
import sys import sys
import hashlib import hashlib
import json
import warnings import warnings
from contextlib import suppress from contextlib import suppress
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Dict, Sequence, Union, types, Optional, List, Type, Any from typing import Dict, Union, types, Optional, List, Type, Any
import torch import torch
import transformers
from diffusers import DiffusionPipeline, SchedulerMixin, ConfigMixin from diffusers import DiffusionPipeline, SchedulerMixin, ConfigMixin
from diffusers import logging as diffusers_logging from diffusers import logging as diffusers_logging
@ -63,6 +64,11 @@ def calc_model_size_by_fs(
if subfolder is not None: if subfolder is not None:
model_path = os.path.join(model_path, subfolder) model_path = os.path.join(model_path, subfolder)
# this can happen when, for example, the safety checker
# is not downloaded.
if not os.path.exists(model_path):
return 0
all_files = os.listdir(model_path) all_files = os.listdir(model_path)
all_files = [f for f in all_files if os.path.isfile(os.path.join(model_path, f))] all_files = [f for f in all_files if os.path.isfile(os.path.join(model_path, f))]
@ -88,7 +94,7 @@ def calc_model_size_by_fs(
if not file.endswith(index_postfix): if not file.endswith(index_postfix):
continue continue
try: try:
with open(os.path.join(model_path, index_file), "r") as f: with open(os.path.join(model_path, file), "r") as f:
index_data = json.loads(f.read()) index_data = json.loads(f.read())
return int(index_data["metadata"]["total_size"]) return int(index_data["metadata"]["total_size"])
except: except:
@ -277,7 +283,7 @@ class ClassifierModelInfo(ModelInfoBase):
self.child_sizes: Dict[str, int] = dict() self.child_sizes: Dict[str, int] = dict()
try: try:
main_config = EmptyConfigLoader.load_config(repo_id_or_path, config_name="config.json") main_config = EmptyConfigLoader.load_config(self.repo_id_or_path, config_name="config.json")
#main_config = json.loads(os.path.join(self.model_path, "config.json")) #main_config = json.loads(os.path.join(self.model_path, "config.json"))
except: except:
raise Exception("Invalid classifier model! (config.json not found or invalid)") raise Exception("Invalid classifier model! (config.json not found or invalid)")
@ -289,7 +295,7 @@ class ClassifierModelInfo(ModelInfoBase):
def _load_tokenizer(self, main_config: dict): def _load_tokenizer(self, main_config: dict):
try: try:
tokenizer_config = EmptyConfigLoader.load_config(repo_id_or_path, config_name="tokenizer_config.json") tokenizer_config = EmptyConfigLoader.load_config(self.repo_id_or_path, config_name="tokenizer_config.json")
#tokenizer_config = json.loads(os.path.join(self.model_path, "tokenizer_config.json")) #tokenizer_config = json.loads(os.path.join(self.model_path, "tokenizer_config.json"))
except: except:
raise Exception("Invalid classifier model! (Failed to load tokenizer_config.json)") raise Exception("Invalid classifier model! (Failed to load tokenizer_config.json)")
@ -314,13 +320,13 @@ class ClassifierModelInfo(ModelInfoBase):
raise Exception("Invalid classifier model! (Failed to detect text_encoder type)") raise Exception("Invalid classifier model! (Failed to detect text_encoder type)")
self.child_types[SDModelType.TextEncoder] = self._definition_to_type(["transformers", text_encoder_class_name]) self.child_types[SDModelType.TextEncoder] = self._definition_to_type(["transformers", text_encoder_class_name])
self.child_sizes[SDModelType.TextEncoder] = calc_model_size_by_fs(repo_id_or_path) self.child_sizes[SDModelType.TextEncoder] = calc_model_size_by_fs(self.repo_id_or_path)
def _load_feature_extractor(self, main_config: dict): def _load_feature_extractor(self, main_config: dict):
self.child_sizes[SDModelType.FeatureExtractor] = 0 self.child_sizes[SDModelType.FeatureExtractor] = 0
try: try:
feature_extractor_config = EmptyConfigLoader.load_config(repo_id_or_path, config_name="preprocessor_config.json") feature_extractor_config = EmptyConfigLoader.load_config(self.repo_id_or_path, config_name="preprocessor_config.json")
except: except:
return # feature extractor not passed with t5 return # feature extractor not passed with t5

View File

@ -195,13 +195,12 @@ class InvokeAILogger(object):
@classmethod @classmethod
def getLoggers(cls, config: InvokeAIAppConfig) -> list[logging.Handler]: def getLoggers(cls, config: InvokeAIAppConfig) -> list[logging.Handler]:
handler_strs = config.log_handlers handler_strs = config.log_handlers
print(f'handler_strs={handler_strs}')
handlers = list() handlers = list()
for handler in handler_strs: for handler in handler_strs:
handler_name,*args = handler.split('=',2) handler_name,*args = handler.split('=',2)
args = args[0] if len(args) > 0 else None args = args[0] if len(args) > 0 else None
# console is the only handler that gets a custom formatter # console and file are the only handlers that gets a custom formatter
if handler_name=='console': if handler_name=='console':
formatter = LOG_FORMATTERS[config.log_format] formatter = LOG_FORMATTERS[config.log_format]
ch = logging.StreamHandler() ch = logging.StreamHandler()
@ -210,14 +209,16 @@ class InvokeAILogger(object):
elif handler_name=='syslog': elif handler_name=='syslog':
ch = cls._parse_syslog_args(args) ch = cls._parse_syslog_args(args)
ch.setFormatter(InvokeAISyslogFormatter())
handlers.append(ch) handlers.append(ch)
elif handler_name=='file': elif handler_name=='file':
handlers.append(cls._parse_file_args(args)) ch = cls._parse_file_args(args)
ch.setFormatter(InvokeAISyslogFormatter())
handlers.append(ch)
elif handler_name=='http': elif handler_name=='http':
handlers.append(cls._parse_http_args(args)) ch = cls._parse_http_args(args)
handlers.append(ch)
return handlers return handlers
@staticmethod @staticmethod

View File

@ -12,7 +12,7 @@
margin: 0; margin: 0;
} }
</style> </style>
<script type="module" crossorigin src="./assets/index-251c2c6e.js"></script> <script type="module" crossorigin src="./assets/index-88e8dffe.js"></script>
</head> </head>
<body dir="ltr"> <body dir="ltr">

File diff suppressed because one or more lines are too long