From 86d536755d48ca8a5d7a887085950a9ed59750ea Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Mon, 25 Mar 2024 18:52:24 -0400 Subject: [PATCH] Check for cuDNN version compatibility issues on startup. Prior to this check, the app would silently run with ~50% performance degradation caused by a cuDNN version mismatch. --- invokeai/app/api_app.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 333b9a58c0..87eaefc020 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -1,4 +1,5 @@ import asyncio +import logging import mimetypes import socket from contextlib import asynccontextmanager @@ -6,6 +7,7 @@ from inspect import signature from pathlib import Path from typing import Any +import torch import uvicorn from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware @@ -226,6 +228,22 @@ app.mount( ) # docs favicon is in here +def check_cudnn(logger: logging.Logger) -> None: + """Check for cuDNN issues that could be causing degraded performance.""" + if torch.backends.cudnn.is_available(): + try: + # Note: At the time of writing (torch 2.2.1), torch.backends.cudnn.version() only raises an error the first + # time it is called. Subsequent calls will return the version number without complaining about a mismatch. + cudnn_version = torch.backends.cudnn.version() + logger.info(f"cuDNN version: {cudnn_version}") + except RuntimeError as e: + logger.warning( + "Encountered a cuDNN version issue. This may result in degraded performance. This issue is usually " + "caused by an incompatible cuDNN version installed in your python environment, or on the host " + f"system. Full error message:\n{e}" + ) + + def invoke_api() -> None: def find_port(port: int) -> int: """Find a port not in use starting at given port""" @@ -252,6 +270,8 @@ def invoke_api() -> None: if port != app_config.port: logger.warn(f"Port {app_config.port} in use, using port {port}") + check_cudnn(logger) + # Start our own event loop for eventing usage loop = asyncio.new_event_loop() config = uvicorn.Config(