Unit tests for HOST settings (#6698)

- CORS
- ALLOWED_HOSTS
This commit is contained in:
Oliver 2024-03-13 20:37:05 +11:00 committed by GitHub
parent 128fa95f71
commit 2fef34852c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -40,6 +40,147 @@ from .tasks import offload_task
from .validators import validate_overage
class HostTest(InvenTreeTestCase):
"""Test for host configuration."""
@override_settings(ALLOWED_HOSTS=['testserver'])
def test_allowed_hosts(self):
"""Test that the ALLOWED_HOSTS functions as expected."""
self.assertIn('testserver', settings.ALLOWED_HOSTS)
response = self.client.get('/api/', headers={'host': 'testserver'})
self.assertEqual(response.status_code, 200)
response = self.client.get('/api/', headers={'host': 'invalidserver'})
self.assertEqual(response.status_code, 400)
@override_settings(ALLOWED_HOSTS=['invalidserver.co.uk'])
def test_allowed_hosts_2(self):
"""Another test for ALLOWED_HOSTS functionality."""
response = self.client.get('/api/', headers={'host': 'invalidserver.co.uk'})
self.assertEqual(response.status_code, 200)
class CorsTest(TestCase):
"""Unit tests for CORS functionality."""
def cors_headers(self):
"""Return a list of CORS headers."""
return [
'access-control-allow-origin',
'access-control-allow-credentials',
'access-control-allow-methods',
'access-control-allow-headers',
]
def preflight(self, url, origin, method='GET'):
"""Make a CORS preflight request to the specified URL."""
headers = {'origin': origin, 'access-control-request-method': method}
return self.client.options(url, headers=headers)
def test_no_origin(self):
"""Test that CORS headers are not included for regular requests.
- We use the /api/ endpoint for this test (it does not require auth)
- By default, in debug mode *all* CORS origins are allowed
"""
# Perform an initial response without the "origin" header
response = self.client.get('/api/')
self.assertEqual(response.status_code, 200)
for header in self.cors_headers():
self.assertNotIn(header, response.headers)
# Now, perform a "preflight" request with the "origin" header
response = self.preflight('/api/', origin='http://random-external-server.com')
self.assertEqual(response.status_code, 200)
for header in self.cors_headers():
self.assertIn(header, response.headers)
self.assertEqual(response.headers['content-length'], '0')
self.assertEqual(
response.headers['access-control-allow-origin'],
'http://random-external-server.com',
)
@override_settings(
CORS_ALLOW_ALL_ORIGINS=False,
CORS_ALLOWED_ORIGINS=['http://my-external-server.com'],
CORS_ALLOWED_ORIGIN_REGEXES=[],
)
def test_auth_view(self):
"""Test that CORS requests work for the /auth/ view.
Here, we are not authorized by default,
but the CORS headers should still be included.
"""
url = '/auth/'
# First, a preflight request with a "valid" origin
response = self.preflight(url, origin='http://my-external-server.com')
self.assertEqual(response.status_code, 200)
for header in self.cors_headers():
self.assertIn(header, response.headers)
# Next, a preflight request with an "invalid" origin
response = self.preflight(url, origin='http://random-external-server.com')
self.assertEqual(response.status_code, 200)
for header in self.cors_headers():
self.assertNotIn(header, response.headers)
# Next, make a GET request (without a token)
response = self.client.get(
url, headers={'origin': 'http://my-external-server.com'}
)
# Unauthorized
self.assertEqual(response.status_code, 401)
self.assertIn('access-control-allow-origin', response.headers)
self.assertNotIn('access-control-allow-methods', response.headers)
@override_settings(
CORS_ALLOW_ALL_ORIGINS=False,
CORS_ALLOWED_ORIGINS=[],
CORS_ALLOWED_ORIGIN_REGEXES=['http://.*myserver.com'],
)
def test_cors_regex(self):
"""Test that CORS regexes work as expected."""
valid_urls = [
'http://www.myserver.com',
'http://test.myserver.com',
'http://myserver.com',
'http://www.myserver.com:8080',
]
invalid_urls = [
'http://myserver.org',
'http://www.other-server.org',
'http://google.com',
'http://myserver.co.uk:8080',
]
for url in valid_urls:
response = self.preflight('/api/', origin=url)
self.assertEqual(response.status_code, 200)
self.assertIn('access-control-allow-origin', response.headers)
for url in invalid_urls:
response = self.preflight('/api/', origin=url)
self.assertEqual(response.status_code, 200)
self.assertNotIn('access-control-allow-origin', response.headers)
class ConversionTest(TestCase):
"""Tests for conversion of physical units."""
@ -912,6 +1053,7 @@ class TestVersionNumber(TestCase):
hash = str(
subprocess.check_output('git rev-parse --short HEAD'.split()), 'utf-8'
).strip()
self.assertEqual(hash, version.inventreeCommitHash())
d = (