mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Check for cuDNN version compatibility issues on startup. Prior to this check, the app would silently run with ~50% performance degradation caused by a cuDNN version mismatch.
This commit is contained in:
parent
e46c22e41a
commit
86d536755d
@ -1,4 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import logging
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import socket
|
import socket
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
@ -6,6 +7,7 @@ from inspect import signature
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
@ -226,6 +228,22 @@ app.mount(
|
|||||||
) # docs favicon is in here
|
) # docs favicon is in here
|
||||||
|
|
||||||
|
|
||||||
|
def check_cudnn(logger: logging.Logger) -> None:
|
||||||
|
"""Check for cuDNN issues that could be causing degraded performance."""
|
||||||
|
if torch.backends.cudnn.is_available():
|
||||||
|
try:
|
||||||
|
# Note: At the time of writing (torch 2.2.1), torch.backends.cudnn.version() only raises an error the first
|
||||||
|
# time it is called. Subsequent calls will return the version number without complaining about a mismatch.
|
||||||
|
cudnn_version = torch.backends.cudnn.version()
|
||||||
|
logger.info(f"cuDNN version: {cudnn_version}")
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.warning(
|
||||||
|
"Encountered a cuDNN version issue. This may result in degraded performance. This issue is usually "
|
||||||
|
"caused by an incompatible cuDNN version installed in your python environment, or on the host "
|
||||||
|
f"system. Full error message:\n{e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def invoke_api() -> None:
|
def invoke_api() -> None:
|
||||||
def find_port(port: int) -> int:
|
def find_port(port: int) -> int:
|
||||||
"""Find a port not in use starting at given port"""
|
"""Find a port not in use starting at given port"""
|
||||||
@ -252,6 +270,8 @@ def invoke_api() -> None:
|
|||||||
if port != app_config.port:
|
if port != app_config.port:
|
||||||
logger.warn(f"Port {app_config.port} in use, using port {port}")
|
logger.warn(f"Port {app_config.port} in use, using port {port}")
|
||||||
|
|
||||||
|
check_cudnn(logger)
|
||||||
|
|
||||||
# Start our own event loop for eventing usage
|
# Start our own event loop for eventing usage
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
config = uvicorn.Config(
|
config = uvicorn.Config(
|
||||||
|
Loading…
Reference in New Issue
Block a user