Skip to content

Commit

Permalink
Common folder (#96)
Browse files Browse the repository at this point in the history
* Revamped JWTManager

* adapted User_management code to new implementation of the common JWT manager

* Added the common folder to UM docker

* Added the common folder to user management in github actions

---------

Co-authored-by: aLeuleu <[email protected]>
  • Loading branch information
V-Fries and a-levra authored Jan 26, 2024
1 parent bbe3adc commit 5bf9c4c
Show file tree
Hide file tree
Showing 12 changed files with 172 additions and 91 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/user_management.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ jobs:
run: |
isort . --check-only
- name: Copy common
working-directory: user_management/src/
run: |
cp -r ../../common .
- name: Run Django migrations
working-directory: user_management/src/
env:
Expand Down
Empty file added common/src/__init__.py
Empty file.
65 changes: 65 additions & 0 deletions common/src/jwt_managers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from datetime import datetime, timedelta, timezone

import jwt

import common.src.settings as settings


class JWTManager:
def __init__(self,
private_key: str | None,
public_key: str | None,
algorithm: str,
expiration_time_minutes: int | None):
self.private_key = private_key
self.public_key = public_key
self.algorithm = algorithm
self.expiration_time_minutes = expiration_time_minutes

def generate_jwt(self, payload_arg: dict) -> (bool, str | None, list[str] | None):
""" returns: Success, jwt, [error messages] """

now = datetime.now(timezone.utc)
expiration_time_minutes = now + timedelta(minutes=self.expiration_time_minutes)

payload = {'exp': expiration_time_minutes}
for key, value in payload_arg.items():
payload[key] = value

try:
token = jwt.encode(payload, self.private_key, algorithm=self.algorithm)
except Exception as e:
return False, None, [str(e)]
return True, token, None

def decode_jwt(self, encoded_jwt: str) -> (bool, dict | None, list[str] | None):
""" returns: Success, payload, error message """

try:
decoded_payload = jwt.decode(encoded_jwt, self.public_key, algorithms=[self.algorithm])
if decoded_payload.get('exp') is None:
return False, None, ["No expiration date found"]
return True, decoded_payload, None
except Exception as e:
return False, None, [str(e)]


class UserAccessJWTDecoder:
JWT_MANAGER = JWTManager(None,
settings.ACCESS_PUBLIC_KEY,
settings.ACCESS_ALGORITHM,
None)

@staticmethod
def authenticate(encoded_jwt: str) -> (bool, dict, list[str] | None):
""" returns: Success, payload, error message """

success, decoded_payload, error_decode = UserAccessJWTDecoder.JWT_MANAGER.decode_jwt(encoded_jwt)
if not success:
return False, None, error_decode

user_id = decoded_payload.get('user_id')
if user_id is None or user_id == '':
return False, None, ['No user_id in payload']

return True, decoded_payload, None
10 changes: 10 additions & 0 deletions common/src/settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
ACCESS_PUBLIC_KEY = """-----BEGIN PUBLIC KEY-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAl88VVar5X6lAlHjj4o4r
r3WoAQloSNbxjgyUd6dU3z3a8JbLibihyl/LjrfAJXCT39FzBbjcWHw7dnDkBeU0
xX8pPNESkfJI7wxzkc1WcPk1KMwvy1dTaoCub7fZxNl2oOObdzTGpic8co7VOUqa
5cJks3MTL/8ipxaf4HVJ4luvcySvPflL1woWO3QfTomL/B/Xnu9fmj2ynn8DptfY
wJEe4eFA/jx+TP3coPBgs/XYG3stdyislm574U+5QvfRi1uii8jkFgpIxwUnxYbx
mZW+X8IdGmaUnucNeF1pLZjEIcr7MkzP3zm1auQww71DObGTPaLLJNjTPdP3rWYJ
mQIDAQAB
-----END PUBLIC KEY-----"""
ACCESS_ALGORITHM = 'RS256'
8 changes: 8 additions & 0 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,11 @@ volumes:
type: none
o: bind
device: pong_server/src/

# Common code
common_code:
driver: local
driver_opts:
type: none
o: bind
device: common
2 changes: 1 addition & 1 deletion user_management/doc/User_management.md
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ all fields are mandatory
> errors can be :
> - Refresh token not found
> - Signature verification failed
> - Empty payload
> - No expiration date found
> - Signature has expired
> - No user_id in payload
> - User does not exist
Expand Down
7 changes: 5 additions & 2 deletions user_management/docker/docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,14 @@ services:
build:
context: .
dockerfile: Dockerfile
command: sh -c "python3 src/manage.py makemigrations &&
command: sh -c "cp -r /app/user_management /app/src &&
cp -r /app/common /app/src/common &&
python3 src/manage.py makemigrations &&
python3 src/manage.py migrate &&
gunicorn --chdir src/ user_management.wsgi:application -w 4 -b 0.0.0.0:8000"
volumes:
- user_management_code:/app/src/
- user_management_code:/app/user_management
- common_code:/app/common
expose:
- "8000"
depends_on:
Expand Down
22 changes: 11 additions & 11 deletions user_management/src/user/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from user.models import User
from user_management import settings
from user_management.JWTManager import JWTManager
from user_management.JWTManager import UserAccessJWTManager


class TestsSignup(TestCase):
Expand Down Expand Up @@ -256,10 +256,10 @@ def test_refresh_jwt(self):
}
url = reverse('signup')
print('Creating user...')
refresh_token = self.client.post(url, json.dumps(data_preparation), content_type='application/json')
result = self.client.post(url, json.dumps(data_preparation), content_type='application/json')
print('User created')
data = {
'refresh_token': refresh_token.json()['refresh_token']
'refresh_token': result.json()['refresh_token']
}
url = reverse('refresh-access-jwt')
print('Testing valid refresh token')
Expand All @@ -268,7 +268,7 @@ def test_refresh_jwt(self):
self.assertTrue('access_token' in result.json())
print('Testing invalids refresh tokens ... :')
# 1 Refresh token not found
valid_access_token = JWTManager('access').generate_token(1)[1]
valid_access_token = UserAccessJWTManager.generate_jwt(1)[1]

