Fix login, add logout, add refresh for tokens and tests for all

This commit is contained in:
2025-03-05 18:29:57 +02:00
parent 6630fb577b
commit d6211d12b8
8 changed files with 253 additions and 43 deletions
+3 -1
View File
@@ -1,3 +1,4 @@
from fastapi.security import OAuth2PasswordBearer
from pydantic_settings import BaseSettings, SettingsConfigDict # type: ignore from pydantic_settings import BaseSettings, SettingsConfigDict # type: ignore
from passlib.context import CryptContext # type: ignore from passlib.context import CryptContext # type: ignore
import pytz import pytz
@@ -6,6 +7,7 @@ class Settings(BaseSettings):
PROJECT_NAME: str = "StoneEdge Asset Management System" PROJECT_NAME: str = "StoneEdge Asset Management System"
PROJECT_VERSION: str = "0.0.1" PROJECT_VERSION: str = "0.0.1"
PROJECT_SUMMARY: str = "Product API for StoneEdge." PROJECT_SUMMARY: str = "Product API for StoneEdge."
PROJECT_PUBLIC_URL: str = ""
SECRET_KEY: str | None = None SECRET_KEY: str | None = None
PSQL_USERNAME: str = "user" PSQL_USERNAME: str = "user"
PSQL_PASSWORD: str = "password" PSQL_PASSWORD: str = "password"
@@ -16,8 +18,8 @@ class Settings(BaseSettings):
ACCESS_TOKEN_EXPIRE_MIN: int = 30 ACCESS_TOKEN_EXPIRE_MIN: int = 30
REFRESH_TOKEN_EXPIRE_MIN: int = 60 REFRESH_TOKEN_EXPIRE_MIN: int = 60
BACKEND_CORS_ORIGINS: list = ["*"] BACKEND_CORS_ORIGINS: list = ["*"]
DEFAULT_TIMEZONE: str = pytz.UTC._tzname
CRYPT: CryptContext = CryptContext(schemes=["bcrypt"], deprecated="auto") CRYPT: CryptContext = CryptContext(schemes=["bcrypt"], deprecated="auto")
OAUTH2_SCHEME: OAuth2PasswordBearer = OAuth2PasswordBearer(tokenUrl="token")
model_config = SettingsConfigDict(env_file=".env") model_config = SettingsConfigDict(env_file=".env")
+4 -3
View File
@@ -1,3 +1,4 @@
import pytz
from tortoise.models import Model from tortoise.models import Model
from tortoise import fields from tortoise import fields
import uuid import uuid
@@ -21,8 +22,8 @@ class Token(Model, CMDMixin):
refresh_token: str = fields.TextField(null=True) refresh_token: str = fields.TextField(null=True)
disabled: bool = fields.BooleanField(default=False) disabled: bool = fields.BooleanField(default=False)
def delete(self) -> None: async def delete(self) -> None:
self.disabled = True self.disabled = True
self.disabled_at = datetime.now(tz=settings.DEFAULT_TIMEZONE) self.disabled_at = datetime.now(tz=pytz.UTC)
self.save() await self.save()
+58 -26
View File
@@ -1,21 +1,22 @@
from datetime import timedelta from datetime import datetime
from typing import Annotated from typing import Annotated
from fastapi.responses import JSONResponse import uuid
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordRequestForm
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from modules.auth.utils import create_token import pytz
from modules.users.utils import get_current_active_user
from modules.auth.utils import create_jwt_tokens, get_tokens_from_logged_in_user
from modules.auth.models import Token from modules.auth.models import Token
from modules.users.models import User from modules.users.models import User
from fastapi import Depends, HTTPException from fastapi import Depends, HTTPException, status
from tortoise.expressions import Q
from config import settings from config import settings
router = APIRouter(prefix="/api/v1/auth", tags=["auth"]) router = APIRouter(prefix="/api/v1/auth", tags=["auth"])
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") email_error: str = "E-Mail Address or password is incorrect"
token_error: str = "Refresh token not found or something went wrong."
error: str = "E-Mail Address or password is incorrect"
crypt = settings.CRYPT crypt = settings.CRYPT
@@ -25,34 +26,65 @@ async def login(form: Annotated[OAuth2PasswordRequestForm, Depends()]):
user: User | None = await User.filter(email=form.username).first() user: User | None = await User.filter(email=form.username).first()
if user is None: if user is None:
raise HTTPException(status_code=401, detail=error) raise HTTPException(status_code=401, detail=email_error)
if user.check_against_password(form.password) is False: if user.check_against_password(form.password) is False:
raise HTTPException(status_code=401, detail=error) raise HTTPException(status_code=401, detail=email_error)
if user.disabled is True: if user.disabled is True:
raise HTTPException(status_code=401, detail=error) raise HTTPException(status_code=401, detail=email_error)
auth_token = create_token( tokens = await create_jwt_tokens(user)
user_id=user.id, offset=timedelta(settings.ACCESS_TOKEN_EXPIRE_MIN)
)
refresh_token = create_token( return {"jwt": tokens}
user_id=user.id, offset=timedelta(settings.REFRESH_TOKEN_EXPIRE_MIN)
)
token = await Token.create(
user=user,
access_token=auth_token,
refresh_token=refresh_token,
)
return {"jwt": token} @router.get("/logout", status_code=204)
async def logout(user: Annotated[User, Depends(get_current_active_user)]):
get_all_tokens = await Token.filter(Q(user__id=user.id))
if get_all_tokens is None:
raise HTTPException(
status_code=status.HTTP_204_NO_CONTENT, detail="An error occurred."
)
for token in get_all_tokens:
await token.delete()
return
@router.post("/refresh") @router.post("/refresh")
async def refresh_login(): async def refresh_login(
pass refresh_token: Annotated[Token | None, Depends(get_tokens_from_logged_in_user)]
):
if refresh_token is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=token_error,
)
# Disable tokens if used after expiration.
if (
refresh_token.created_at >= datetime.now(tz=pytz.utc)
and refresh_token.disabled is False
):
refresh_token.delete()
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=token_error,
)
if refresh_token.disabled is True:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=token_error,
)
await refresh_token.delete()
tokens = await create_jwt_tokens(
user=await User.filter(Q(id=refresh_token.user_id)).first()
)
return {"jwt": tokens}
@router.post("/register") @router.post("/register")
+55 -4
View File
@@ -1,8 +1,16 @@
from datetime import timedelta from datetime import timedelta
from typing import Annotated
import uuid, time import uuid, time
from tortoise.expressions import Q
from fastapi import Depends, HTTPException, status
from modules.users.models import User
from modules.auth.models import Token
from config import settings from config import settings
from joserfc import jwt # type: ignore from joserfc import jwt # type: ignore
from joserfc.jwk import OctKey # type: ignore from joserfc.jwk import OctKey # type: ignore
from config import settings
crypt = settings.CRYPT crypt = settings.CRYPT
@@ -17,7 +25,7 @@ def create_token(user_id: uuid, offset: timedelta) -> str:
return jwt.encode( return jwt.encode(
{"alg": "HS256", "typ": "JWT"}, {"alg": "HS256", "typ": "JWT"},
{ {
"iss": "", "iss": f"{settings.PROJECT_PUBLIC_URL}",
"sub": f"id:{user}", "sub": f"id:{user}",
"nbf": curr_time, "nbf": curr_time,
"iat": curr_time, "iat": curr_time,
@@ -27,5 +35,48 @@ def create_token(user_id: uuid, offset: timedelta) -> str:
) )
def decode_token(token: str): async def create_jwt_tokens(user: User) -> Token:
pass """
Create a Token class with the following entities:
1) A user that is attached to the Token
2) A fresh Auth Token
3) A fresh Refresh Token.
This is then returned in the form of an Token class.
"""
auth_token = create_token(
user_id=user.id, offset=timedelta(settings.ACCESS_TOKEN_EXPIRE_MIN)
)
refresh_token = create_token(
user_id=user.id, offset=timedelta(settings.REFRESH_TOKEN_EXPIRE_MIN)
)
return await Token.create(
user=user,
access_token=auth_token,
refresh_token=refresh_token,
)
async def get_tokens_from_logged_in_user(
token: Annotated[str, Depends(settings.OAUTH2_SCHEME)]
) -> User | None:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="An issue occurred with the token.",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload: jwt.Token = jwt.decode(
token, OctKey.import_key(settings.SECRET_KEY), algorithms=["HS256"]
)
id: str | None = payload.claims.get("sub", None)
if id is None:
raise credentials_exception
user_id = id.split(":")[1]
except:
raise credentials_exception
return await Token.filter(Q(refresh_token=token) & Q(user__id=user_id)).first()
@@ -2,6 +2,7 @@ from datetime import datetime
from enum import Enum from enum import Enum
from typing import Type from typing import Type
import uuid import uuid
import pytz
from tortoise.exceptions import ConfigurationError from tortoise.exceptions import ConfigurationError
from tortoise.models import Model from tortoise.models import Model
from tortoise import fields from tortoise import fields
@@ -78,9 +79,9 @@ class Organization(Model, CMDMixin):
def __str__(self) -> str: def __str__(self) -> str:
return f"{self.id} - {self.name}" return f"{self.id} - {self.name}"
def delete(self) -> None: async def delete(self) -> None:
self.disabled = True self.disabled = True
self.disabled_at = datetime.now(tz=settings.DEFAULT_TIMEZONE) self.disabled_at = datetime.now(tz=pytz.UTC)
self.save() await self.save()
@@ -1,6 +1,7 @@
from datetime import datetime from datetime import datetime
import uuid import uuid
from pydantic import EmailStr from pydantic import EmailStr
import pytz
from tortoise.models import Model from tortoise.models import Model
from tortoise import fields from tortoise import fields
@@ -57,10 +58,10 @@ class User(Model, CMDMixin):
return False return False
self.set_password(new_password) self.set_password(new_password)
def delete(self) -> None: async def delete(self) -> None:
self.disabled = True self.disabled = True
self.disabled_at = datetime.now(tz=settings.DEFAULT_TIMEZONE) self.disabled_at = datetime.now(tz=pytz.UTC)
self.save() await self.save()
@@ -102,7 +103,7 @@ class Membership(Model, CMDMixin):
acl: ACL = fields.ForeignKeyField("models.ACL") acl: ACL = fields.ForeignKeyField("models.ACL")
disabled: bool = fields.BooleanField(default=False) disabled: bool = fields.BooleanField(default=False)
def delete(self) -> None: async def delete(self) -> None:
self.disabled = True self.disabled = True
self.disabled_at = datetime.now(tz=settings.DEFAULT_TIMEZONE) self.disabled_at = datetime.now(tz=pytz.UTC)
self.save() await self.save()
@@ -0,0 +1,36 @@
from typing import Annotated
from joserfc import jwt # type: ignore
from joserfc.jwk import OctKey # type: ignore
from tortoise.expressions import Q
from fastapi import Depends, HTTPException, status
# from modules.users.schemas import UserModel
from modules.users.models import User
from config import settings
async def get_user_from_token(token: Annotated[str, Depends(settings.OAUTH2_SCHEME)]) -> User:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="An issue occurred with the token.",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload: jwt.Token = jwt.decode(token, OctKey.import_key(settings.SECRET_KEY), algorithms=["HS256"])
id: str | None = payload.claims.get("sub", None)
if id is None:
raise credentials_exception
user_id = id.split(":")[1]
except:
raise credentials_exception
return await User.filter(Q(id=user_id)).get_or_none()
async def get_current_active_user(
user: Annotated[User, Depends(get_user_from_token)],
):
if user.disabled:
raise HTTPException(status_code=400, detail="User is not found or active")
return user
@@ -65,3 +65,89 @@ class TestAuthentication(object):
"token_type": "Bearer", "token_type": "Bearer",
} }
} }
@pytest.mark.asyncio
async def test_logging_out_destroys_tokens(
self, client: AsyncClient, use_admin_account
):
_, _, admin, _ = use_admin_account
response = await client.post(
"http://localhost/api/v1/auth/",
data={
"username": "admin@localhost.com",
"password": "adminpassword",
"grant_type": "password",
},
)
assert response.status_code == 200
assert response.json() == {
"jwt": {
"created_at": ANY,
"user_id": str(admin.id),
"id": ANY,
"modified_at": ANY,
"disabled_at": None,
"refresh_token": ANY,
"disabled": False,
"access_token": ANY,
"token_type": "Bearer",
}
}
access_token = response.json()["jwt"]["access_token"]
logout = await client.get(
"http://localhost/api/v1/auth/logout",
headers={"Authorization": f"Bearer {access_token}"},
)
assert logout.status_code == 204
@pytest.mark.asyncio
async def test_create_new_tokens_upon_refresh(
self, client: AsyncClient, use_admin_account
):
_, _, admin, _ = use_admin_account
token = await client.post(
"http://localhost/api/v1/auth/",
data={
"username": "admin@localhost.com",
"password": "adminpassword",
"grant_type": "password",
},
)
assert token.status_code == 200
assert token.json() == {
"jwt": {
"created_at": ANY,
"user_id": str(admin.id),
"id": ANY,
"modified_at": ANY,
"disabled_at": None,
"refresh_token": ANY,
"disabled": False,
"access_token": ANY,
"token_type": "Bearer",
}
}
refresh_token = token.json()["jwt"]["refresh_token"]
response2 = await client.post(
"http://localhost/api/v1/auth/refresh",
headers={"Authorization": f"Bearer {refresh_token}"},
)
assert response2.status_code == 200
assert response2.json() == {
"jwt": {
"created_at": ANY,
"user_id": str(admin.id),
"id": ANY,
"modified_at": ANY,
"disabled_at": None,
"refresh_token": ANY,
"disabled": False,
"access_token": ANY,
"token_type": "Bearer",
}
}