Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: handling of branch names with slashes #131

Merged
merged 7 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 36 additions & 7 deletions src/gitingest/query_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from config import TMP_BASE_PATH
from gitingest.exceptions import InvalidPatternError
from gitingest.ignore_patterns import DEFAULT_IGNORE_PATTERNS
from gitingest.repository_clone import _check_repo_exists
from gitingest.repository_clone import _check_repo_exists, fetch_remote_branch_list

HEX_DIGITS: set[str] = set(string.hexdigits)

Expand Down Expand Up @@ -168,18 +168,47 @@ async def _parse_repo_source(source: str) -> dict[str, Any]:
parsed["type"] = possible_type

# Commit or branch
commit_or_branch = remaining_parts.pop(0)
commit_or_branch = remaining_parts[0]
if _is_valid_git_commit_hash(commit_or_branch):
parsed["commit"] = commit_or_branch
else:
parsed["branch"] = commit_or_branch
parsed["subpath"] += "/".join(remaining_parts[1:])

gowthamkishore3799 marked this conversation as resolved.
Show resolved Hide resolved
# Subpath if anything left
if remaining_parts:
else:
parsed["branch"] = await _configure_branch_and_subpath(remaining_parts, url)
parsed["subpath"] += "/".join(remaining_parts)

return parsed

async def _configure_branch_and_subpath(remaining_parts: list[str],url: str) -> str | None:
"""
Find the branch name from the remaining parts of the URL path.
Parameters
----------
remaining_parts : list[str]
List of path parts extracted from the URL.
url : str
The repository URL to determine branches.

Returns
-------
str (branch name) or None

gowthamkishore3799 marked this conversation as resolved.
Show resolved Hide resolved
"""
try:
# Fetch the list of branches from the remote repository
branches: list[str] = await fetch_remote_branch_list(url)
except Exception as e:
gowthamkishore3799 marked this conversation as resolved.
Show resolved Hide resolved
print(f"Warning: Failed to fetch branch list: {str(e)}")
gowthamkishore3799 marked this conversation as resolved.
Show resolved Hide resolved
return remaining_parts.pop(0) if remaining_parts else None
gowthamkishore3799 marked this conversation as resolved.
Show resolved Hide resolved

branch = []

while remaining_parts:
branch.append(remaining_parts.pop(0))
branch_name = "/".join(branch)
if branch_name in branches:
return branch_name

return None

def _is_valid_git_commit_hash(commit: str) -> bool:
"""
Expand Down
30 changes: 28 additions & 2 deletions src/gitingest/repository_clone.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from gitingest.utils import async_timeout

CLONE_TIMEOUT: int = 20
TIMEOUT: int = 20


@dataclass
Expand Down Expand Up @@ -34,7 +34,7 @@ class CloneConfig:
branch: str | None = None


@async_timeout(CLONE_TIMEOUT)
@async_timeout(TIMEOUT)
async def clone_repo(config: CloneConfig) -> tuple[bytes, bytes]:
"""
Clone a repository to a local path based on the provided configuration.
Expand Down Expand Up @@ -141,6 +141,32 @@ async def _check_repo_exists(url: str) -> bool:
raise RuntimeError(f"Unexpected status code: {status_code}")


@async_timeout(TIMEOUT)
async def fetch_remote_branch_list(url: str) -> list[str]:
"""
Get the list of branches from the remote repo.

Parameters
----------
url : str
The URL of the repository.

Returns
-------
list[str]
list of the branches in the remote repository
gowthamkishore3799 marked this conversation as resolved.
Show resolved Hide resolved
"""
fetch_branches_command = ["git", "ls-remote", "--heads", url]
stdout, stderr = await _run_git_command(*fetch_branches_command)
gowthamkishore3799 marked this conversation as resolved.
Show resolved Hide resolved
stdout_decoded = stdout.decode()

return [
line.split('refs/heads/', 1)[1]
for line in stdout_decoded.splitlines()
if line.strip() and 'refs/heads/' in line
]