# 2 Invalid token
valid_payload = {
Expand All @@ -277,38 +277,38 @@ def test_refresh_jwt(self):
'token_type': 'refresh'
}
bad_signature_token = jwt.encode(valid_payload,
open(f'{settings.BASE_DIR}/user/test_resources/invalid_key.pub').read(),
'RS256')
"INVALID_KEY",
'HS256')
# 3 Empty payload
payload = {}
empty_payload = jwt.encode(payload, settings.REFRESH_PRIVATE_KEY, 'RS256')
empty_payload = jwt.encode(payload, settings.REFRESH_KEY, 'HS256')

# 4 Token expired
payload_expired = {
'user_id': 1,
'exp': datetime.utcnow(),
'token_type': 'refresh'
}
expired_token = jwt.encode(payload_expired, settings.REFRESH_PRIVATE_KEY, 'RS256')
expired_token = jwt.encode(payload_expired, settings.REFRESH_KEY, 'HS256')

# 5 No user_id in payload
payload_no_user_id = {
'exp': datetime.utcnow() + timedelta(minutes=100),
'token_type': 'refresh'
}
token_no_user_id = jwt.encode(payload_no_user_id, settings.REFRESH_PRIVATE_KEY, 'RS256')
token_no_user_id = jwt.encode(payload_no_user_id, settings.REFRESH_KEY, 'HS256')

# 6 User does not exist
payload_user_not_exist = {
'user_id': 999,
'exp': datetime.utcnow() + timedelta(minutes=100),
'token_type': 'refresh'
}
token_user_not_exist = jwt.encode(payload_user_not_exist, settings.REFRESH_PRIVATE_KEY, 'RS256')
token_user_not_exist = jwt.encode(payload_user_not_exist, settings.REFRESH_KEY, 'HS256')

