Skip to content

Commit

Permalink
Merge pull request #4 from StreetLamb/skills
Browse files Browse the repository at this point in the history
Enable members to use skills
  • Loading branch information
StreetLamb authored Apr 26, 2024
2 parents ddcd2a5 + 3ffb8ea commit 315eb45
Show file tree
Hide file tree
Showing 30 changed files with 668 additions and 252 deletions.
3 changes: 2 additions & 1 deletion backend/app/api/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from fastapi import APIRouter

from app.api.routes import items, login, members, teams, users, utils
from app.api.routes import items, login, members, skills, teams, users, utils

api_router = APIRouter()
api_router.include_router(login.router, tags=["login"])
Expand All @@ -11,3 +11,4 @@
api_router.include_router(
members.router, prefix="/teams/{team_id}/members", tags=["members"]
)
api_router.include_router(skills.router, prefix="/skills", tags=["skills"])
34 changes: 27 additions & 7 deletions backend/app/api/routes/members.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any

from fastapi import APIRouter, Depends, HTTPException
from sqlmodel import func, select
from sqlmodel import col, func, select

from app.api.deps import CurrentUser, SessionDep
from app.models import (
Expand All @@ -11,18 +11,32 @@
MembersOut,
MemberUpdate,
Message,
Skill,
Team,
)

router = APIRouter()


