Skip to content

Commit

Permalink
feat(backend): routes processing work
Browse files Browse the repository at this point in the history
  • Loading branch information
Pavel7004 committed Nov 21, 2024
1 parent 2b5bac2 commit a6d34f2
Showing 1 changed file with 100 additions and 54 deletions.
154 changes: 100 additions & 54 deletions backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,6 @@
driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))


class Route(BaseModel):
id: str
points: List[str]
author: str
popularity_score: float


class Lake(BaseModel):
id: str
name: str
Expand Down Expand Up @@ -110,15 +103,15 @@ def create_user(user_request: NewUserRequest) -> str:


@app.get("/users/{user_id}", response_model=User)
def get_user(user_id: str) -> User:
def get_user(user_id: UUID4) -> User:
query = """
MATCH (u:User {id: $user_id})
RETURN u
"""

with driver.session() as session:
data = session \
.run(query, {"user_id": user_id}) \
.run(query, {"user_id": user_id.hex}) \
.single()['u']

if not data:
Expand Down Expand Up @@ -207,15 +200,15 @@ def create_point(req: CreatePointRequest) -> str:


@app.get("/points/{point_id}", response_model=Point)
def get_point(point_id: str) -> Point:
def get_point(point_id: UUID4) -> Point:
query = """
MATCH (p:Point {id: $point_id})
RETURN p
"""

with driver.session() as session:
data = session \
.run(query, {"point_id": point_id}) \
.run(query, {"point_id": point_id.hex}) \
.single()['p']

if not data:
Expand Down Expand Up @@ -266,36 +259,114 @@ def list_points() -> List[Point]:
return points


# Route Endpoints
class Route(BaseModel):
id: UUID4
point_ids: List[UUID4]
author_id: UUID4
popularity_score: float


class CreateRouteRequest(BaseModel):
point_ids: List[UUID4]
author: UUID4
popularity_score: float


@app.post("/routes/new", response_model=UUID4)
def create_route(req: CreateRouteRequest) -> UUID4:
route = Route(
id=uuid.uuid4(),
point_ids=req.point_ids,
author_id=req.author,
popularity_score=req.popularity_score,
)

query = """
MATCH (a:User {id: $author_id})
WITH a
WHERE a IS NOT NULL
UNWIND $point_ids AS pid
MATCH (p:Point {id: pid})
WITH a, collect(p) AS points
WHERE size(points) = size($point_ids)
CREATE (r:Route {id: $id, popularity_score: $popularity_score})
CREATE (r)-[:CREATED_BY]->(a)
FOREACH (point IN points | CREATE (r)-[:HAS_POINT]->(point))
RETURN r
"""

route_prepared = route.dict()
route_prepared['id'] = route.id.hex
route_prepared['author_id'] = route.author_id.hex
route_prepared['point_ids'] = \
[point_id.hex for point_id in route.point_ids]

data = None
with driver.session() as session:
data = session.run(query, route_prepared).single()

if data is None:
raise HTTPException(status_code=500, detail="Internal error")

return UUID4(data['r']['id'])


@app.get("/routes/{route_id}", response_model=Route)
def get_route(route_id: str):
def get_route(route_id: UUID4) -> Route:
query = """
MATCH (r:Route {id: $route_id})
OPTIONAL MATCH (r)-[:HAS_POINT]->(p:Point)
RETURN r, collect(p.id) as points
OPTIONAL MATCH (r)-[:CREATED_BY]->(u:User)
RETURN r, collect(p.id) as points, u.id as author_id
"""
result = run_query(query, {"route_id": route_id})
if result:
route = result[0]['r']
route['points'] = result[0]['points']
return Route(**route)
return {}

with driver.session() as session:
data = session.run(query, {"route_id": route_id.hex}).single()

if not data or data.get('r') is None:
raise HTTPException(status_code=404, detail="Not found")

route = data['r']
points = data['points']

return Route(
id=UUID4(route['id']),
author_id=UUID4(data['author_id']),
point_ids=[UUID4(point_id) for point_id in points['point_ids']],
popularity_score=float(route['popularity_score']),
)


@app.get("/routes", response_model=List[Route])
def list_routes():
def list_routes() -> List[Route]:
query = """
MATCH (r:Route)
OPTIONAL MATCH (r)-[:HAS_POINT]->(p:Point)
WITH r, collect(p.id) as points
RETURN r, points
OPTIONAL MATCH (r)-[:CREATED_BY]->(u:User)
WITH r, collect(p.id) as points, u.id as author_id
RETURN r, points, author_id
"""
result = run_query(query)
routes = []
for record in result:
route = record['r']
route['points'] = record['points']
routes.append(Route(**route))

routes: List[Route] = []
with driver.session() as session:
for record in session.run(query):
route = record['r']
points = record['points']
author = record['author_id']

routes.append(
Route(
id=UUID4(route['id']),
author_id=UUID4(author),
point_ids=[UUID4(point) for point in points],
popularity_score=float(route['popularity_score']),
))

if len(routes) == 0:
raise HTTPException(status_code=404, detail="No routes found")

return routes


Expand Down Expand Up @@ -361,31 +432,6 @@ def list_support_requests():
return support_requests


@app.post("/routes/new", response_model=Route)
def create_route(route: Route):
query = """
CREATE (r:Route {
id: $id,
author: $author,
popularity_score: $popularity_score
})
WITH r
UNWIND $points as point_id
MATCH (p:Point {id: point_id})
CREATE (r)-[:HAS_POINT]->(p)
RETURN r
"""
parameters = route.dict()
parameters['points'] = route.points
result = run_query(query, parameters)
if result:
route_data = result[0]['r']
route_data['points'] = route.points
return Route(**route_data)
else:
raise HTTPException(status_code=500, detail="Failed to create route")


@app.post("/lakes/new", response_model=Lake)
def create_lake(lake: Lake):
query = """
Expand Down

0 comments on commit a6d34f2

Please sign in to comment.