Skip to content

Commit

Permalink
Merge branch 'main' into duckduckgo
Browse files Browse the repository at this point in the history
  • Loading branch information
maheshmurag authored Nov 20, 2024
2 parents 4f25b1e + 56d2d1b commit a9c0baf
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 78 deletions.
2 changes: 1 addition & 1 deletion src/git/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ classifiers = [
dependencies = [
"click>=8.1.7",
"gitpython>=3.1.43",
"mcp-python~=0.6.0",
"mcp>=0.6.0",
"pydantic>=2.0.0",
]

Expand Down
75 changes: 46 additions & 29 deletions src/git/src/mcp_git/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@
import anyio.lowlevel
from pathlib import Path
from git.types import Sequence
from mcp_python.server import Server
from mcp_python.server.stdio import stdio_server
from mcp_python.types import Tool
from mcp_python.server.types import EmbeddedResource, ImageContent
from mcp.server import Server
from mcp.server.session import ServerSession
from mcp.server.stdio import stdio_server
from mcp.types import TextContent, Tool, EmbeddedResource, ImageContent, ListRootsResult
from enum import StrEnum
import git
from git.objects import Blob, Tree
from mcp_python import ServerSession

from pydantic import BaseModel, Field
from typing import List, Optional
Expand Down Expand Up @@ -167,7 +166,7 @@ async def serve(repository: Path | None) -> None:
return

# Create server
server = Server("git-mcp")
server = Server("mcp-git")

@server.list_tools()
async def list_tools() -> list[Tool]:
Expand Down Expand Up @@ -244,7 +243,7 @@ async def by_roots() -> Sequence[str]:
"server.request_context.session must be a ServerSession"
)

roots_result = await server.request_context.session.list_roots()
roots_result: ListRootsResult = await server.request_context.session.list_roots()
logger.debug(f"Roots result: {roots_result}")
repo_paths = []
for root in roots_result.roots:
Expand All @@ -267,66 +266,84 @@ def by_commandline() -> Sequence[str]:
@server.call_tool()
async def call_tool(
name: str, arguments: dict
) -> Sequence[str | ImageContent | EmbeddedResource]:
) -> list[TextContent | ImageContent | EmbeddedResource]:
if name == GitTools.LIST_REPOS:
return await list_repos()
result = await list_repos()
return [TextContent(type="text", text=str(r)) for r in result]

repo_path = Path(arguments["repo_path"])
repo = git.Repo(repo_path)

match name:
case GitTools.READ_FILE:
return [
git_read_file(
repo, arguments["file_path"], arguments.get("ref", "HEAD")
TextContent(
type="text",
text=git_read_file(
repo, arguments["file_path"], arguments.get("ref", "HEAD")
)
)
]

case GitTools.LIST_FILES:
return [
str(f)
TextContent(type="text", text=str(f))
for f in git_list_files(
repo, arguments.get("path", ""), arguments.get("ref", "HEAD")
)
]

case GitTools.FILE_HISTORY:
return git_file_history(
repo, arguments["file_path"], arguments.get("max_entries", 10)
)
return [
TextContent(type="text", text=entry)
for entry in git_file_history(
repo, arguments["file_path"], arguments.get("max_entries", 10)
)
]

case GitTools.COMMIT:
result = git_commit(repo, arguments["message"], arguments.get("files"))
return [result]
return [TextContent(type="text", text=result)]

case GitTools.SEARCH_CODE:
return git_search_code(
repo,
arguments["query"],
arguments.get("file_pattern", "*"),
arguments.get("ref", "HEAD"),
)
return [
TextContent(type="text", text=result)
for result in git_search_code(
repo,
arguments["query"],
arguments.get("file_pattern", "*"),
arguments.get("ref", "HEAD"),
)
]

case GitTools.GET_DIFF:
return [
git_get_diff(
repo,
arguments["ref1"],
arguments["ref2"],
arguments.get("file_path"),
TextContent(
type="text",
text=git_get_diff(
repo,
arguments["ref1"],
arguments["ref2"],
arguments.get("file_path"),
)
)
]

case GitTools.GET_REPO_STRUCTURE:
return [git_get_repo_structure(repo, arguments.get("ref", "HEAD"))]
return [
TextContent(
type="text",
text=git_get_repo_structure(repo, arguments.get("ref", "HEAD"))
)
]

case _:
raise ValueError(f"Unknown tool: {name}")

# Run the server
options = server.create_initialization_options()
async with stdio_server() as (read_stream, write_stream):
await server.run(read_stream, write_stream, options)
await server.run(read_stream, write_stream, options, raise_exceptions=True)


@click.command()
Expand Down
Loading

0 comments on commit a9c0baf

Please sign in to comment.