def validate_unique_name_in_team(
session: SessionDep, team_id: int, id: int, member_in: MemberCreate | MemberUpdate
def check_duplicate_names_on_create(
session: SessionDep, team_id: int, member_in: MemberCreate
):
"""Check if (name, team_id) is unique"""
statement = select(Member).where(
Member.name == member_in.name,
Member.belongs_to == team_id,
)
member_unique = session.exec(statement).first()
if member_unique:
raise HTTPException(
status_code=400, detail="Member with this name already exists"
)


def check_duplicate_names_on_update(
session: SessionDep, team_id: int, member_in: MemberUpdate, id: int
):
"""Check if (name, team_id) is unique"""
if member_in.name is None:
return
statement = select(Member).where(
Member.name == member_in.name,
Member.belongs_to == team_id,
Expand Down Expand Up @@ -112,7 +126,7 @@ def create_member(
current_user: CurrentUser,
team_id: int,
member_in: MemberCreate,
_: bool = Depends(validate_unique_name_in_team),
_: bool = Depends(check_duplicate_names_on_create),
) -> Any:
"""
Create new member.
Expand All @@ -136,7 +150,7 @@ def update_member(
team_id: int,
id: int,
member_in: MemberUpdate,
_: bool = Depends(validate_unique_name_in_team),
_: bool = Depends(check_duplicate_names_on_update),
) -> Any:
"""
Update a member.
Expand All @@ -163,6 +177,12 @@ def update_member(
if not member:
raise HTTPException(status_code=404, detail="Member not found")

# update member's skills if required
if member_in.skills is not None:
skill_ids = [skill.id for skill in member_in.skills]
skills = session.exec(select(Skill).where(col(Skill.id).in_(skill_ids))).all()
member.skills = skills

update_dict = member_in.model_dump(exclude_unset=True)
member.sqlmodel_update(update_dict)
session.add(member)
Expand Down
38 changes: 38 additions & 0 deletions backend/app/api/routes/skills.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from typing import Any

from fastapi import APIRouter, HTTPException
from sqlmodel import func, select

from app.api.deps import SessionDep
from app.models import (
Skill,
SkillOut,
SkillsOut,
)

router = APIRouter()


@router.get("/", response_model=SkillsOut)
def read_skills(session: SessionDep, skip: int = 0, limit: int = 100) -> Any:
"""
Retrieve skills.
"""
count_statement = select(func.count()).select_from(Skill)
count = session.exec(count_statement).one()

statement = select(Skill).offset(skip).limit(limit)
skills = session.exec(statement).all()

return SkillsOut(data=skills, count=count)


@router.get("/{id}", response_model=SkillOut)
def read_skill(session: SessionDep, id: int) -> Any:
"""
Get skill by ID.
"""
skill = session.get(Skill, id)
if not skill:
raise HTTPException(status_code=404, detail="Skill not found")
return skill
27 changes: 21 additions & 6 deletions backend/app/api/routes/teams.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,24 @@
router = APIRouter()


async def validate_unique_name(session: SessionDep, team_in: TeamCreate | TeamUpdate):
async def check_duplicate_name_on_create(session: SessionDep, team_in: TeamCreate):
"""Validate that team name is unique"""
if team_in.name is None:
return
statement = select(Team).where(Team.name == team_in.name)
team = session.exec(statement).first()
if team:
raise HTTPException(status_code=400, detail="Team name already exists")


async def check_duplicate_name_on_update(
session: SessionDep, team_in: TeamUpdate, id: int
):
"""Validate that team name is unique"""
statement = select(Team).where(Team.name == team_in.name, Team.id != id)
team = session.exec(statement).first()
if team:
raise HTTPException(status_code=400, detail="Team name already exists")


@router.get("/", response_model=TeamsOut)
def read_teams(
session: SessionDep, current_user: CurrentUser, skip: int = 0, limit: int = 100
Expand Down Expand Up @@ -121,7 +129,7 @@ def create_team(
session: SessionDep,
current_user: CurrentUser,
team_in: TeamCreate,
_: bool = Depends(validate_unique_name),
_: bool = Depends(check_duplicate_name_on_create),
) -> Any:
"""
Create new team and it's team leader
Expand Down Expand Up @@ -155,7 +163,7 @@ def update_team(
current_user: CurrentUser,
id: int,
team_in: TeamUpdate,
_: bool = Depends(validate_unique_name),
_: bool = Depends(check_duplicate_name_on_update),
) -> Any:
"""
Update a team.
Expand Down Expand Up @@ -195,12 +203,19 @@ async def stream(
"""
Stream a response to a user's input.
"""
# Get team and join members and skills
team = session.get(Team, id)
if not team:
raise HTTPException(status_code=404, detail="Team not found")
if not current_user.is_superuser and (team.owner_id != current_user.id):
raise HTTPException(status_code=400, detail="Not enough permissions")

# Populate the skills for each member
members = team.members
for member in members:
member.skills = member.skills

return StreamingResponse(
generator(team, team.members, team_chat.messages),
generator(team, members, team_chat.messages),
media_type="text/event-stream",
)
19 changes: 18 additions & 1 deletion backend/app/core/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from app import crud
from app.core.config import settings
from app.models import User, UserCreate
from app.core.graph.skills import all_skills
from app.models import Skill, User, UserCreate

engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI))

Expand Down Expand Up @@ -32,3 +33,19 @@ def init_db(session: Session) -> None:
is_superuser=True,
)
user = crud.create_user(session=session, user_create=user_in)

existing_skills = session.exec(select(Skill)).all()
existing_skills_dict = {skill.name: skill for skill in existing_skills}

for skill_name, skill_info in all_skills.items():
if skill_name in existing_skills_dict:
existing_skill = existing_skills_dict[skill_name]
if existing_skill.description != skill_info.description:
# Update the existing skill's description
existing_skill.description = skill_info.description
session.add(existing_skill) # Mark the modified object for saving
else:
new_skill = Skill(name=skill_name, description=skill_info.description)
session.add(new_skill) # Prepare new skill for addition to the database

session.commit()
3 changes: 2 additions & 1 deletion backend/app/core/graph/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def convert_team_to_dict(team: Team, members: list[MemberModel]):
"name": member_name,
"backstory": member.backstory or "",
"role": member.role,
"tools": [],
"tools": [skill.name for skill in member.skills],
}

for nei_id in out_counts[member_id]:
Expand Down Expand Up @@ -173,6 +173,7 @@ async def generator(team: Team, members: list[Member], messages: list[ChatMessag
for message in messages
]

# TODO: Figure out how to use async_stream to stream responses from subgraphs
async for output in root.astream(
{
"messages": messages,
Expand Down
6 changes: 3 additions & 3 deletions backend/app/core/graph/members.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field

from app.core.graph.tools import all_tools
from app.core.graph.skills import all_skills


class Person(BaseModel):
Expand Down Expand Up @@ -95,10 +95,10 @@ def create_agent(
self, llm: ChatOpenAI, prompt: ChatPromptTemplate, tools: list[str]
):
"""Create the agent executor"""
tools = [all_tools[tool] for tool in tools]
tools = [all_skills[tool].tool for tool in tools]
# Tools cannot be empty, add a placeholder
if len(tools) < 1:
tools = [all_tools["nothing"]]
tools = [all_skills["nothing"].tool]
agent = create_openai_functions_agent(llm, tools, prompt)
executor = AgentExecutor(agent=agent, tools=tools)
return executor
Expand Down
29 changes: 29 additions & 0 deletions backend/app/core/graph/skills.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from collections.abc import Callable

from langchain_community.tools import DuckDuckGoSearchRun, WikipediaQueryRun
from langchain_community.utilities import WikipediaAPIWrapper
from langchain_core.tools import tool
from pydantic import BaseModel


class SkillInfo(BaseModel):
description: str
tool: Callable


@tool
def nothing(query: str) -> str:
"""Placeholder Tool. Does nothing"""
return ""


all_skills: dict[str, SkillInfo] = {
"nothing": SkillInfo(description="Does nothing", tool=nothing),
"search": SkillInfo(
description="Searches the web using Duck Duck Go", tool=DuckDuckGoSearchRun()
),
"wikipedia": SkillInfo(
description="Searches Wikipedia",
tool=WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper()),
),
}
11 changes: 0 additions & 11 deletions backend/app/core/graph/tools.py

This file was deleted.

21 changes: 12 additions & 9 deletions backend/app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ class MemberUpdate(MemberBase):
belongs_to: int | None = None
position_x: float | None = None
position_y: float | None = None
skills: list["Skill"] | None = None


class Member(MemberBase, table=True):
Expand All @@ -223,6 +224,7 @@ class MemberOut(MemberBase):
id: int
belongs_to: int
owner_of: int | None
skills: list["Skill"]


class MembersOut(SQLModel):
Expand All @@ -238,17 +240,18 @@ class SkillBase(SQLModel):
description: str | None = None


class SkillCreate(SkillBase):
name: str


class SkillUpdate(SkillBase):
name: str | None = None
description: str | None = None


class Skill(SkillBase, table=True):
id: int | None = Field(default=None, primary_key=True)
members: list["Member"] = Relationship(
back_populates="skills", link_model=MemberSkillsLink
)


class SkillsOut(SQLModel):
data: list[Skill]
count: int


class SkillOut(SkillBase):
id: int
description: str | None
Loading

0 comments on commit 315eb45

Please sign in to comment.