Skip to content

Commit

Permalink
Remove duplicated code
Browse files Browse the repository at this point in the history
  • Loading branch information
douglaslassance committed Oct 10, 2024
1 parent f8ae761 commit 52877b1
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 108 deletions.
37 changes: 7 additions & 30 deletions gitalong/batch.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,14 @@
import os
import asyncio

from typing import Coroutine, Optional, List
from typing import List, Coroutine

import git
import git.exc

from git import Repo

from .enums import CommitSpread
from .repository import Repository
from .exceptions import RepositoryNotSetup
from .functions import pulled_within, set_read_only


def get_repository_safe(
filename: str,
) -> Optional[Repository]:
"""
Args:
filename (str): A path that belong to the repository including itself.
Returns:
Optional[Repository]: The repository or None.
"""
try:
return Repository(repository=filename, use_cached_instances=True)
except git.exc.InvalidGitRepositoryError:
return None
except RepositoryNotSetup:
return None
from .functions import set_read_only


async def get_files_last_commits(
Expand All @@ -47,7 +26,7 @@ async def get_files_last_commits(
for filename in filenames:
last_commit = {}

repository = get_repository_safe(filename)
repository = Repository.from_filename(filename)
if not repository:
last_commits.append(last_commit)
continue
Expand Down Expand Up @@ -92,19 +71,17 @@ async def get_files_last_commits(

if not last_commit:
pull_threshold = repository.config.get("pull_threshold", 60)
repo = Repo(repository.working_dir)
if not pulled_within(repo, pull_threshold):
if not repository.pulled_within(pull_threshold):
try:
git.Repo(repository.working_dir).remotes.origin.fetch(prune=prune)
repository.remote.fetch(prune=prune)
except git.exc.GitCommandError:
pass

args = ["--all", "--remotes", '--pretty=format:"%H"', "--", filename]
output = repo.git.log(*args)
output = repository.git.log(*args)
file_commits = output.replace('"', "").split("\n") if output else []
sha = file_commits[0] if file_commits else ""
last_commit = repo.commit(sha) or {}
last_commit = repository.get_commit(sha) or {}

last_commits.append(last_commit)

Expand Down Expand Up @@ -257,7 +234,7 @@ async def claim_files(
for filename in filenames:
last_commits = await get_files_last_commits([filename], prune=prune)
last_commit = last_commits[0]
repository = get_repository_safe(filename)
repository = Repository.from_filename(filename)
config = repository.config if repository else {}
modify_permissions = config.get("modify_permissions")
spread = repository.get_commit_spread(last_commit) if repository else 0
Expand Down
46 changes: 14 additions & 32 deletions gitalong/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,8 @@
from .__info__ import __version__

from .enums import CommitSpread
from .exceptions import RepositoryNotSetup
from .repository import Repository
from .batch import get_files_last_commits, claim_files, get_repository_safe


def get_repository(filename: str) -> Repository:
"""
Args:
filename (str): A path that belong to the repository including itself.
Returns:
Repository: The repository.
"""
try:
# Initializing Gitalong for each file allows to handle files from multiple
# repository. This is especially import to support submodules.
return Repository(repository=filename, use_cached_instances=True)
except git.exc.InvalidGitRepositoryError:
click.echo("fatal: not a git repository")
raise click.Abort() # pylint: disable=raise-missing-from
except RepositoryNotSetup:
click.echo("fatal: not a gitalong repository")
raise click.Abort() # pylint: disable=raise-missing-from
from .batch import get_files_last_commits, claim_files


def get_status_string(filename: str, commit: dict, spread: int) -> str:
Expand Down Expand Up @@ -80,8 +59,8 @@ def version(): # pylint: disable=missing-function-docstring
)
@click.pass_context
def config(ctx, prop): # pylint: disable=missing-function-docstring
repository = get_repository(ctx.obj.get("REPOSITORY", ""))
repository_config = repository.config
repository = Repository.from_filename(ctx.obj.get("REPOSITORY", ""))
repository_config = repository.config if repository else {}
prop = prop.replace("-", "_")
if prop in repository_config:
value = repository_config[prop]
Expand All @@ -103,19 +82,22 @@ def config(ctx, prop): # pylint: disable=missing-function-docstring
@click.pass_context
def update(ctx, repository):
"""TODO: Improve error handling."""
repository = get_repository(ctx.obj.get("REPOSITORY", ""))
root = repository.working_dir if repository else ""
repository = Repository.from_filename(ctx.obj.get("REPOSITORY", ""))
if not repository:
return
working_dir = repository.working_dir
repository.update_tracked_commits()
locally_changed = {}
permission_changes = []
if repository.config.get("modify_permissions"):
# TODO: This is a very expensive operation and needs to be optimized.
# TODO: This is an expensive operation and needs to be optimized.
# Also probably should not be done here at the CLI level.
for filename in repository.files:
if os.path.isfile(repository.get_absolute_path(filename)):
if root not in locally_changed:
locally_changed[root] = repository.locally_changed_files
if working_dir not in locally_changed:
locally_changed[working_dir] = repository.locally_changed_files
perm_change = repository.update_file_permissions(
filename, locally_changed[root]
filename, locally_changed[working_dir]
)
if perm_change:
permission_changes.append(f"{' '.join(perm_change)}")
Expand Down Expand Up @@ -157,7 +139,7 @@ def run_status(ctx, filename): # pylint: disable=missing-function-docstring
file_status = []
commits = asyncio.run(get_files_last_commits(filename))
for _filename, commit in zip(filename, commits):
repository = get_repository_safe(ctx.obj.get("REPOSITORY", _filename))
repository = Repository.from_filename(ctx.obj.get("REPOSITORY", _filename))
absolute_filename = (
repository.get_absolute_path(_filename) if repository else _filename
)
Expand All @@ -182,7 +164,7 @@ def claim(ctx, filename): # pylint: disable=missing-function-docstring
statuses = []
blocking_commits = asyncio.run(claim_files(filename))
for _filename, commit in zip(filename, blocking_commits):
repository = get_repository_safe(ctx.obj.get("REPOSITORY", _filename))
repository = Repository.from_filename(ctx.obj.get("REPOSITORY", _filename))
absolute_filename = (
repository.get_absolute_path(_filename) if repository else _filename
)
Expand Down
1 change: 1 addition & 0 deletions gitalong/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from git.repo import Repo


MOVE_STRING_REGEX = re.compile("{(.*)}")


Expand Down
100 changes: 56 additions & 44 deletions gitalong/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
import shutil
import socket
import asyncio

from typing import Optional, List

Expand All @@ -15,6 +16,9 @@

from git.repo import Repo

# Deliberatedly import the module to avoid circular imports.
from . import batch

from .store import Store
from .enums import CommitSpread
from .stores.git_store import GitStore
Expand All @@ -24,7 +28,7 @@
get_real_path,
is_binary_file,
set_read_only,
get_filenames_from_move_string,
pulled_within,
)


Expand Down Expand Up @@ -157,6 +161,22 @@ def setup(
gitalong.install_hooks()
return gitalong

@classmethod
def from_filename(cls, filename: str) -> Optional["Repository"]:
"""
Args:
filename (str): A path that belong to the repository including itself.
Returns:
Optional[Repository]: The repository or None.
"""
try:
return cls(repository=filename, use_cached_instances=True)
except git.exc.InvalidGitRepositoryError:
return None
except RepositoryNotSetup:
return None

@staticmethod
def _write_config_file(config: dict, path: str):
with open(path, "w", encoding="utf8") as config_file:
Expand Down Expand Up @@ -410,17 +430,18 @@ def _accumulate_local_only_commits(self, start: git.Commit, local_commits: list)
start (git.objects.Commit):
The commit that we start peeling from last commit.
"""
# TODO: Maybe there is a way to get this information using pure Python.
if self._managed_repository.git.branch("--remotes", "--contains", start.hexsha):
return
# TODO: This is an expensive thing to call inividualy.
commit_dict = self.get_commit_dict(start)
commit_dict.update(self.context_dict)
# TODO: This is an expensive thing to call inividualy.
commit_dict["branches"] = {"local": self.get_commit_branches(start.hexsha)}
# TODO: Maybe we should compare the SHA here.
if commit_dict not in local_commits:
local_commits.append(commit_dict)
# TODO: These call to batch functions are expensive for a single file.
commits = asyncio.run(batch.get_commits_dicts([start]))
commit = commits[0] if commits else {}
commit.update(self.context_dict)
branches_list = asyncio.run(batch.get_commits_branches([commit]))
branches = branches_list[0] if branches_list else []
commit["branches"] = {"local": branches}
# Maybe we should compare the SHA here.
if commit not in local_commits:
local_commits.append(commit)
for parent in start.parents:
self._accumulate_local_only_commits(parent, local_commits)

Expand Down Expand Up @@ -469,7 +490,6 @@ def _uncommitted_changes(self) -> list:
Returns:
list: A list of unique relative filenames that feature uncommitted changes.
"""
# TODO: Maybe there is a way to get this information using pure Python.
git_cmd = self._managed_repository.git
output = git_cmd.ls_files("--exclude-standard", "--others")
untracked_changes = output.split("\n") if output else []
Expand Down Expand Up @@ -505,8 +525,7 @@ def files(self) -> list:
"""
git_cmd = self._managed_repository.git
try:
# TODO: HEAD might not be safe here since user could checkout an earlier
# commit.
# TODO: HEAD might not be safe. The user could checkout an earlier commit.
filenames = git_cmd.ls_tree(full_tree=True, name_only=True, r="HEAD")
return filenames.split("\n")
except git.exc.GitCommandError:
Expand Down Expand Up @@ -604,45 +623,38 @@ def _get_updated_tracked_commits(self, claims: Optional[List[str]] = None) -> li
tracked_commits.append(commit)
return tracked_commits

def get_commit_dict(self, commit: git.Commit) -> dict:
def pulled_within(self, seconds: float) -> bool:
"""
Args:
commit (git.objects.Commit): The commit to get as a dict.
seconds (float): Time in seconds since last push.
Returns:
dict: A simplified JSON serializable dict that represents the commit.
TYPE: Whether the repository pulled within the time provided.
"""
changes = []
for change in list(commit.stats.files.keys()):
changes += get_filenames_from_move_string(str(change))
return {
"sha": commit.hexsha,
"remote": self._remote.url,
"changes": changes,
"date": str(commit.committed_datetime),
"author": commit.author.name,
}
return pulled_within(self._managed_repository, seconds)

def get_commit_branches(self, sha: str, remote: bool = False) -> list:
def log(self, message: str):
"""Logs a message to the managed repository.
Args:
message (str): The message to log.
"""
self._managed_repository.git.log(message)

def get_commit(self, sha: str) -> git.Commit:
"""
Args:
sha (str): The sha of the commit to check for.
remote (bool, optional): Whether we should return local or remote branches.
sha (str): The SHA of the commit to get.
Returns:
list: A list of branch names that this commit is living on.
git.Commit: The commit object for the provided SHA.
"""
args = ["--remote" if remote else []]
args += ["--contains", sha]
try:
branches = self._managed_repository.git.branch(*args)
# If the commit is not on any branch we get a git.exc.GitCommandError.
except git.exc.GitCommandError:
return []
branches = branches.replace("*", "")
branches = branches.replace(" ", "")
branches = branches.split("\n") if branches else []
branch_names = set()
for branch in branches:
branch_names.add(branch.split("/")[-1])
return list(branch_names)
return self._managed_repository.commit(sha)

@property
def git(self) -> git.Git:
"""
Returns:
git.cmd.Git: The Git command line interface for the managed repository.
"""
return self._managed_repository.git
4 changes: 2 additions & 2 deletions gitalong/stores/git_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def commits(self) -> typing.List[dict]:
remote = store_repository.remote()
pull_threshold = self._managed_repository.config.get("pull_threshold", 60)
if not pulled_within(store_repository, pull_threshold) and remote.refs:
# TODO: If we could check that a pull is already happening then we could
# avoid this try except and save time.
# TODO: We could check that a pull is already happening thus avoiding this
# try except and save time.
try:
remote.pull(
ff=True,
Expand Down

0 comments on commit 52877b1

Please sign in to comment.