async def _run_git_command(*args: str) -> tuple[bytes, bytes]:
"""
Execute a Git command asynchronously and captures its output.
Expand Down
72 changes: 59 additions & 13 deletions tests/query_parser/test_query_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from pathlib import Path

import pytest
from unittest.mock import patch, AsyncMock
from gitingest.repository_clone import _check_repo_exists, fetch_remote_branch_list
gowthamkishore3799 marked this conversation as resolved.
Show resolved Hide resolved

from gitingest.ignore_patterns import DEFAULT_IGNORE_PATTERNS
from gitingest.query_parser import _parse_patterns, _parse_repo_source, parse_query
Expand Down Expand Up @@ -96,18 +98,21 @@ async def test_parse_query_invalid_pattern() -> None:
with pytest.raises(ValueError, match="Pattern.*contains invalid characters"):
await parse_query(url, max_file_size=50, from_web=True, include_patterns="*.py;rm -rf")


async def test_parse_url_with_subpaths() -> None:
"""
Test `_parse_repo_source` with a URL containing a branch and subpath.
Verifies that user name, repository name, branch, and subpath are correctly extracted.
"""
url = "https://github.com/user/repo/tree/main/subdir/file"
result = await _parse_repo_source(url)
assert result["user_name"] == "user"
assert result["repo_name"] == "repo"
assert result["branch"] == "main"
assert result["subpath"] == "/subdir/file"
with patch('gitingest.repository_clone._run_git_command', new_callable=AsyncMock) as mock_run_git_command:
mock_run_git_command.return_value = (b"refs/heads/main\nrefs/heads/dev\nrefs/heads/feature-branch\n", b"")
with patch('gitingest.repository_clone.fetch_remote_branch_list', new_callable=AsyncMock) as mock_fetch_branches:
mock_fetch_branches.return_value = ["main", "dev", "feature-branch"]
result = await _parse_repo_source(url)
assert result["user_name"] == "user"
assert result["repo_name"] == "repo"
assert result["branch"] == "main"
assert result["subpath"] == "/subdir/file"


async def test_parse_url_invalid_repo_structure() -> None:
Expand Down Expand Up @@ -222,15 +227,18 @@ async def test_parse_url_branch_and_commit_distinction() -> None:
url_branch = "https://github.com/user/repo/tree/main"
url_commit = "https://github.com/user/repo/tree/abcd1234abcd1234abcd1234abcd1234abcd1234"

result_branch = await _parse_repo_source(url_branch)
result_commit = await _parse_repo_source(url_commit)

assert result_branch["branch"] == "main"
assert result_branch["commit"] is None
with patch('gitingest.repository_clone._run_git_command', new_callable=AsyncMock) as mock_run_git_command:
mock_run_git_command.return_value = (b"refs/heads/main\nrefs/heads/dev\nrefs/heads/feature-branch\n", b"")
with patch('gitingest.repository_clone.fetch_remote_branch_list', new_callable=AsyncMock) as mock_fetch_branches:
mock_fetch_branches.return_value = ["main", "dev", "feature-branch"]

assert result_commit["branch"] is None
assert result_commit["commit"] == "abcd1234abcd1234abcd1234abcd1234abcd1234"
result_branch = await _parse_repo_source(url_branch)
result_commit = await _parse_repo_source(url_commit)
assert result_branch["branch"] == "main"
assert result_branch["commit"] is None

assert result_commit["branch"] is None
assert result_commit["commit"] == "abcd1234abcd1234abcd1234abcd1234abcd1234"

async def test_parse_query_uuid_uniqueness() -> None:
"""
Expand Down Expand Up @@ -274,3 +282,41 @@ async def test_parse_query_with_branch() -> None:
assert result["branch"] == "2.2.x"
assert result["commit"] is None
assert result["type"] == "blob"

@pytest.mark.asyncio
@pytest.mark.parametrize("url, expected_branch, expected_subpath", [
("https://github.com/user/repo/tree/main/src", "main", "/src"),
("https://github.com/user/repo/tree/fix1", "fix1", "/"),
("https://github.com/user/repo/tree/nonexistent-branch/src", "nonexistent-branch", "/src"),
])
async def test_parse_repo_source_with_failed_git_command(url, expected_branch, expected_subpath):
"""
Test `_parse_repo_source` when git command fails.
Verifies that the function returns the first path component as the branch.
"""
with patch('gitingest.repository_clone.fetch_remote_branch_list', new_callable=AsyncMock) as mock_fetch_branches:
mock_fetch_branches.side_effect = Exception("Failed to fetch branch list")

result = await _parse_repo_source(url)

assert result["branch"] == expected_branch
assert result["subpath"] == expected_subpath

@pytest.mark.asyncio
@pytest.mark.parametrize("url, expected_branch, expected_subpath", [
("https://github.com/user/repo/tree/feature/fix1/src", "feature/fix1", "/src"),
("https://github.com/user/repo/tree/main/src", "main", "/src"),
("https://github.com/user/repo", None, "/"), # No
("https://github.com/user/repo/tree/nonexistent-branch/src", None, "/"), # Non-existent branch
("https://github.com/user/repo/tree/fix", "fix", "/"),
])
async def test_parse_repo_source_with_various_url_patterns(url, expected_branch, expected_subpath):
with patch('gitingest.repository_clone._run_git_command', new_callable=AsyncMock) as mock_run_git_command, \
patch('gitingest.repository_clone.fetch_remote_branch_list', new_callable=AsyncMock) as mock_fetch_branches:

mock_run_git_command.return_value = (b"refs/heads/feature/fix1\nrefs/heads/main\nrefs/heads/feature-branch\nrefs/heads/fix\n", b"")
mock_fetch_branches.return_value = ["feature/fix1", "main", "feature-branch"]

result = await _parse_repo_source(url)
assert result["branch"] == expected_branch
assert result["subpath"] == expected_subpath