Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable members to use skills #4

Merged
merged 5 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion backend/app/api/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from fastapi import APIRouter

from app.api.routes import items, login, members, teams, users, utils
from app.api.routes import items, login, members, skills, teams, users, utils

api_router = APIRouter()
api_router.include_router(login.router, tags=["login"])
Expand All @@ -11,3 +11,4 @@
api_router.include_router(
members.router, prefix="/teams/{team_id}/members", tags=["members"]
)
api_router.include_router(skills.router, prefix="/skills", tags=["skills"])
34 changes: 27 additions & 7 deletions backend/app/api/routes/members.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any

from fastapi import APIRouter, Depends, HTTPException
from sqlmodel import func, select
from sqlmodel import col, func, select

from app.api.deps import CurrentUser, SessionDep
from app.models import (
Expand All @@ -11,18 +11,32 @@
MembersOut,
MemberUpdate,
Message,
Skill,
Team,
)

router = APIRouter()


def validate_unique_name_in_team(
session: SessionDep, team_id: int, id: int, member_in: MemberCreate | MemberUpdate
def check_duplicate_names_on_create(
session: SessionDep, team_id: int, member_in: MemberCreate
):
"""Check if (name, team_id) is unique"""
statement = select(Member).where(
Member.name == member_in.name,
Member.belongs_to == team_id,
)
member_unique = session.exec(statement).first()
if member_unique:
raise HTTPException(
status_code=400, detail="Member with this name already exists"
)


def check_duplicate_names_on_update(
session: SessionDep, team_id: int, member_in: MemberUpdate, id: int
):
"""Check if (name, team_id) is unique"""
if member_in.name is None:
return
statement = select(Member).where(
Member.name == member_in.name,
Member.belongs_to == team_id,
Expand Down Expand Up @@ -112,7 +126,7 @@ def create_member(
current_user: CurrentUser,
team_id: int,
member_in: MemberCreate,
_: bool = Depends(validate_unique_name_in_team),
_: bool = Depends(check_duplicate_names_on_create),
) -> Any:
"""
Create new member.
Expand All @@ -136,7 +150,7 @@ def update_member(
team_id: int,
id: int,
member_in: MemberUpdate,
_: bool = Depends(validate_unique_name_in_team),
_: bool = Depends(check_duplicate_names_on_update),
) -> Any:
"""
Update a member.
Expand All @@ -163,6 +177,12 @@ def update_member(
if not member:
raise HTTPException(status_code=404, detail="Member not found")

# update member's skills if required
if member_in.skills is not None:
skill_ids = [skill.id for skill in member_in.skills]
skills = session.exec(select(Skill).where(col(Skill.id).in_(skill_ids))).all()
member.skills = skills

update_dict = member_in.model_dump(exclude_unset=True)
member.sqlmodel_update(update_dict)
session.add(member)
Expand Down
38 changes: 38 additions & 0 deletions backend/app/api/routes/skills.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from typing import Any

from fastapi import APIRouter, HTTPException
from sqlmodel import func, select

from app.api.deps import SessionDep
from app.models import (
Skill,
SkillOut,
SkillsOut,
)

router = APIRouter()


@router.get("/", response_model=SkillsOut)
def read_skills(session: SessionDep, skip: int = 0, limit: int = 100) -> Any:
"""
Retrieve skills.
"""
count_statement = select(func.count()).select_from(Skill)
count = session.exec(count_statement).one()

statement = select(Skill).offset(skip).limit(limit)
skills = session.exec(statement).all()

return SkillsOut(data=skills, count=count)


@router.get("/{id}", response_model=SkillOut)
def read_skill(session: SessionDep, id: int) -> Any:
"""
Get skill by ID.
"""
skill = session.get(Skill, id)
if not skill:
raise HTTPException(status_code=404, detail="Skill not found")
return skill
27 changes: 21 additions & 6 deletions backend/app/api/routes/teams.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,24 @@
router = APIRouter()


async def validate_unique_name(session: SessionDep, team_in: TeamCreate | TeamUpdate):
async def check_duplicate_name_on_create(session: SessionDep, team_in: TeamCreate):
"""Validate that team name is unique"""
if team_in.name is None:
return
statement = select(Team).where(Team.name == team_in.name)
team = session.exec(statement).first()
if team:
raise HTTPException(status_code=400, detail="Team name already exists")


async def check_duplicate_name_on_update(
session: SessionDep, team_in: TeamUpdate, id: int
):
"""Validate that team name is unique"""
statement = select(Team).where(Team.name == team_in.name, Team.id != id)
team = session.exec(statement).first()
if team:
raise HTTPException(status_code=400, detail="Team name already exists")


@router.get("/", response_model=TeamsOut)
def read_teams(
session: SessionDep, current_user: CurrentUser, skip: int = 0, limit: int = 100
Expand Down Expand Up @@ -121,7 +129,7 @@ def create_team(
session: SessionDep,
current_user: CurrentUser,
team_in: TeamCreate,
_: bool = Depends(validate_unique_name),
_: bool = Depends(check_duplicate_name_on_create),
) -> Any:
"""
Create new team and it's team leader
Expand Down Expand Up @@ -155,7 +163,7 @@ def update_team(
current_user: CurrentUser,
id: int,
team_in: TeamUpdate,
_: bool = Depends(validate_unique_name),
_: bool = Depends(check_duplicate_name_on_update),
) -> Any:
"""
Update a team.
Expand Down Expand Up @@ -195,12 +203,19 @@ async def stream(
"""
Stream a response to a user's input.
"""
# Get team and join members and skills
team = session.get(Team, id)
if not team:
raise HTTPException(status_code=404, detail="Team not found")
if not current_user.is_superuser and (team.owner_id != current_user.id):
raise HTTPException(status_code=400, detail="Not enough permissions")

# Populate the skills for each member
members = team.members
for member in members:
member.skills = member.skills

return StreamingResponse(
generator(team, team.members, team_chat.messages),
generator(team, members, team_chat.messages),
media_type="text/event-stream",
)
19 changes: 18 additions & 1 deletion backend/app/core/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from app import crud
from app.core.config import settings
from app.models import User, UserCreate
from app.core.graph.skills import all_skills
from app.models import Skill, User, UserCreate

engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI))

Expand Down Expand Up @@ -32,3 +33,19 @@ def init_db(session: Session) -> None:
is_superuser=True,
)
user = crud.create_user(session=session, user_create=user_in)

existing_skills = session.exec(select(Skill)).all()
existing_skills_dict = {skill.name: skill for skill in existing_skills}

for skill_name, skill_info in all_skills.items():
if skill_name in existing_skills_dict:
existing_skill = existing_skills_dict[skill_name]
if existing_skill.description != skill_info.description:
# Update the existing skill's description
existing_skill.description = skill_info.description
session.add(existing_skill) # Mark the modified object for saving
else:
new_skill = Skill(name=skill_name, description=skill_info.description)
session.add(new_skill) # Prepare new skill for addition to the database

session.commit()
3 changes: 2 additions & 1 deletion backend/app/core/graph/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def convert_team_to_dict(team: Team, members: list[MemberModel]):
"name": member_name,
"backstory": member.backstory or "",
"role": member.role,
"tools": [],
"tools": [skill.name for skill in member.skills],
}

for nei_id in out_counts[member_id]:
Expand Down Expand Up @@ -173,6 +173,7 @@ async def generator(team: Team, members: list[Member], messages: list[ChatMessag
for message in messages
]

# TODO: Figure out how to use async_stream to stream responses from subgraphs
async for output in root.astream(
{
"messages": messages,
Expand Down
6 changes: 3 additions & 3 deletions backend/app/core/graph/members.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field

from app.core.graph.tools import all_tools
from app.core.graph.skills import all_skills


class Person(BaseModel):
Expand Down Expand Up @@ -95,10 +95,10 @@ def create_agent(
self, llm: ChatOpenAI, prompt: ChatPromptTemplate, tools: list[str]
):
"""Create the agent executor"""
tools = [all_tools[tool] for tool in tools]
tools = [all_skills[tool].tool for tool in tools]
# Tools cannot be empty, add a placeholder
if len(tools) < 1:
tools = [all_tools["nothing"]]
tools = [all_skills["nothing"].tool]
agent = create_openai_functions_agent(llm, tools, prompt)
executor = AgentExecutor(agent=agent, tools=tools)
return executor
Expand Down
29 changes: 29 additions & 0 deletions backend/app/core/graph/skills.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from collections.abc import Callable

from langchain_community.tools import DuckDuckGoSearchRun, WikipediaQueryRun
from langchain_community.utilities import WikipediaAPIWrapper
from langchain_core.tools import tool
from pydantic import BaseModel


class SkillInfo(BaseModel):
description: str
tool: Callable


@tool
def nothing(query: str) -> str:
"""Placeholder Tool. Does nothing"""
return ""


all_skills: dict[str, SkillInfo] = {
"nothing": SkillInfo(description="Does nothing", tool=nothing),
"search": SkillInfo(
description="Searches the web using Duck Duck Go", tool=DuckDuckGoSearchRun()
),
"wikipedia": SkillInfo(
description="Searches Wikipedia",
tool=WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper()),
),
}
11 changes: 0 additions & 11 deletions backend/app/core/graph/tools.py

This file was deleted.

21 changes: 12 additions & 9 deletions backend/app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ class MemberUpdate(MemberBase):
belongs_to: int | None = None
position_x: float | None = None
position_y: float | None = None
skills: list["Skill"] | None = None


class Member(MemberBase, table=True):
Expand All @@ -223,6 +224,7 @@ class MemberOut(MemberBase):
id: int
belongs_to: int
owner_of: int | None
skills: list["Skill"]


class MembersOut(SQLModel):
Expand All @@ -238,17 +240,18 @@ class SkillBase(SQLModel):
description: str | None = None


class SkillCreate(SkillBase):
name: str


class SkillUpdate(SkillBase):
name: str | None = None
description: str | None = None


class Skill(SkillBase, table=True):
id: int | None = Field(default=None, primary_key=True)
members: list["Member"] = Relationship(
back_populates="skills", link_model=MemberSkillsLink
)


class SkillsOut(SQLModel):
data: list[Skill]
count: int


class SkillOut(SkillBase):
id: int
description: str | None
Loading
Loading