errors = [('Refresh token not found', {'access_token': valid_access_token}),
('Signature verification failed', {'refresh_token': bad_signature_token}),
('Empty payload', {'refresh_token': empty_payload}),
('No expiration date found', {'refresh_token': empty_payload}),
('Signature has expired', {'refresh_token': expired_token}),
('No user_id in payload', {'refresh_token': token_no_user_id}),
('User does not exist', {'refresh_token': token_user_not_exist})
Expand Down
4 changes: 2 additions & 2 deletions user_management/src/user/views/OAuth.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from user.models import PendingOAuth, User
from user_management import settings
from user_management.JWTManager import JWTManager
from user_management.JWTManager import UserRefreshJWTManager
from user_management.utils import (download_image_from_url,
generate_random_string)

Expand Down Expand Up @@ -112,7 +112,7 @@ def get(self, request, auth_service):
if not user:
return JsonResponse(data={'errors': ['Failed to create or get user']}, status=400)

success, refresh_token, errors = JWTManager('refresh').generate_token(user.id)
success, refresh_token, errors = UserRefreshJWTManager.generate_jwt(user.id)
if not success:
return JsonResponse(data={'errors': errors}, status=400)

Expand Down
11 changes: 6 additions & 5 deletions user_management/src/user/views/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from django.views.decorators.csrf import csrf_exempt

from user.models import User
from user_management.JWTManager import JWTManager
from user_management.JWTManager import (UserAccessJWTManager,
UserRefreshJWTManager)


@method_decorator(csrf_exempt, name='dispatch')
Expand All @@ -27,7 +28,7 @@ def post(self, request):
user = User.objects.create(username=json_request['username'],
email=json_request['email'],
password=json_request['password'])
success, refresh_token, errors = JWTManager('refresh').generate_token(user.id)
success, refresh_token, errors = UserRefreshJWTManager.generate_jwt(user.id)
if success is False:
return JsonResponse(data={'errors': errors}, status=400)
return JsonResponse(data={'refresh_token': refresh_token}, status=201)
Expand Down Expand Up @@ -126,7 +127,7 @@ def post(self, request):
if validation_errors:
return JsonResponse(data={'errors': validation_errors}, status=400)
user = User.objects.filter(username=json_request['username']).first()
success, refresh_token, errors = JWTManager('refresh').generate_token(user.id)
success, refresh_token, errors = UserRefreshJWTManager.generate_jwt(user.id)
if success is False:
return JsonResponse(data={'errors': errors}, status=400)
return JsonResponse(data={'refresh_token': refresh_token}, status=200)
Expand Down Expand Up @@ -204,10 +205,10 @@ def post(request):
refresh_token = json_request.get('refresh_token')
if refresh_token is None:
return JsonResponse(data={'errors': ['Refresh token not found']}, status=400)
success, errors, user_id = JWTManager('refresh').is_authentic_and_valid_request(refresh_token)
success, user_id, errors = UserRefreshJWTManager.authenticate(refresh_token)
if success is False:
return JsonResponse(data={'errors': errors}, status=400)
success, access_token, errors = JWTManager('access').generate_token(user_id)
success, access_token, errors = UserAccessJWTManager.generate_jwt(user_id)
if success is False:
return JsonResponse(data={'errors': errors}, status=400)
return JsonResponse(data={'access_token': access_token}, status=200)
Expand Down
111 changes: 55 additions & 56 deletions user_management/src/user_management/JWTManager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from datetime import datetime, timedelta

import jwt
import common.src.settings as common_settings
from common.src.jwt_managers import JWTManager, UserAccessJWTDecoder
from django.conf import settings

from user.models import User
Expand All @@ -15,61 +14,61 @@ def user_exist(user_id: int) -> bool:
return User.objects.filter(id=user_id).exists()


class JWTManager:
def __init__(self, token_type: str, payload: dict = None):
if token_type == 'access':
self.private_key = settings.ACCESS_KEY
self.public_key = settings.ACCESS_KEY
self.algorithm = 'HS256'
self.expire_minutes_reference = settings.ACCESS_EXPIRATION_MINUTES
elif token_type == 'refresh':
self.private_key = settings.REFRESH_PRIVATE_KEY
self.public_key = settings.REFRESH_PUBLIC_KEY
self.algorithm = 'RS256'
self.expire_minutes_reference = settings.REFRESH_EXPIRATION_MINUTES
else:
raise Exception('Invalid token type')
self.token_type = token_type
self.payload = payload

def decode_jwt(self, encoded_jwt: str) -> (bool, dict, str):
try:
decoded_payload = jwt.decode(encoded_jwt, self.public_key, algorithms=[self.algorithm])
return True, decoded_payload, None
except Exception as e:
return False, None, str(e)

def generate_token(self, user_id: int) -> (bool, str, str):
class UserRefreshJWTManager:
JWT_MANAGER = JWTManager(settings.REFRESH_KEY,
settings.REFRESH_KEY,
'HS256',
settings.REFRESH_EXPIRATION_MINUTES)

@staticmethod
def generate_jwt(user_id: int) -> (bool, str | None, list[str] | None):
""" returns: Success, jwt, [error messages] """

if not user_exist(user_id):
return False, None, 'User does not exist'
try:
now = datetime.utcnow()
expiration_time = now + timedelta(minutes=self.expire_minutes_reference)
payload = {
'user_id': user_id,
'exp': expiration_time,
'token_type': self.token_type
}
token = jwt.encode(payload, self.private_key, algorithm=self.algorithm)
except Exception as e:
return False, None, str(e)
return True, token, None

def is_authentic_and_valid_request(self, encoded_jwt: str) -> (bool, list, int):
success, decoded_payload, error_decode = self.decode_jwt(encoded_jwt)
errors = []
return False, None, ['User does not exist']
return UserRefreshJWTManager.JWT_MANAGER.generate_jwt({'user_id': user_id, 'token_type': 'refresh'}) # Common

@staticmethod
def authenticate(encoded_jwt: str) -> (bool, int | None, list[str] | None):
""" returns: Success, user_id, [error messages] """

success, payload, error_list = UserRefreshJWTManager.JWT_MANAGER.decode_jwt(encoded_jwt) # Common
if not success:
errors.append(error_decode)
elif decoded_payload is None or decoded_payload == {}:
errors.append('Empty payload')
if errors:
return False, errors, None
return False, None, error_list

user_id = decoded_payload.get('user_id')
user_id = payload.get('user_id')
if user_id is None or user_id == '':
errors.append('No user_id in payload')
return False, None, ['No user_id in payload']
elif not user_exist(user_id):
errors.append('User does not exist')
if errors:
return False, errors, None
return True, None, user_id
return False, None, ['User does not exist']
return True, user_id, None


class UserAccessJWTManager:
# Never provide the public key as we must use common.UserAccessJWTDecoder to decode
JWT_MANAGER = JWTManager(settings.ACCESS_PRIVATE_KEY,
None,
common_settings.ACCESS_ALGORITHM,
settings.ACCESS_EXPIRATION_MINUTES)

@staticmethod
def generate_jwt(user_id: int) -> (bool, str | None, list[str] | None):
""" returns: Success, jwt, [error messages] """

if not user_exist(user_id):
return False, None, ['User does not exist']
return UserAccessJWTManager.JWT_MANAGER.generate_jwt({'user_id': user_id, 'token_type': 'access'}) # Common

@staticmethod
def authenticate(encoded_jwt: str) -> (bool, str | None, list[str]):
""" returns: Success, user_id, [error messages] """

success, payload, error_decode = UserAccessJWTDecoder.authenticate(encoded_jwt) # Common
if not success:
return False, None, error_decode

user_id = payload['user_id']
if not user_exist(user_id):
return False, None, ['User does not exist']

return True, user_id, None
Loading

0 comments on commit 5bf9c4c

Please sign in to comment.