Skip to content

Commit

Permalink
fix: node_user_usages problem
Browse files Browse the repository at this point in the history
  • Loading branch information
M03ED committed Sep 4, 2024
1 parent 079296b commit 68dab4c
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 51 deletions.
28 changes: 13 additions & 15 deletions app/db/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,13 +288,13 @@ def get_user_usages(db: Session, dbuser: User, start: datetime, end: datetime) -
Returns:
List[UserUsageResponse]: List of user usage responses.
"""
usages = {}

usages[0] = UserUsageResponse( # Main Core
usages = {0: UserUsageResponse( # Main Core
node_id=None,
node_name="Master",
used_traffic=0
)
)}

for node in db.query(Node).all():
usages[node.id] = UserUsageResponse(
node_id=node.id,
Expand Down Expand Up @@ -438,8 +438,8 @@ def update_user(db: Session, dbuser: User, modify: UserModify) -> User:
if modify.inbounds:
for proxy_type, tags in modify.excluded_inbounds.items():
dbproxy = db.query(Proxy) \
.where(Proxy.user == dbuser, Proxy.type == proxy_type) \
.first() or added_proxies.get(proxy_type)
.where(Proxy.user == dbuser, Proxy.type == proxy_type) \
.first() or added_proxies.get(proxy_type)
if dbproxy:
dbproxy.excluded_inbounds = [get_or_create_inbound(db, tag) for tag in tags]

Expand Down Expand Up @@ -627,7 +627,7 @@ def autodelete_expired_users(db: Session,

def get_all_users_usages(
db: Session, admin: Admin, start: datetime, end: datetime
) -> List[UserUsageResponse]:
) -> List[UserUsageResponse]:
"""
Retrieves usage data for all users associated with an admin within a specified time range.
Expand All @@ -644,16 +644,13 @@ def get_all_users_usages(
List[UserUsageResponse]: A list of UserUsageResponse objects, each representing
the usage data for a specific node or the main core.
"""
usages = {}

usages[0] = UserUsageResponse( # Main Core
usages = {0: UserUsageResponse( # Main Core
node_id=None,
node_name="Master",
used_traffic=0
)
)}

for node in db.query(Node).all():

usages[node.id] = UserUsageResponse(
node_id=node.id,
node_name=node.name,
Expand Down Expand Up @@ -784,6 +781,7 @@ def get_admin(db: Session, username: str) -> Admin:
"""
return db.query(Admin).filter(Admin.username == username).first()


def create_admin(db: Session, admin: AdminCreate) -> Admin:
"""
Creates a new admin in the database.
Expand Down Expand Up @@ -1042,6 +1040,7 @@ def get_user_templates(

return dbuser_templates.all()


def get_node(db: Session, name: str) -> Optional[Node]:
"""
Retrieves a node by its name.
Expand Down Expand Up @@ -1110,14 +1109,13 @@ def get_nodes_usage(db: Session, start: datetime, end: datetime) -> List[NodeUsa
Returns:
List[NodeUsageResponse]: A list of NodeUsageResponse objects containing usage data.
"""
usages = {}

usages[0] = NodeUsageResponse( # Main Core
usages = {0: NodeUsageResponse( # Main Core
node_id=None,
node_name="Master",
uplink=0,
downlink=0
)
)}

for node in db.query(Node).all():
usages[node.id] = NodeUsageResponse(
node_id=node.id,
Expand Down
15 changes: 12 additions & 3 deletions app/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from datetime import datetime, timezone
from app.utils.jwt import get_subscription_payload


def validate_admin(db: Session, username: str, password: str) -> Optional[AdminValidationResult]:
"""Validate admin credentials with environment variables or database."""
if SUDOERS.get(username) == password:
Expand All @@ -26,33 +27,39 @@ def get_admin_by_username(username: str, db: Session = Depends(get_db)):
raise HTTPException(status_code=404, detail="Admin not found")
return dbadmin


def get_dbnode(node_id: int, db: Session = Depends(get_db)):
"""Fetch a node by its ID from the database, raising a 404 error if not found."""
dbnode = crud.get_node_by_id(db, node_id)
if not dbnode:
raise HTTPException(status_code=404, detail="Node not found")
return dbnode


def validate_dates(start: Optional[Union[str, datetime]], end: Optional[Union[str, datetime]]) -> bool:
"""Validate if start and end dates are correct and if end is after start."""
try:
if start:
start_date = start if isinstance(start, datetime) else datetime.fromisoformat(start)
else:
start_date = None
if end:
end_date = end if isinstance(end, datetime) else datetime.fromisoformat(end)
if start and end_date < start_date:
if start_date and end_date < start_date:
return False
return True
except ValueError:
return False



def get_user_template(template_id: int, db: Session = Depends(get_db)):
"""Fetch a User Template by its ID, raise 404 if not found."""
dbuser_template = crud.get_user_template(db, template_id)
if not dbuser_template:
raise HTTPException(status_code=404, detail="User Template not found")
return dbuser_template


def get_validated_sub(
token: str,
db: Session = Depends(get_db)
Expand All @@ -70,6 +77,7 @@ def get_validated_sub(

return dbuser


def get_validated_user(
username: str,
admin: Admin = Depends(Admin.get_current),
Expand All @@ -84,6 +92,7 @@ def get_validated_user(

return dbuser


def get_expired_users_list(db: Session, admin: Admin, expired_after: Optional[datetime] = None, expired_before: Optional[datetime] = None):

expired_before = expired_before or datetime.now(timezone.utc)
Expand All @@ -99,4 +108,4 @@ def get_expired_users_list(db: Session, admin: Admin, expired_after: Optional[da
return [
u for u in dbusers
if u.expire and expired_after.timestamp() <= u.expire <= expired_before.timestamp()
]
]
27 changes: 17 additions & 10 deletions app/routers/node.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import time
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from typing import List

from sqlalchemy.exc import IntegrityError
Expand Down Expand Up @@ -45,7 +45,7 @@ def add_node(
new_node: NodeCreate,
bg: BackgroundTasks,
db: Session = Depends(get_db),
admin: Admin = Depends(Admin.check_sudo_admin)
_: Admin = Depends(Admin.check_sudo_admin)
):
"""Add a new node to the database and optionally add it as a host."""
try:
Expand All @@ -64,7 +64,7 @@ def add_node(
@router.get("/node/{node_id}", response_model=NodeResponse)
def get_node(
dbnode: NodeResponse = Depends(get_dbnode),
admin: Admin = Depends(Admin.check_sudo_admin)
_: Admin = Depends(Admin.check_sudo_admin)
):
"""Retrieve details of a specific node by its ID."""
return dbnode
Expand Down Expand Up @@ -141,7 +141,7 @@ async def node_logs(node_id: int, websocket: WebSocket, db: Session = Depends(ge
@router.get("/nodes", response_model=List[NodeResponse])
def get_nodes(
db: Session = Depends(get_db),
admin: Admin = Depends(Admin.check_sudo_admin)
_: Admin = Depends(Admin.check_sudo_admin)
):
"""Retrieve a list of all nodes. Accessible only to sudo admins."""
return crud.get_nodes(db)
Expand All @@ -153,7 +153,7 @@ def modify_node(
bg: BackgroundTasks,
dbnode: NodeResponse = Depends(get_node),
db: Session = Depends(get_db),
admin: Admin = Depends(Admin.check_sudo_admin)
_: Admin = Depends(Admin.check_sudo_admin)
):
"""Update a node's details. Only accessible to sudo admins."""
updated_node = crud.update_node(db, dbnode, modified_node)
Expand All @@ -169,7 +169,7 @@ def modify_node(
def reconnect_node(
bg: BackgroundTasks,
dbnode: NodeResponse = Depends(get_node),
admin: Admin = Depends(Admin.check_sudo_admin)
_: Admin = Depends(Admin.check_sudo_admin)
):
"""Trigger a reconnection for the specified node. Only accessible to sudo admins."""
bg.add_task(xray.operations.connect_node,node_id=dbnode.id)
Expand All @@ -195,14 +195,21 @@ def get_usage(
db: Session = Depends(get_db),
start: datetime = Query(None, example="2024-01-01T00:00:00"),
end: datetime = Query(None, example="2024-01-31T23:59:59"),
admin: Admin = Depends(Admin.check_sudo_admin)
_: Admin = Depends(Admin.check_sudo_admin)
):
"""Retrieve usage statistics for nodes within a specified date range."""
if not validate_dates(start, end):
raise HTTPException(status_code=400, detail="Invalid date range or format")

start_date = start or datetime.utcnow() - timedelta(days=30)
end_date = end or datetime.utcnow()
usages = crud.get_nodes_usage(db, start_date, end_date)
if not start:
start = datetime.now(timezone.utc) - timedelta(days=30)
else:
start = datetime.fromisoformat(start).astimezone(timezone.utc)
if not end:
end = datetime.now(timezone.utc)
else:
end = datetime.fromisoformat(end).astimezone(timezone.utc)

usages = crud.get_nodes_usage(db, start, end)

return {"usages": usages}
18 changes: 12 additions & 6 deletions app/routers/subscription.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from distutils.version import LooseVersion

from fastapi import Depends, Header, HTTPException, Path, Request, Response, APIRouter, Query
Expand Down Expand Up @@ -139,18 +139,24 @@ def user_subscription_info(
@router.get("/{token}/usage")
def user_get_usage(
dbuser: UserResponse = Depends(get_validated_sub),
start: datetime = Query(None, example="2024-01-01T00:00:00"),
end: datetime = Query(None, example="2024-01-31T23:59:59"),
start: str = "",
end: str = "",
db: Session = Depends(get_db)
):
"""Fetches the usage statistics for the user within a specified date range."""
if not validate_dates(start, end):
raise HTTPException(status_code=400, detail="Invalid date range or format")

start_date = start or datetime.utcnow() - timedelta(days=30)
end_date = end or datetime.utcnow()
if not start:
start = datetime.now(timezone.utc) - timedelta(days=30)
else:
start = datetime.fromisoformat(start).astimezone(timezone.utc)
if not end:
end = datetime.now(timezone.utc)
else:
end = datetime.fromisoformat(end).astimezone(timezone.utc)

usages = crud.get_user_usages(db, dbuser, start_date, end_date)
usages = crud.get_user_usages(db, dbuser, start, end)

return {"usages": usages, "username": dbuser.username}

Expand Down
46 changes: 29 additions & 17 deletions app/routers/user.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from typing import List, Union, Optional

from sqlalchemy.exc import IntegrityError
Expand All @@ -14,6 +14,7 @@

router = APIRouter(tags=['User'], prefix='/api')


@router.post("/user", response_model=UserResponse)
def add_user(
new_user: UserCreate,
Expand Down Expand Up @@ -155,7 +156,6 @@ def remove_user(
return {"detail": "User successfully deleted"}



@router.post("/user/{username}/reset", response_model=UserResponse)
def reset_user_data_usage(
bg: BackgroundTasks,
Expand Down Expand Up @@ -258,26 +258,32 @@ def reset_users_data_usage(
@router.get("/user/{username}/usage", response_model=UserUsagesResponse)
def get_user_usage(
dbuser: UserResponse = Depends(get_validated_user),
start: str = None,
end: str = None,
db: Session = Depends(get_db),
admin: Admin = Depends(Admin.get_current)
start: str = "",
end: str = "",
db: Session = Depends(get_db)
):
"""Get users usage"""
if not validate_dates(start, end):
raise HTTPException(status_code=400, detail="Invalid date range or format")

start_date = start or datetime.utcnow() - timedelta(days=30)
end_date = end or datetime.utcnow()
if not start:
start = datetime.now(timezone.utc) - timedelta(days=30)
else:
start = datetime.fromisoformat(start).astimezone(timezone.utc)
if not end:
end = datetime.now(timezone.utc)
else:
end = datetime.fromisoformat(end).astimezone(timezone.utc)

usages = crud.get_user_usages(db, dbuser, start_date, end_date)
usages = crud.get_user_usages(db, dbuser, start, end)

return {"usages": usages, "username": dbuser.username}


@router.get("/users/usage", response_model=UsersUsagesResponse)
def get_users_usage(
start: datetime = Query(None, example="2024-01-01T00:00:00"),
end: datetime = Query(None, example="2024-01-31T23:59:59"),
start: str = "",
end: str = "",
db: Session = Depends(get_db),
owner: Union[List[str], None] = Query(None, alias="admin"),
admin: Admin = Depends(Admin.get_current)
Expand All @@ -286,13 +292,19 @@ def get_users_usage(
if not validate_dates(start, end):
raise HTTPException(status_code=400, detail="Invalid date range or format")

start_date = start or datetime.utcnow() - timedelta(days=30)
end_date = end or datetime.utcnow()
if not start:
start = datetime.now(timezone.utc) - timedelta(days=30)
else:
start = datetime.fromisoformat(start).astimezone(timezone.utc)
if not end:
end = datetime.now(timezone.utc)
else:
end = datetime.fromisoformat(end).astimezone(timezone.utc)

usages = crud.get_all_users_usages(
db=db,
start=start_date,
end=end_date,
start=start,
end=end,
admin=owner if admin.is_sudo else [admin.username]
)

Expand Down Expand Up @@ -336,7 +348,7 @@ def get_expired_users(
- If both are omitted, returns all expired users
"""

if not validate_dates(expired_after, expired_before, allow_both_none=True):
if not validate_dates(expired_after, expired_before):
raise HTTPException(status_code=400, detail="Invalid date range or format")

expired_users = get_expired_users_list(db, admin, expired_after, expired_before)
Expand Down Expand Up @@ -373,4 +385,4 @@ def delete_expired_users(
logger.info(f"User \"{removed_user}\" deleted")
bg.add_task(report.user_deleted, username=removed_user, user_admin=next((u.admin for u in expired_users if u.username == removed_user), None), by=admin)

return removed_users
return removed_users

0 comments on commit 68dab4c

Please sign in to comment.