Fix fixtures and fix flaky tests
This commit is contained in:
@@ -14,7 +14,22 @@
|
||||
abort @external 401
|
||||
}
|
||||
|
||||
*.{$MAIN_DOMAIN} {
|
||||
stoneedge.{$MAIN_DOMAIN} {
|
||||
handle / {
|
||||
reverse_proxy stoneedge:8000
|
||||
}
|
||||
}
|
||||
|
||||
stoneedge-staging.{$MAIN_DOMAIN} {
|
||||
handle / {
|
||||
import protect
|
||||
abort
|
||||
|
||||
reverse_proxy stoneedge-staging:8000
|
||||
}
|
||||
}
|
||||
|
||||
{$MAIN_DOMAIN} {
|
||||
encode zstd gzip
|
||||
|
||||
handle /ping {
|
||||
@@ -24,16 +39,7 @@
|
||||
}
|
||||
|
||||
handle / {
|
||||
abort
|
||||
}
|
||||
|
||||
handle @stoneedge {
|
||||
reverse_proxy stoneedge:8000
|
||||
}
|
||||
|
||||
handle @stoneedge-staging {
|
||||
import protect
|
||||
abort
|
||||
abort 404
|
||||
}
|
||||
|
||||
handle {
|
||||
|
||||
@@ -7,7 +7,7 @@ class Settings(BaseSettings):
|
||||
PROJECT_NAME: str = "StoneEdge Asset Management System"
|
||||
PROJECT_VERSION: str = "0.0.1"
|
||||
PROJECT_SUMMARY: str = "Product API for StoneEdge."
|
||||
PROJECT_PUBLIC_URL: str = ""
|
||||
PROJECT_PUBLIC_URL: str = "localhost"
|
||||
SECRET_KEY: str | None = None
|
||||
PSQL_USERNAME: str = "user"
|
||||
PSQL_PASSWORD: str = "password"
|
||||
@@ -15,8 +15,8 @@ class Settings(BaseSettings):
|
||||
PSQL_PORT: int = 5432
|
||||
PSQL_DB_NAME: str = "stoneedge"
|
||||
PSQL_TEST_DB_NAME: str = "stoneedge_testing"
|
||||
ACCESS_TOKEN_EXPIRE_MIN: int = 30
|
||||
REFRESH_TOKEN_EXPIRE_MIN: int = 60
|
||||
ACCESS_TOKEN_EXPIRE_MIN: int = 10
|
||||
REFRESH_TOKEN_EXPIRE_MIN: int = 20
|
||||
BACKEND_CORS_ORIGINS: list = ["*"]
|
||||
CRYPT: CryptContext = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
OAUTH2_SCHEME: OAuth2PasswordBearer = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
@@ -11,6 +11,8 @@ from modules.auth.router import router as auth_router
|
||||
from modules.users.router import router as users_router
|
||||
from modules.organizations.router import router as organizations_router
|
||||
|
||||
from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware
|
||||
from fastapi.middleware.trustedhost import TrustedHostMiddleware
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_: FastAPI):
|
||||
@@ -27,6 +29,9 @@ app = FastAPI(
|
||||
default_response_class=msgspec_jsonresponse,
|
||||
)
|
||||
|
||||
app.add_middleware(HTTPSRedirectMiddleware)
|
||||
app.add_middleware(TrustedHostMiddleware, allowed_hosts=[settings.PROJECT_PUBLIC_URL,])
|
||||
|
||||
# Set all CORS enabled origins
|
||||
if settings.BACKEND_CORS_ORIGINS:
|
||||
app.add_middleware(
|
||||
|
||||
@@ -15,8 +15,8 @@ class Token(Model, CMDMixin):
|
||||
Creates the access tokens for the User
|
||||
"""
|
||||
|
||||
id: uuid = fields.UUIDField(primary_key=True)
|
||||
user: uuid = fields.ForeignKeyField("models.User")
|
||||
id: uuid.UUID = fields.UUIDField(primary_key=True)
|
||||
user: uuid.UUID = fields.ForeignKeyField("models.User")
|
||||
token_type: str = fields.CharField(max_length=128, default="Bearer")
|
||||
access_token: str = fields.TextField(null=True)
|
||||
refresh_token: str = fields.TextField(null=True)
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Annotated
|
||||
import uuid
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import EmailStr
|
||||
import pytz
|
||||
from modules.users.utils import get_current_active_user
|
||||
from modules.auth.utils import create_jwt_tokens, get_tokens_from_logged_in_user
|
||||
@@ -15,7 +16,7 @@ from config import settings
|
||||
|
||||
router = APIRouter(prefix="/api/v1/auth", tags=["auth"])
|
||||
|
||||
email_error: str = "E-Mail Address or password is incorrect"
|
||||
account_error: str = "E-Mail Address or password is incorrect"
|
||||
token_error: str = "Refresh token not found or something went wrong."
|
||||
|
||||
crypt = settings.CRYPT
|
||||
@@ -23,16 +24,21 @@ crypt = settings.CRYPT
|
||||
|
||||
@router.post("/")
|
||||
async def login(form: Annotated[OAuth2PasswordRequestForm, Depends()]):
|
||||
"""
|
||||
Login
|
||||
|
||||
Logs the user into our API, creates tokens and passes them back to User.
|
||||
"""
|
||||
user: User | None = await User.filter(email=form.username).first()
|
||||
|
||||
if user is None:
|
||||
raise HTTPException(status_code=401, detail=email_error)
|
||||
raise HTTPException(status_code=401, detail=account_error)
|
||||
|
||||
if user.check_against_password(form.password) is False:
|
||||
raise HTTPException(status_code=401, detail=email_error)
|
||||
raise HTTPException(status_code=401, detail=account_error)
|
||||
|
||||
if user.disabled is True:
|
||||
raise HTTPException(status_code=401, detail=email_error)
|
||||
raise HTTPException(status_code=401, detail=account_error)
|
||||
|
||||
tokens = await create_jwt_tokens(user)
|
||||
|
||||
@@ -41,6 +47,11 @@ async def login(form: Annotated[OAuth2PasswordRequestForm, Depends()]):
|
||||
|
||||
@router.get("/logout", status_code=204)
|
||||
async def logout(user: Annotated[User, Depends(get_current_active_user)]):
|
||||
"""
|
||||
Logout
|
||||
|
||||
Logout destroys all tokens for User that are currently active.
|
||||
"""
|
||||
get_all_tokens = await Token.filter(Q(user__id=user.id))
|
||||
if get_all_tokens is None:
|
||||
raise HTTPException(
|
||||
@@ -55,6 +66,12 @@ async def logout(user: Annotated[User, Depends(get_current_active_user)]):
|
||||
async def refresh_login(
|
||||
refresh_token: Annotated[Token | None, Depends(get_tokens_from_logged_in_user)]
|
||||
):
|
||||
"""
|
||||
Refresh
|
||||
|
||||
After ging this route a token that is active and not disabled, we disable ALL other tokens and pass along new tokens.
|
||||
Tokens are alive for about 10 minutes. Refresh tokens are alive for 20 minutes.
|
||||
"""
|
||||
if refresh_token is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
@@ -78,8 +95,12 @@ async def refresh_login(
|
||||
detail=token_error,
|
||||
)
|
||||
|
||||
await refresh_token.delete()
|
||||
|
||||
get_all_tokens = await Token.filter(Q(user__id=refresh_token.user_id))
|
||||
|
||||
for token in get_all_tokens:
|
||||
if token.id != refresh_token.id:
|
||||
await token.delete()
|
||||
|
||||
tokens = await create_jwt_tokens(
|
||||
user=await User.filter(Q(id=refresh_token.user_id)).first()
|
||||
)
|
||||
@@ -88,5 +109,9 @@ async def refresh_login(
|
||||
|
||||
|
||||
@router.post("/register")
|
||||
async def register():
|
||||
async def register(email: EmailStr, name: str, surname: str, password: str, validate_password: str):
|
||||
pass
|
||||
|
||||
@router.post("/2fa")
|
||||
async def twofa():
|
||||
pass
|
||||
@@ -2,4 +2,5 @@ from tortoise.contrib.pydantic import pydantic_model_creator
|
||||
|
||||
from modules.auth.models import Token
|
||||
|
||||
TokenModel = pydantic_model_creator(Token)
|
||||
token_model = pydantic_model_creator(Token)
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@ from tortoise.models import Model
|
||||
from tortoise import fields
|
||||
|
||||
from mixins.CMDMixin import CMDMixin
|
||||
from config import settings
|
||||
|
||||
class EnumField(fields.CharField):
|
||||
"""
|
||||
@@ -43,7 +42,6 @@ class OrganizationType(Enum):
|
||||
2. What size is it?
|
||||
|
||||
All choices should be representative of the org.
|
||||
There are no seat costs.
|
||||
"""
|
||||
|
||||
HOME: str = "home" # Home use (Any size)
|
||||
@@ -62,10 +60,10 @@ class Organization(Model, CMDMixin):
|
||||
and makes sure that we can add users.
|
||||
"""
|
||||
|
||||
id: uuid = fields.UUIDField(primary_key=True)
|
||||
id: uuid.UUID = fields.UUIDField(primary_key=True)
|
||||
name: str = fields.CharField(max_length=128)
|
||||
type: str = EnumField(OrganizationType)
|
||||
users: uuid = fields.ManyToManyField(
|
||||
users: uuid.UUID = fields.ManyToManyField(
|
||||
"models.User",
|
||||
related_name="members",
|
||||
through="Membership",
|
||||
|
||||
@@ -18,7 +18,7 @@ class User(Model, CMDMixin):
|
||||
This holds all of our users
|
||||
"""
|
||||
|
||||
id: uuid = fields.UUIDField(primary_key=True)
|
||||
id: uuid.UUID = fields.UUIDField(primary_key=True)
|
||||
email: EmailStr = fields.CharField(max_length=128)
|
||||
username: str = fields.TextField(max_length=128)
|
||||
name: str = fields.TextField(max_length=128)
|
||||
@@ -72,7 +72,7 @@ class ACL(Model):
|
||||
Access control lists, every invited user gets an ACL and this decides whether you grant / deny access to certain parts of our system.
|
||||
"""
|
||||
|
||||
id: uuid = fields.UUIDField(primary_key=True)
|
||||
id: uuid.UUID = fields.UUIDField(primary_key=True)
|
||||
READ: bool = fields.BooleanField(default=False)
|
||||
WRITE: bool = fields.BooleanField(default=False)
|
||||
REPORT: bool = fields.BooleanField(default=False)
|
||||
@@ -97,7 +97,7 @@ class Membership(Model, CMDMixin):
|
||||
Creates a connection between an user and a company together with an ACL.
|
||||
"""
|
||||
|
||||
id: uuid = fields.UUIDField(primary_key=True)
|
||||
id: uuid.UUID = fields.UUIDField(primary_key=True)
|
||||
organization: Organization = fields.ForeignKeyField("models.Organization")
|
||||
user: User = fields.ForeignKeyField("models.User")
|
||||
acl: ACL = fields.ForeignKeyField("models.ACL")
|
||||
|
||||
@@ -5,19 +5,22 @@ from joserfc.jwk import OctKey # type: ignore
|
||||
from tortoise.expressions import Q
|
||||
from fastapi import Depends, HTTPException, status
|
||||
|
||||
# from modules.users.schemas import UserModel
|
||||
from modules.users.models import User
|
||||
from config import settings
|
||||
|
||||
|
||||
async def get_user_from_token(token: Annotated[str, Depends(settings.OAUTH2_SCHEME)]) -> User:
|
||||
async def get_user_from_token(
|
||||
token: Annotated[str, Depends(settings.OAUTH2_SCHEME)]
|
||||
) -> User | None:
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="An issue occurred with the token.",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
payload: jwt.Token = jwt.decode(token, OctKey.import_key(settings.SECRET_KEY), algorithms=["HS256"])
|
||||
payload: jwt.Token = jwt.decode(
|
||||
token, OctKey.import_key(settings.SECRET_KEY), algorithms=["HS256"]
|
||||
)
|
||||
id: str | None = payload.claims.get("sub", None)
|
||||
if id is None:
|
||||
raise credentials_exception
|
||||
@@ -25,12 +28,20 @@ async def get_user_from_token(token: Annotated[str, Depends(settings.OAUTH2_SCHE
|
||||
except:
|
||||
raise credentials_exception
|
||||
|
||||
return await User.filter(Q(id=user_id)).get_or_none()
|
||||
return await User.filter(Q(id=user_id)).first()
|
||||
|
||||
|
||||
async def get_current_active_user(
|
||||
user: Annotated[User, Depends(get_user_from_token)],
|
||||
):
|
||||
user: Annotated[User | None, Depends(get_user_from_token)],
|
||||
) -> User:
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User is not found or active",
|
||||
)
|
||||
if user.disabled:
|
||||
raise HTTPException(status_code=400, detail="User is not found or active")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User is not found or active",
|
||||
)
|
||||
return user
|
||||
|
||||
@@ -40,7 +40,7 @@ def event_loop():
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def client_manager(app, base_url="http://localhost", **kw) -> ClientManagerType:
|
||||
async def client_manager(app, base_url="https://localhost", **kw) -> ClientManagerType:
|
||||
app.state.testing = True
|
||||
async with LifespanManager(app):
|
||||
transport = httpx.ASGITransport(app=app)
|
||||
|
||||
+42
-28
@@ -1,6 +1,6 @@
|
||||
from modules.organizations.models import Organization, OrganizationType
|
||||
from modules.users.models import ACL, Membership, User
|
||||
import pytest # type: ignore
|
||||
import pytest # type=ignore
|
||||
from config import settings
|
||||
|
||||
crypt = settings.CRYPT
|
||||
@@ -10,30 +10,38 @@ crypt = settings.CRYPT
|
||||
async def use_user_account():
|
||||
org, _ = await Organization.get_or_create(
|
||||
id="6ad4c94e-0522-4912-8d16-02d451f4c92d",
|
||||
name="User's Organization",
|
||||
type=OrganizationType.HOME,
|
||||
defaults={
|
||||
"name": "User's Organization",
|
||||
"type": OrganizationType.HOME,
|
||||
},
|
||||
)
|
||||
acl, _ = await ACL.get_or_create(
|
||||
id="a4e927a3-36e5-4761-badb-0a44ade6616f",
|
||||
READ=True,
|
||||
WRITE=True,
|
||||
REPORT=True,
|
||||
MANAGE=False,
|
||||
ADMIN=False,
|
||||
defaults={
|
||||
"READ": True,
|
||||
"WRITE": True,
|
||||
"REPORT": True,
|
||||
"MANAGE": False,
|
||||
"ADMIN": False,
|
||||
},
|
||||
)
|
||||
user, _ = await User.get_or_create(
|
||||
id="24235427-9662-4ba3-a9c5-00000000000b",
|
||||
email="user@localhost.com",
|
||||
username="user",
|
||||
name="awesome",
|
||||
surname="user",
|
||||
password=crypt.hash("userpassword"),
|
||||
defaults={
|
||||
"email": "user@localhost.com",
|
||||
"username": "user",
|
||||
"name": "awesome",
|
||||
"surname": "user",
|
||||
"password": crypt.hash("userpassword"),
|
||||
},
|
||||
)
|
||||
membership, _ = await Membership.get_or_create(
|
||||
id="833b9511-b2da-4760-8fa4-1a5c7059911e",
|
||||
organization=org,
|
||||
user=user,
|
||||
acl=acl,
|
||||
defaults={
|
||||
"organization": org,
|
||||
"user": user,
|
||||
"acl": acl,
|
||||
},
|
||||
)
|
||||
return org, acl, user, membership
|
||||
|
||||
@@ -42,31 +50,37 @@ async def use_user_account():
|
||||
async def use_admin_account():
|
||||
org, _ = await Organization.get_or_create(
|
||||
id="de001f44-1bb8-4667-9f9d-2d62d6ad7270",
|
||||
name="Admin's Organization",
|
||||
type=OrganizationType.EXTRA_LARGE_ORGANIZATION,
|
||||
defaults={
|
||||
"name": "Admin's Organization",
|
||||
"type": OrganizationType.EXTRA_LARGE_ORGANIZATION,
|
||||
},
|
||||
)
|
||||
acl, _ = await ACL.get_or_create(
|
||||
id="83c1bfe6-c2ed-4ba1-be03-0e5c1960ec31",
|
||||
READ=True,
|
||||
WRITE=True,
|
||||
REPORT=True,
|
||||
MANAGE=True,
|
||||
ADMIN=True,
|
||||
defaults={
|
||||
"READ": True,
|
||||
"WRITE": True,
|
||||
"REPORT": True,
|
||||
"MANAGE": True,
|
||||
"ADMIN": True,
|
||||
},
|
||||
)
|
||||
user, _ = await User.get_or_create(
|
||||
id="24235427-9662-4ba3-a9c5-00000000000a",
|
||||
defaults={
|
||||
"id": "24235427-9662-4ba3-a9c5-00000000000a",
|
||||
"email": "admin@localhost.com",
|
||||
"username": "admin",
|
||||
"name": "awesome",
|
||||
"surname": "admin",
|
||||
"password": crypt.hash("adminpassword"),
|
||||
}
|
||||
},
|
||||
)
|
||||
membership, _ = await Membership.get_or_create(
|
||||
id="393473ee-c218-4bcf-82cd-cb676c4d8a33",
|
||||
organization=org,
|
||||
user=user,
|
||||
acl=acl,
|
||||
defaults={
|
||||
"organization": org,
|
||||
"user": user,
|
||||
"acl": acl,
|
||||
},
|
||||
)
|
||||
return org, acl, user, membership
|
||||
|
||||
@@ -12,7 +12,7 @@ class TestAuthentication(object):
|
||||
self, client: AsyncClient
|
||||
):
|
||||
response = await client.post(
|
||||
"http://localhost/api/v1/auth/",
|
||||
"https://localhost/api/v1/auth/",
|
||||
data={
|
||||
"username": "non-existing@localhost.com",
|
||||
"password": "password",
|
||||
@@ -28,7 +28,7 @@ class TestAuthentication(object):
|
||||
):
|
||||
_, _, _, _ = use_admin_account
|
||||
response = await client.post(
|
||||
"http://localhost/api/v1/auth/",
|
||||
"https://localhost/api/v1/auth/",
|
||||
data={
|
||||
"username": "admin@localhost.com",
|
||||
"password": "password",
|
||||
@@ -44,7 +44,7 @@ class TestAuthentication(object):
|
||||
):
|
||||
_, _, admin, _ = use_admin_account
|
||||
response = await client.post(
|
||||
"http://localhost/api/v1/auth/",
|
||||
"https://localhost/api/v1/auth/",
|
||||
data={
|
||||
"username": "admin@localhost.com",
|
||||
"password": "adminpassword",
|
||||
@@ -68,14 +68,14 @@ class TestAuthentication(object):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logging_out_destroys_tokens(
|
||||
self, client: AsyncClient, use_admin_account
|
||||
self, client: AsyncClient, use_user_account
|
||||
):
|
||||
_, _, admin, _ = use_admin_account
|
||||
_, _, user, _ = use_user_account
|
||||
response = await client.post(
|
||||
"http://localhost/api/v1/auth/",
|
||||
"https://localhost/api/v1/auth/",
|
||||
data={
|
||||
"username": "admin@localhost.com",
|
||||
"password": "adminpassword",
|
||||
"username": "user@localhost.com",
|
||||
"password": "userpassword",
|
||||
"grant_type": "password",
|
||||
},
|
||||
)
|
||||
@@ -83,7 +83,7 @@ class TestAuthentication(object):
|
||||
assert response.json() == {
|
||||
"jwt": {
|
||||
"created_at": ANY,
|
||||
"user_id": str(admin.id),
|
||||
"user_id": str(user.id),
|
||||
"id": ANY,
|
||||
"modified_at": ANY,
|
||||
"disabled_at": None,
|
||||
@@ -95,20 +95,31 @@ class TestAuthentication(object):
|
||||
}
|
||||
|
||||
access_token = response.json()["jwt"]["access_token"]
|
||||
refresh_token = response.json()["jwt"]["refresh_token"]
|
||||
|
||||
logout = await client.get(
|
||||
"http://localhost/api/v1/auth/logout",
|
||||
"https://localhost/api/v1/auth/logout",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
assert logout.status_code == 204
|
||||
|
||||
refresh_request = await client.post(
|
||||
"https://localhost/api/v1/auth/refresh",
|
||||
headers={"Authorization": f"Bearer {refresh_token}"},
|
||||
)
|
||||
|
||||
assert refresh_request.status_code == 401
|
||||
assert refresh_request.json() == {
|
||||
"detail": "Refresh token not found or something went wrong."
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_new_tokens_upon_refresh(
|
||||
self, client: AsyncClient, use_admin_account
|
||||
):
|
||||
_, _, admin, _ = use_admin_account
|
||||
token = await client.post(
|
||||
"http://localhost/api/v1/auth/",
|
||||
"https://localhost/api/v1/auth/",
|
||||
data={
|
||||
"username": "admin@localhost.com",
|
||||
"password": "adminpassword",
|
||||
@@ -133,7 +144,7 @@ class TestAuthentication(object):
|
||||
refresh_token = token.json()["jwt"]["refresh_token"]
|
||||
|
||||
response2 = await client.post(
|
||||
"http://localhost/api/v1/auth/refresh",
|
||||
"https://localhost/api/v1/auth/refresh",
|
||||
headers={"Authorization": f"Bearer {refresh_token}"},
|
||||
)
|
||||
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
import pytest
|
||||
import pytest # type: ignore
|
||||
from httpx import AsyncClient
|
||||
|
||||
|
||||
class TestRootRoute(object):
|
||||
async def test_read_docs_on_main_route(self, client: AsyncClient):
|
||||
response = await client.get("http://localhost/api/v1/")
|
||||
response = await client.get("https://localhost/api/v1/")
|
||||
assert response.status_code == 307
|
||||
|
||||
async def test_get_pong(self, client: AsyncClient):
|
||||
response = await client.get("http://localhost/api/v1/ping")
|
||||
response = await client.get("https://localhost/api/v1/ping")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"ping": "pong!"}
|
||||
|
||||
Reference in New Issue
Block a user