Merge pull request #735 from SchrodingersGat/token-auth-fix

Improvements for token authentication
This commit is contained in:
Oliver 2020-04-18 12:12:54 +10:00 committed by GitHub
commit 4f3f78f55a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 56 additions and 60 deletions

View File

@ -4,9 +4,10 @@ from rest_framework.test import APITestCase
from rest_framework import status from rest_framework import status
from django.urls import reverse from django.urls import reverse
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from base64 import b64encode
class APITests(APITestCase): class APITests(APITestCase):
""" Tests for the InvenTree API """ """ Tests for the InvenTree API """
@ -21,24 +22,48 @@ class APITests(APITestCase):
username = 'test_user' username = 'test_user'
password = 'test_pass' password = 'test_pass'
token = None
def setUp(self): def setUp(self):
# Create a user (but do not log in!) # Create a user (but do not log in!)
User = get_user_model() User = get_user_model()
User.objects.create_user(self.username, 'user@email.com', self.password) User.objects.create_user(self.username, 'user@email.com', self.password)
def get_token(self): def basicAuth(self):
# Use basic authentication
authstring = bytes("{u}:{p}".format(u=self.username, p=self.password), "ascii")
# Use "basic" auth by default
auth = b64encode(authstring).decode("ascii")
self.client.credentials(HTTP_AUTHORIZATION="Basic {auth}".format(auth=auth))
def tokenAuth(self):
self.basicAuth()
token_url = reverse('api-token') token_url = reverse('api-token')
response = self.client.get(token_url, format='json', data={})
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn('token', response.data)
# POST to retreive a token
response = self.client.post(token_url, format='json', data={'username': self.username, 'password': self.password})
token = response.data['token'] token = response.data['token']
self.client.credentials(HTTP_AUTHORIZATION='Token ' + token)
self.token = token self.token = token
def token_failure(self):
# Test token endpoint without basic auth
url = reverse('api-token')
response = self.client.get(url, format='json')
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
self.assertIsNone(self.token)
def token_success(self):
self.tokenAuth()
self.assertIsNotNone(self.token)
def test_info_view(self): def test_info_view(self):
""" """
Test that we can read the 'info-view' endpoint. Test that we can read the 'info-view' endpoint.
@ -55,51 +80,18 @@ class APITests(APITestCase):
self.assertEquals('InvenTree', data['server']) self.assertEquals('InvenTree', data['server'])
def test_get_token_fail(self): def test_barcode_fail(self):
""" Ensure that an invalid user cannot get a token """ # Test barcode endpoint without auth
response = self.client.post(reverse('api-barcode-plugin'), format='json')
token_url = reverse('api-token')
response = self.client.post(token_url, format='json', data={'username': 'bad', 'password': 'also_bad'})
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertFalse('token' in response.data)
def test_get_token_pass(self):
""" Ensure that a valid user can request an API token """
token_url = reverse('api-token')
# POST to retreive a token
response = self.client.post(token_url, format='json', data={'username': self.username, 'password': self.password})
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertTrue('token' in response.data)
self.assertTrue('pk' in response.data)
self.assertTrue(len(response.data['token']) > 0)
# Now, use the token to access other data
token = response.data['token']
part_url = reverse('api-part-list')
# Try to access without a token
response = self.client.get(part_url, format='json')
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
# Now, with the token
self.client.credentials(HTTP_AUTHORIZATION='Token ' + token)
response = self.client.get(part_url, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_barcode(self): def test_barcode(self):
""" Test the barcode endpoint """ """ Test the barcode endpoint """
url = reverse('api-barcode-plugin') self.tokenAuth()
self.get_token() url = reverse('api-barcode-plugin')
data = { data = {
'barcode': { 'barcode': {

View File

@ -3,7 +3,7 @@ from django.contrib.auth.models import User
from django.core.exceptions import ObjectDoesNotExist from django.core.exceptions import ObjectDoesNotExist
from .serializers import UserSerializer from .serializers import UserSerializer
from rest_framework.authtoken.views import ObtainAuthToken from rest_framework.views import APIView
from rest_framework.authtoken.models import Token from rest_framework.authtoken.models import Token
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework import status from rest_framework import status
@ -25,28 +25,32 @@ class UserList(generics.ListAPIView):
permission_classes = (permissions.IsAuthenticated,) permission_classes = (permissions.IsAuthenticated,)
class GetAuthToken(ObtainAuthToken): class GetAuthToken(APIView):
""" Return authentication token for an authenticated user. """ """ Return authentication token for an authenticated user. """
def post(self, request, *args, **kwargs): permission_classes = [
permissions.IsAuthenticated,
]
def get(self, request, *args, **kwargs):
return self.login(request) return self.login(request)
def delete(self, request): def delete(self, request):
return self.logout(request) return self.logout(request)
def login(self, request): def login(self, request):
serializer = self.serializer_class(data=request.data,
context={'request': request})
serializer.is_valid(raise_exception=True)
user = serializer.validated_data['user']
token, created = Token.objects.get_or_create(user=user)
return Response({ if request.user.is_authenticated:
'token': token.key, # Get the user token (or create one if it does not exist)
'pk': user.pk, token, created = Token.objects.get_or_create(user=request.user)
'username': user.username, return Response({
'email': user.email 'token': token.key,
}) })
else:
return Response({
'error': 'User not authenticated',
})
def logout(self, request): def logout(self, request):
try: try: