Setup testing, fix database connection, add manage.py
This commit is contained in:
@@ -23,7 +23,7 @@ TORTOISE_ORM = {
|
|||||||
|
|
||||||
|
|
||||||
async def init_db():
|
async def init_db():
|
||||||
await Tortoise.init(db_url=settings.PSQL_CONNECT_STR, modules=modules)
|
await Tortoise.init(config=TORTOISE_ORM)
|
||||||
|
|
||||||
|
|
||||||
async def migrate_db():
|
async def migrate_db():
|
||||||
|
|||||||
@@ -1,14 +1,13 @@
|
|||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.responses import JSONResponse
|
from tortoise import run_async
|
||||||
from starlette.responses import RedirectResponse
|
|
||||||
from tortoise import Tortoise
|
|
||||||
from config import settings
|
from config import settings
|
||||||
|
from database import migrate_db
|
||||||
|
|
||||||
|
from router import router as root_router
|
||||||
from modules.assets.router import router as asset_router
|
from modules.assets.router import router as asset_router
|
||||||
from modules.auth.router import router as auth_router
|
from modules.auth.router import router as auth_router
|
||||||
from modules.users.router import router as users_router
|
from modules.users.router import router as users_router
|
||||||
from modules.organizations.router import router as organizations_router
|
from modules.organizations.router import router as organizations_router
|
||||||
from tortoise.contrib.fastapi import register_tortoise
|
|
||||||
from database import modules
|
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title=settings.PROJECT_NAME,
|
title=settings.PROJECT_NAME,
|
||||||
@@ -16,27 +15,10 @@ app = FastAPI(
|
|||||||
summary=settings.PROJECT_SUMMARY,
|
summary=settings.PROJECT_SUMMARY,
|
||||||
)
|
)
|
||||||
|
|
||||||
Tortoise.init_models(modules, "models")
|
run_async(migrate_db())
|
||||||
|
|
||||||
register_tortoise(
|
|
||||||
app,
|
|
||||||
db_url=settings.PSQL_CONNECT_STR,
|
|
||||||
modules=modules,
|
|
||||||
generate_schemas=True,
|
|
||||||
add_exception_handlers=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
app.include_router(root_router)
|
||||||
app.include_router(auth_router)
|
app.include_router(auth_router)
|
||||||
app.include_router(users_router)
|
app.include_router(users_router)
|
||||||
app.include_router(organizations_router)
|
app.include_router(organizations_router)
|
||||||
app.include_router(asset_router)
|
app.include_router(asset_router)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/")
|
|
||||||
async def main():
|
|
||||||
return RedirectResponse(url="/docs")
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/ping")
|
|
||||||
async def ping() -> JSONResponse:
|
|
||||||
return JSONResponse("PONG")
|
|
||||||
|
|||||||
@@ -0,0 +1,24 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
from ptpython.repl import embed # type: ignore
|
||||||
|
|
||||||
|
from database import *
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
|
async def setup():
|
||||||
|
try:
|
||||||
|
await embed(globals=globals(), return_asyncio_coroutine=True, patch_stdout=True)
|
||||||
|
except EOFError:
|
||||||
|
loop.stop()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
try:
|
||||||
|
asyncio.ensure_future(setup())
|
||||||
|
loop.run_forever()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
@@ -0,0 +1,15 @@
|
|||||||
|
from tortoise import BaseDBAsyncClient
|
||||||
|
|
||||||
|
|
||||||
|
async def upgrade(db: BaseDBAsyncClient) -> str:
|
||||||
|
return """
|
||||||
|
ALTER TABLE "asset" ADD "disabled_at" TIMESTAMPTZ;
|
||||||
|
ALTER TABLE "asset" ADD "modified_at" TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP;
|
||||||
|
ALTER TABLE "asset" ADD "created_at" TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP;"""
|
||||||
|
|
||||||
|
|
||||||
|
async def downgrade(db: BaseDBAsyncClient) -> str:
|
||||||
|
return """
|
||||||
|
ALTER TABLE "asset" DROP COLUMN "disabled_at";
|
||||||
|
ALTER TABLE "asset" DROP COLUMN "modified_at";
|
||||||
|
ALTER TABLE "asset" DROP COLUMN "created_at";"""
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
from tortoise import fields
|
from tortoise import fields
|
||||||
|
|
||||||
|
class CMDMixin():
|
||||||
class CMDMixin:
|
|
||||||
"""
|
"""
|
||||||
Created, modified and delete mixin, these are required for every class.
|
Created, modified and delete mixin, these are required for every class.
|
||||||
"""
|
"""
|
||||||
@@ -9,3 +8,4 @@ class CMDMixin:
|
|||||||
created_at = fields.DatetimeField(null=True, auto_now_add=True)
|
created_at = fields.DatetimeField(null=True, auto_now_add=True)
|
||||||
modified_at = fields.DatetimeField(null=True, auto_now=True)
|
modified_at = fields.DatetimeField(null=True, auto_now=True)
|
||||||
disabled_at = fields.DatetimeField(null=True)
|
disabled_at = fields.DatetimeField(null=True)
|
||||||
|
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
from tortoise.models import Model
|
from tortoise.models import Model
|
||||||
from tortoise import fields
|
from tortoise import fields
|
||||||
|
from mixins.CMDMixin import CMDMixin
|
||||||
|
|
||||||
class Asset(Model):
|
class Asset(Model, CMDMixin):
|
||||||
id = fields.UUIDField(primary_key=True)
|
id = fields.UUIDField(primary_key=True)
|
||||||
name = fields.CharField(max_length=128)
|
name = fields.CharField(max_length=128)
|
||||||
|
|||||||
@@ -2,8 +2,6 @@ from uuid import UUID
|
|||||||
|
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
|
|
||||||
from modules.assets.models import Asset
|
|
||||||
|
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
prefix="/assets"
|
prefix="/assets"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from tortoise import fields
|
|||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from models import CMDMixin
|
from mixins.CMDMixin import CMDMixin
|
||||||
from config import settings
|
from config import settings
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -4,16 +4,15 @@ from fastapi.responses import JSONResponse
|
|||||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||||
|
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
from utils import create_token
|
from modules.auth.utils import create_token
|
||||||
from models import Token
|
from modules.auth.models import Token
|
||||||
from modules.users.models import User
|
from modules.users.models import User
|
||||||
from fastapi import Depends, HTTPException
|
from fastapi import Depends, HTTPException
|
||||||
from config import settings
|
from config import settings
|
||||||
from tortoise.expressions import Q
|
from tortoise.expressions import Q
|
||||||
from schemas import TokenModel
|
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/auth")
|
router = APIRouter(prefix="/api/v1/auth", tags=["auth"])
|
||||||
|
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||||
|
|
||||||
@@ -22,17 +21,17 @@ error: str = "E-Mail Address or password is incorrect"
|
|||||||
crypt = settings.CRYPT
|
crypt = settings.CRYPT
|
||||||
|
|
||||||
|
|
||||||
@router.post("/", response_model=TokenModel)
|
@router.post("/", status_code=200)
|
||||||
async def login(form: Annotated[OAuth2PasswordRequestForm, Depends()]):
|
async def login(form: Annotated[OAuth2PasswordRequestForm, Depends()]):
|
||||||
user: User = await User.filter(
|
user: User | None = await User.filter(
|
||||||
Q(email=form.username)
|
Q(email=form.username)
|
||||||
).get_or_none()
|
).get_or_none()
|
||||||
|
|
||||||
if user is None:
|
if user is None:
|
||||||
HTTPException(status_code=401, detail=error)
|
return HTTPException(status_code=401, detail=error)
|
||||||
|
|
||||||
if user.check_against_password(form.password) is False:
|
if user.check_against_password(form.password) is False:
|
||||||
HTTPException(status_code=401, detail=error)
|
return HTTPException(status_code=401, detail=error)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
await Token.create(
|
await Token.create(
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import uuid, time
|
import uuid, time
|
||||||
from config import settings
|
from config import settings
|
||||||
from joserfc import jwt # type: ignore
|
from joserfc import jwt # type: ignore
|
||||||
from joserfc.jwt import OctKey # type: ignore
|
|
||||||
|
|
||||||
crypt = settings.CRYPT
|
crypt = settings.CRYPT
|
||||||
|
|
||||||
@@ -21,7 +20,7 @@ def create_token(user_id: uuid, offset: float) -> str:
|
|||||||
"iat": curr_time,
|
"iat": curr_time,
|
||||||
"exp": int(curr_time + offset),
|
"exp": int(curr_time + offset),
|
||||||
},
|
},
|
||||||
OctKey.import_key(settings.SECRET_KEY),
|
settings.SECRET_KEY,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from tortoise.exceptions import ConfigurationError
|
|||||||
from tortoise.models import Model
|
from tortoise.models import Model
|
||||||
from tortoise import fields
|
from tortoise import fields
|
||||||
|
|
||||||
from models import CMDMixin
|
from mixins.CMDMixin import CMDMixin
|
||||||
from config import settings
|
from config import settings
|
||||||
|
|
||||||
class EnumField(fields.CharField):
|
class EnumField(fields.CharField):
|
||||||
@@ -45,11 +45,11 @@ class OrganizationType(Enum):
|
|||||||
There are no seat costs.
|
There are no seat costs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
HOME: int = 1 # Home use (Any size)
|
HOME: str = "home" # Home use (Any size)
|
||||||
SMALL_ORGANIZATION: int = 2 # 1-100
|
SMALL_ORGANIZATION: str = "s_org" # 1-100
|
||||||
MEDIUM_ORGANIZATION: int = 3 # 100 - 500
|
MEDIUM_ORGANIZATION: str = "m_org" # 100 - 500
|
||||||
LARGE_ORGANIZATION: int = 4 # 500 - 1000
|
LARGE_ORGANIZATION: str = "l_org" # 500 - 1000
|
||||||
EXTRA_LARGE_ORGANIZATION: int = 5 # 1000 - 5000+
|
EXTRA_LARGE_ORGANIZATION: str = "xl_org" # 1000 - 5000+
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/organizations")
|
router = APIRouter(prefix="/api/v1/organizations")
|
||||||
|
|
||||||
@router.get("/")
|
@router.get("/")
|
||||||
def all_organizations():
|
def all_organizations():
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from tortoise.models import Model
|
|||||||
from tortoise import fields
|
from tortoise import fields
|
||||||
|
|
||||||
from modules.organizations.models import Organization
|
from modules.organizations.models import Organization
|
||||||
from models import CMDMixin
|
from mixins.CMDMixin import CMDMixin
|
||||||
from config import settings
|
from config import settings
|
||||||
|
|
||||||
crypt = settings.CRYPT
|
crypt = settings.CRYPT
|
||||||
|
|||||||
@@ -1,12 +1,19 @@
|
|||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/users")
|
router = APIRouter(prefix="/api/v1/users", tags=["users"])
|
||||||
|
|
||||||
|
|
||||||
@router.get("/")
|
@router.get("/")
|
||||||
def get_all_users():
|
def get_all_users():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/")
|
||||||
|
def create_user():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@router.get("/me")
|
@router.get("/me")
|
||||||
def get_user():
|
def get_user():
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -1,4 +1,17 @@
|
|||||||
|
[tool.black]
|
||||||
|
exclude = '''/
|
||||||
|
# Default values for Black.
|
||||||
|
\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|\.svn|_build|buck-out|build|dist|
|
||||||
|
/'''
|
||||||
|
line-length = 88
|
||||||
|
|
||||||
[tool.aerich]
|
[tool.aerich]
|
||||||
tortoise_orm = "database.TORTOISE_ORM"
|
tortoise_orm = "database.TORTOISE_ORM"
|
||||||
location = "./migrations"
|
location = "./migrations"
|
||||||
src_folder = "./."
|
src_folder = "./."
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
asyncio_default_fixture_loop_scope = "session"
|
||||||
|
testpaths = [
|
||||||
|
"tests/",
|
||||||
|
]
|
||||||
|
|||||||
@@ -7,3 +7,11 @@ black>=24.10.0
|
|||||||
joserfc>=1.0.1
|
joserfc>=1.0.1
|
||||||
passlib>=1.7.4
|
passlib>=1.7.4
|
||||||
pytz>=2024.2
|
pytz>=2024.2
|
||||||
|
ptpython>=0.25
|
||||||
|
|
||||||
|
# Test Suite
|
||||||
|
httpx>=0.28.1
|
||||||
|
pytest>=8.3.4
|
||||||
|
mock>=5.1.0
|
||||||
|
pytest-mock>=3.14.0
|
||||||
|
anyio>=4.8.0
|
||||||
@@ -0,0 +1,14 @@
|
|||||||
|
from fastapi import APIRouter
|
||||||
|
from fastapi.responses import JSONResponse, RedirectResponse
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/v1")
|
||||||
|
|
||||||
|
@router.get("/")
|
||||||
|
async def main():
|
||||||
|
return RedirectResponse(url="/docs")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/ping")
|
||||||
|
async def ping() -> JSONResponse:
|
||||||
|
return JSONResponse("PONG")
|
||||||
@@ -0,0 +1,43 @@
|
|||||||
|
import pytest
|
||||||
|
from httpx import AsyncClient
|
||||||
|
from tortoise import Tortoise
|
||||||
|
from database import modules
|
||||||
|
|
||||||
|
from main import app
|
||||||
|
|
||||||
|
DB_URL = "sqlite://:memory:"
|
||||||
|
|
||||||
|
|
||||||
|
async def init_db(db_url, create_db: bool = True, schemas: bool = True) -> None:
|
||||||
|
"""Initial database connection"""
|
||||||
|
await Tortoise.init(
|
||||||
|
db_url=db_url, modules={"models": modules}, _create_db=create_db
|
||||||
|
)
|
||||||
|
if create_db:
|
||||||
|
print(f"Database created! {db_url = }")
|
||||||
|
if schemas:
|
||||||
|
await Tortoise.generate_schemas()
|
||||||
|
print("Success to generate schemas")
|
||||||
|
|
||||||
|
|
||||||
|
async def init(db_url: str = DB_URL):
|
||||||
|
await init_db(db_url, True, True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def anyio_backend():
|
||||||
|
return "anyio"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
async def client():
|
||||||
|
async with AsyncClient(app=app, base_url="http://test") as client:
|
||||||
|
print("Client is ready")
|
||||||
|
yield client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
async def initialize_tests():
|
||||||
|
await init()
|
||||||
|
yield
|
||||||
|
await Tortoise._drop_databases()
|
||||||
@@ -0,0 +1,48 @@
|
|||||||
|
from tests.fixtures.conftest import init_db
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from modules.organizations.models import Organization
|
||||||
|
from modules.users.models import ACL, Membership, User
|
||||||
|
from main import app
|
||||||
|
from config import settings
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
crypt = settings.CRYPT
|
||||||
|
|
||||||
|
|
||||||
|
async def setup_function():
|
||||||
|
init_db()
|
||||||
|
org = await Organization.create(name="Admin's Organization", type="home")
|
||||||
|
user = await User.create(
|
||||||
|
email="admin@localhost.com",
|
||||||
|
username="admin",
|
||||||
|
name="admin",
|
||||||
|
surname="admin",
|
||||||
|
)
|
||||||
|
user.set_password("password")
|
||||||
|
user.save()
|
||||||
|
acl = await ACL.create(READ=True, WRITE=True, REPORT=True, MANAGE=True, ADMIN=True)
|
||||||
|
await Membership.create(organization=org, user=user, acl=acl)
|
||||||
|
|
||||||
|
print(org, user, acl)
|
||||||
|
|
||||||
|
|
||||||
|
# def teardown_function():
|
||||||
|
# Organization.all().delete()
|
||||||
|
# User.all().delete()
|
||||||
|
# ACL.all().delete()
|
||||||
|
# Membership.all().delete()
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_main():
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/auth",
|
||||||
|
data={
|
||||||
|
"username": "admin@localhost.com",
|
||||||
|
"password": "password",
|
||||||
|
"grant_type": "password",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.json() == {}
|
||||||
|
assert response.status_code == 200
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from main import app
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
def setup_function():
|
||||||
|
print("setting up")
|
||||||
|
|
||||||
|
def test_read_main():
|
||||||
|
response = client.get("/api/v1/")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
def test_get_pong():
|
||||||
|
response = client.get("/api/v1/ping")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.text == '"PONG"'
|
||||||
Reference in New Issue
Block a user