Check for cuDNN version compatibility issues on startup. Prior to this check, the app would silently run with ~50% performance degradation caused by a cuDNN version mismatch.

This commit is contained in:
Ryan Dick 2024-03-25 18:52:24 -04:00 committed by psychedelicious
parent e46c22e41a
commit 86d536755d

View File

@ -1,4 +1,5 @@
import asyncio import asyncio
import logging
import mimetypes import mimetypes
import socket import socket
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
@ -6,6 +7,7 @@ from inspect import signature
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
import torch
import uvicorn import uvicorn
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
@ -226,6 +228,22 @@ app.mount(
) # docs favicon is in here ) # docs favicon is in here
def check_cudnn(logger: logging.Logger) -> None:
"""Check for cuDNN issues that could be causing degraded performance."""
if torch.backends.cudnn.is_available():
try:
# Note: At the time of writing (torch 2.2.1), torch.backends.cudnn.version() only raises an error the first
# time it is called. Subsequent calls will return the version number without complaining about a mismatch.
cudnn_version = torch.backends.cudnn.version()
logger.info(f"cuDNN version: {cudnn_version}")
except RuntimeError as e:
logger.warning(
"Encountered a cuDNN version issue. This may result in degraded performance. This issue is usually "
"caused by an incompatible cuDNN version installed in your python environment, or on the host "
f"system. Full error message:\n{e}"
)
def invoke_api() -> None: def invoke_api() -> None:
def find_port(port: int) -> int: def find_port(port: int) -> int:
"""Find a port not in use starting at given port""" """Find a port not in use starting at given port"""
@ -252,6 +270,8 @@ def invoke_api() -> None:
if port != app_config.port: if port != app_config.port:
logger.warn(f"Port {app_config.port} in use, using port {port}") logger.warn(f"Port {app_config.port} in use, using port {port}")
check_cudnn(logger)
# Start our own event loop for eventing usage # Start our own event loop for eventing usage
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
config = uvicorn.Config( config = uvicorn.Config(