Merge pull request #735 from SchrodingersGat/token-auth-fix

Improvements for token authentication
This commit is contained in:
Oliver 2020-04-18 12:12:54 +10:00 committed by GitHub
commit 4f3f78f55a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 56 additions and 60 deletions

View File

@ -4,9 +4,10 @@ from rest_framework.test import APITestCase
from rest_framework import status
from django.urls import reverse
from django.contrib.auth import get_user_model
from base64 import b64encode
class APITests(APITestCase):
""" Tests for the InvenTree API """
@ -21,24 +22,48 @@ class APITests(APITestCase):
username = 'test_user'
password = 'test_pass'
token = None
def setUp(self):
# Create a user (but do not log in!)
User = get_user_model()
User.objects.create_user(self.username, 'user@email.com', self.password)
def get_token(self):
token_url = reverse('api-token')
def basicAuth(self):
# Use basic authentication
# POST to retreive a token
response = self.client.post(token_url, format='json', data={'username': self.username, 'password': self.password})
authstring = bytes("{u}:{p}".format(u=self.username, p=self.password), "ascii")
# Use "basic" auth by default
auth = b64encode(authstring).decode("ascii")
self.client.credentials(HTTP_AUTHORIZATION="Basic {auth}".format(auth=auth))
def tokenAuth(self):
self.basicAuth()
token_url = reverse('api-token')
response = self.client.get(token_url, format='json', data={})
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn('token', response.data)
token = response.data['token']
self.client.credentials(HTTP_AUTHORIZATION='Token ' + token)
self.token = token
def token_failure(self):
# Test token endpoint without basic auth
url = reverse('api-token')
response = self.client.get(url, format='json')
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
self.assertIsNone(self.token)
def token_success(self):
self.tokenAuth()
self.assertIsNotNone(self.token)
def test_info_view(self):
"""
Test that we can read the 'info-view' endpoint.
@ -55,51 +80,18 @@ class APITests(APITestCase):
self.assertEquals('InvenTree', data['server'])
def test_get_token_fail(self):
""" Ensure that an invalid user cannot get a token """
token_url = reverse('api-token')
response = self.client.post(token_url, format='json', data={'username': 'bad', 'password': 'also_bad'})
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertFalse('token' in response.data)
def test_get_token_pass(self):
""" Ensure that a valid user can request an API token """
token_url = reverse('api-token')
# POST to retreive a token
response = self.client.post(token_url, format='json', data={'username': self.username, 'password': self.password})
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertTrue('token' in response.data)
self.assertTrue('pk' in response.data)
self.assertTrue(len(response.data['token']) > 0)
# Now, use the token to access other data
token = response.data['token']
part_url = reverse('api-part-list')
# Try to access without a token
response = self.client.get(part_url, format='json')
def test_barcode_fail(self):
# Test barcode endpoint without auth
response = self.client.post(reverse('api-barcode-plugin'), format='json')
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
# Now, with the token
self.client.credentials(HTTP_AUTHORIZATION='Token ' + token)
response = self.client.get(part_url, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_barcode(self):
""" Test the barcode endpoint """
url = reverse('api-barcode-plugin')
self.tokenAuth()
self.get_token()
url = reverse('api-barcode-plugin')
data = {
'barcode': {

View File

@ -3,7 +3,7 @@ from django.contrib.auth.models import User
from django.core.exceptions import ObjectDoesNotExist
from .serializers import UserSerializer
from rest_framework.authtoken.views import ObtainAuthToken
from rest_framework.views import APIView
from rest_framework.authtoken.models import Token
from rest_framework.response import Response
from rest_framework import status
@ -25,28 +25,32 @@ class UserList(generics.ListAPIView):
permission_classes = (permissions.IsAuthenticated,)
class GetAuthToken(ObtainAuthToken):
class GetAuthToken(APIView):
""" Return authentication token for an authenticated user. """
def post(self, request, *args, **kwargs):
permission_classes = [
permissions.IsAuthenticated,
]
def get(self, request, *args, **kwargs):
return self.login(request)
def delete(self, request):
return self.logout(request)
def login(self, request):
serializer = self.serializer_class(data=request.data,
context={'request': request})
serializer.is_valid(raise_exception=True)
user = serializer.validated_data['user']
token, created = Token.objects.get_or_create(user=user)
return Response({
'token': token.key,
'pk': user.pk,
'username': user.username,
'email': user.email
})
if request.user.is_authenticated:
# Get the user token (or create one if it does not exist)
token, created = Token.objects.get_or_create(user=request.user)
return Response({
'token': token.key,
})
else:
return Response({
'error': 'User not authenticated',
})
def logout(self, request):
try: