diff --git a/src/gitingest/query_parser.py b/src/gitingest/query_parser.py index 05925b3..809070d 100644 --- a/src/gitingest/query_parser.py +++ b/src/gitingest/query_parser.py @@ -4,6 +4,7 @@ import re import string import uuid +import warnings from pathlib import Path from typing import Any from urllib.parse import unquote, urlparse @@ -11,7 +12,7 @@ from config import TMP_BASE_PATH from gitingest.exceptions import InvalidPatternError from gitingest.ignore_patterns import DEFAULT_IGNORE_PATTERNS -from gitingest.repository_clone import _check_repo_exists +from gitingest.repository_clone import _check_repo_exists, fetch_remote_branch_list HEX_DIGITS: set[str] = set(string.hexdigits) @@ -169,19 +170,48 @@ async def _parse_repo_source(source: str) -> dict[str, Any]: parsed["type"] = possible_type # Commit or branch - commit_or_branch = remaining_parts.pop(0) + commit_or_branch = remaining_parts[0] if _is_valid_git_commit_hash(commit_or_branch): parsed["commit"] = commit_or_branch + parsed["subpath"] += "/".join(remaining_parts[1:]) else: - parsed["branch"] = commit_or_branch - - # Subpath if anything left - if remaining_parts: + parsed["branch"] = await _configure_branch_and_subpath(remaining_parts, url) parsed["subpath"] += "/".join(remaining_parts) - return parsed +async def _configure_branch_and_subpath(remaining_parts: list[str], url: str) -> str | None: + """ + Configure the branch and subpath based on the remaining parts of the URL. + Parameters + ---------- + remaining_parts : list[str] + The remaining parts of the URL path. + url : str + The URL of the repository. + Returns + ------- + str | None + The branch name if found, otherwise None. + + """ + try: + # Fetch the list of branches from the remote repository + branches: list[str] = await fetch_remote_branch_list(url) + except RuntimeError as e: + warnings.warn(f"Warning: Failed to fetch branch list: {str(e)}") + return remaining_parts.pop(0) + + branch = [] + while remaining_parts: + branch.append(remaining_parts.pop(0)) + branch_name = "/".join(branch) + if branch_name in branches: + return branch_name + + return None + + def _is_valid_git_commit_hash(commit: str) -> bool: """ Validate if the provided string is a valid Git commit hash. diff --git a/src/gitingest/repository_clone.py b/src/gitingest/repository_clone.py index d251a6f..4adfcd9 100644 --- a/src/gitingest/repository_clone.py +++ b/src/gitingest/repository_clone.py @@ -5,7 +5,7 @@ from gitingest.utils import async_timeout -CLONE_TIMEOUT: int = 20 +TIMEOUT: int = 20 @dataclass @@ -34,7 +34,7 @@ class CloneConfig: branch: str | None = None -@async_timeout(CLONE_TIMEOUT) +@async_timeout(TIMEOUT) async def clone_repo(config: CloneConfig) -> tuple[bytes, bytes]: """ Clone a repository to a local path based on the provided configuration. @@ -141,6 +141,30 @@ async def _check_repo_exists(url: str) -> bool: raise RuntimeError(f"Unexpected status code: {status_code}") +@async_timeout(TIMEOUT) +async def fetch_remote_branch_list(url: str) -> list[str]: + """ + Fetch the list of branches from a remote Git repository. + Parameters + ---------- + url : str + The URL of the Git repository to fetch branches from. + Returns + ------- + list[str] + A list of branch names available in the remote repository. + """ + fetch_branches_command = ["git", "ls-remote", "--heads", url] + stdout, _ = await _run_git_command(*fetch_branches_command) + stdout_decoded = stdout.decode() + + return [ + line.split("refs/heads/", 1)[1] + for line in stdout_decoded.splitlines() + if line.strip() and "refs/heads/" in line + ] + + async def _run_git_command(*args: str) -> tuple[bytes, bytes]: """ Execute a Git command asynchronously and captures its output. diff --git a/tests/query_parser/test_query_parser.py b/tests/query_parser/test_query_parser.py index ab9480d..b2187a4 100644 --- a/tests/query_parser/test_query_parser.py +++ b/tests/query_parser/test_query_parser.py @@ -1,6 +1,7 @@ """ Tests for the query_parser module. """ from pathlib import Path +from unittest.mock import AsyncMock, patch import pytest @@ -109,11 +110,17 @@ async def test_parse_url_with_subpaths() -> None: Verifies that user name, repository name, branch, and subpath are correctly extracted. """ url = "https://github.com/user/repo/tree/main/subdir/file" - result = await _parse_repo_source(url) - assert result["user_name"] == "user" - assert result["repo_name"] == "repo" - assert result["branch"] == "main" - assert result["subpath"] == "/subdir/file" + with patch("gitingest.repository_clone._run_git_command", new_callable=AsyncMock) as mock_run_git_command: + mock_run_git_command.return_value = (b"refs/heads/main\nrefs/heads/dev\nrefs/heads/feature-branch\n", b"") + with patch( + "gitingest.repository_clone.fetch_remote_branch_list", new_callable=AsyncMock + ) as mock_fetch_branches: + mock_fetch_branches.return_value = ["main", "dev", "feature-branch"] + result = await _parse_repo_source(url) + assert result["user_name"] == "user" + assert result["repo_name"] == "repo" + assert result["branch"] == "main" + assert result["subpath"] == "/subdir/file" async def test_parse_url_invalid_repo_structure() -> None: @@ -228,14 +235,20 @@ async def test_parse_url_branch_and_commit_distinction() -> None: url_branch = "https://github.com/user/repo/tree/main" url_commit = "https://github.com/user/repo/tree/abcd1234abcd1234abcd1234abcd1234abcd1234" - result_branch = await _parse_repo_source(url_branch) - result_commit = await _parse_repo_source(url_commit) + with patch("gitingest.repository_clone._run_git_command", new_callable=AsyncMock) as mock_run_git_command: + mock_run_git_command.return_value = (b"refs/heads/main\nrefs/heads/dev\nrefs/heads/feature-branch\n", b"") + with patch( + "gitingest.repository_clone.fetch_remote_branch_list", new_callable=AsyncMock + ) as mock_fetch_branches: + mock_fetch_branches.return_value = ["main", "dev", "feature-branch"] - assert result_branch["branch"] == "main" - assert result_branch["commit"] is None + result_branch = await _parse_repo_source(url_branch) + result_commit = await _parse_repo_source(url_commit) + assert result_branch["branch"] == "main" + assert result_branch["commit"] is None - assert result_commit["branch"] is None - assert result_commit["commit"] == "abcd1234abcd1234abcd1234abcd1234abcd1234" + assert result_commit["branch"] is None + assert result_commit["commit"] == "abcd1234abcd1234abcd1234abcd1234abcd1234" async def test_parse_query_uuid_uniqueness() -> None: @@ -280,3 +293,55 @@ async def test_parse_query_with_branch() -> None: assert result["branch"] == "2.2.x" assert result["commit"] is None assert result["type"] == "blob" + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "url, expected_branch, expected_subpath", + [ + ("https://github.com/user/repo/tree/main/src", "main", "/src"), + ("https://github.com/user/repo/tree/fix1", "fix1", "/"), + ("https://github.com/user/repo/tree/nonexistent-branch/src", "nonexistent-branch", "/src"), + ], +) +async def test_parse_repo_source_with_failed_git_command(url, expected_branch, expected_subpath): + """ + Test `_parse_repo_source` when git command fails. + Verifies that the function returns the first path component as the branch. + """ + with patch("gitingest.repository_clone.fetch_remote_branch_list", new_callable=AsyncMock) as mock_fetch_branches: + mock_fetch_branches.side_effect = Exception("Failed to fetch branch list") + + result = await _parse_repo_source(url) + + assert result["branch"] == expected_branch + assert result["subpath"] == expected_subpath + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "url, expected_branch, expected_subpath", + [ + ("https://github.com/user/repo/tree/feature/fix1/src", "feature/fix1", "/src"), + ("https://github.com/user/repo/tree/main/src", "main", "/src"), + ("https://github.com/user/repo", None, "/"), # No + ("https://github.com/user/repo/tree/nonexistent-branch/src", None, "/"), # Non-existent branch + ("https://github.com/user/repo/tree/fix", "fix", "/"), + ("https://github.com/user/repo/blob/fix/page.html", "fix", "/page.html"), + ], +) +async def test_parse_repo_source_with_various_url_patterns(url, expected_branch, expected_subpath): + with ( + patch("gitingest.repository_clone._run_git_command", new_callable=AsyncMock) as mock_run_git_command, + patch("gitingest.repository_clone.fetch_remote_branch_list", new_callable=AsyncMock) as mock_fetch_branches, + ): + + mock_run_git_command.return_value = ( + b"refs/heads/feature/fix1\nrefs/heads/main\nrefs/heads/feature-branch\nrefs/heads/fix\n", + b"", + ) + mock_fetch_branches.return_value = ["feature/fix1", "main", "feature-branch"] + + result = await _parse_repo_source(url) + assert result["branch"] == expected_branch + assert result["subpath"] == expected_subpath