Check for cuDNN version compatibility issues on startup. Prior to this check, the app would silently run with ~50% performance degradation caused by a cuDNN version mismatch.

This commit is contained in:
Ryan Dick 2024-03-25 18:52:24 -04:00 committed by psychedelicious
parent e46c22e41a
commit 86d536755d

View File

@ -1,4 +1,5 @@
import asyncio
import logging
import mimetypes
import socket
from contextlib import asynccontextmanager
@ -6,6 +7,7 @@ from inspect import signature
from pathlib import Path
from typing import Any
import torch
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
@ -226,6 +228,22 @@ app.mount(
) # docs favicon is in here
def check_cudnn(logger: logging.Logger) -> None:
"""Check for cuDNN issues that could be causing degraded performance."""
if torch.backends.cudnn.is_available():
try:
# Note: At the time of writing (torch 2.2.1), torch.backends.cudnn.version() only raises an error the first
# time it is called. Subsequent calls will return the version number without complaining about a mismatch.
cudnn_version = torch.backends.cudnn.version()
logger.info(f"cuDNN version: {cudnn_version}")
except RuntimeError as e:
logger.warning(
"Encountered a cuDNN version issue. This may result in degraded performance. This issue is usually "
"caused by an incompatible cuDNN version installed in your python environment, or on the host "
f"system. Full error message:\n{e}"
)
def invoke_api() -> None:
def find_port(port: int) -> int:
"""Find a port not in use starting at given port"""
@ -252,6 +270,8 @@ def invoke_api() -> None:
if port != app_config.port:
logger.warn(f"Port {app_config.port} in use, using port {port}")
check_cudnn(logger)
# Start our own event loop for eventing usage
loop = asyncio.new_event_loop()
config = uvicorn.Config(