From 6e3f60341313784505ee7d067c7811fadf7164f1 Mon Sep 17 00:00:00 2001
From: Oliver Walters <oliver.henry.walters@gmail.com>
Date: Sat, 18 Apr 2020 12:06:50 +1000
Subject: [PATCH] Improvements for token authentication

- Do not POST username and password data
- This is already encoded in the Authorization: Basic headers
---
 InvenTree/InvenTree/test_api.py | 84 +++++++++++++++------------------
 InvenTree/users/views.py        | 32 +++++++------
 2 files changed, 56 insertions(+), 60 deletions(-)

diff --git a/InvenTree/InvenTree/test_api.py b/InvenTree/InvenTree/test_api.py
index e75a111e5d..7068593e6b 100644
--- a/InvenTree/InvenTree/test_api.py
+++ b/InvenTree/InvenTree/test_api.py
@@ -4,9 +4,10 @@ from rest_framework.test import APITestCase
 from rest_framework import status
 
 from django.urls import reverse
-
 from django.contrib.auth import get_user_model
 
+from base64 import b64encode
+
 
 class APITests(APITestCase):
     """ Tests for the InvenTree API """
@@ -21,24 +22,48 @@ class APITests(APITestCase):
     username = 'test_user'
     password = 'test_pass'
 
+    token = None
+
     def setUp(self):
 
         # Create a user (but do not log in!)
         User = get_user_model()
         User.objects.create_user(self.username, 'user@email.com', self.password)
 
-    def get_token(self):
+    def basicAuth(self):
+        # Use basic authentication
+
+        authstring = bytes("{u}:{p}".format(u=self.username, p=self.password), "ascii")
+
+        # Use "basic" auth by default
+        auth = b64encode(authstring).decode("ascii")
+        self.client.credentials(HTTP_AUTHORIZATION="Basic {auth}".format(auth=auth))
+
+    def tokenAuth(self):
+
+        self.basicAuth()
         token_url = reverse('api-token')
+        response = self.client.get(token_url, format='json', data={})
+
+        self.assertEqual(response.status_code, status.HTTP_200_OK)
+        self.assertIn('token', response.data)
 
-        # POST to retreive a token
-        response = self.client.post(token_url, format='json', data={'username': self.username, 'password': self.password})
-        
         token = response.data['token']
-
-        self.client.credentials(HTTP_AUTHORIZATION='Token ' + token)
-
         self.token = token
 
+    def token_failure(self):
+        # Test token endpoint without basic auth
+        url = reverse('api-token')
+        response = self.client.get(url, format='json')
+
+        self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
+        self.assertIsNone(self.token)
+
+    def token_success(self):
+
+        self.tokenAuth()
+        self.assertIsNotNone(self.token)
+
     def test_info_view(self):
         """
         Test that we can read the 'info-view' endpoint.
@@ -55,51 +80,18 @@ class APITests(APITestCase):
 
         self.assertEquals('InvenTree', data['server'])
 
-    def test_get_token_fail(self):
-        """ Ensure that an invalid user cannot get a token """
-
-        token_url = reverse('api-token')
-
-        response = self.client.post(token_url, format='json', data={'username': 'bad', 'password': 'also_bad'})
-
-        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
-        self.assertFalse('token' in response.data)
-
-    def test_get_token_pass(self):
-        """ Ensure that a valid user can request an API token """
-
-        token_url = reverse('api-token')
-
-        # POST to retreive a token
-        response = self.client.post(token_url, format='json', data={'username': self.username, 'password': self.password})
-
-        self.assertEqual(response.status_code, status.HTTP_200_OK)
-        self.assertTrue('token' in response.data)
-        self.assertTrue('pk' in response.data)
-        self.assertTrue(len(response.data['token']) > 0)
-
-        # Now, use the token to access other data
-        token = response.data['token']
-
-        part_url = reverse('api-part-list')
-
-        # Try to access without a token
-        response = self.client.get(part_url, format='json')
+    def test_barcode_fail(self):
+        # Test barcode endpoint without auth
+        response = self.client.post(reverse('api-barcode-plugin'), format='json')
 
         self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
-        
-        # Now, with the token
-        self.client.credentials(HTTP_AUTHORIZATION='Token ' + token)
-        response = self.client.get(part_url, format='json')
-
-        self.assertEqual(response.status_code, status.HTTP_200_OK)
 
     def test_barcode(self):
         """ Test the barcode endpoint """
 
-        url = reverse('api-barcode-plugin')
+        self.tokenAuth()
 
-        self.get_token()
+        url = reverse('api-barcode-plugin')
 
         data = {
             'barcode': {
diff --git a/InvenTree/users/views.py b/InvenTree/users/views.py
index 01d9834c69..97e5f48355 100644
--- a/InvenTree/users/views.py
+++ b/InvenTree/users/views.py
@@ -3,7 +3,7 @@ from django.contrib.auth.models import User
 from django.core.exceptions import ObjectDoesNotExist
 from .serializers import UserSerializer
 
-from rest_framework.authtoken.views import ObtainAuthToken
+from rest_framework.views import APIView
 from rest_framework.authtoken.models import Token
 from rest_framework.response import Response
 from rest_framework import status
@@ -25,28 +25,32 @@ class UserList(generics.ListAPIView):
     permission_classes = (permissions.IsAuthenticated,)
 
 
-class GetAuthToken(ObtainAuthToken):
+class GetAuthToken(APIView):
     """ Return authentication token for an authenticated user. """
 
-    def post(self, request, *args, **kwargs):
+    permission_classes = [
+        permissions.IsAuthenticated,
+    ]
+
+    def get(self, request, *args, **kwargs):
         return self.login(request)
 
     def delete(self, request):
         return self.logout(request)
 
     def login(self, request):
-        serializer = self.serializer_class(data=request.data,
-                                           context={'request': request})
-        serializer.is_valid(raise_exception=True)
-        user = serializer.validated_data['user']
-        token, created = Token.objects.get_or_create(user=user)
 
-        return Response({
-            'token': token.key,
-            'pk': user.pk,
-            'username': user.username,
-            'email': user.email
-        })
+        if request.user.is_authenticated:
+            # Get the user token (or create one if it does not exist)
+            token, created = Token.objects.get_or_create(user=request.user)
+            return Response({
+                'token': token.key,
+            })
+
+        else:
+            return Response({
+                'error': 'User not authenticated',
+            })
 
     def logout(self, request):
         try: