From d721b00d03f83f038e9c23219bf58ccad147d515 Mon Sep 17 00:00:00 2001 From: Filip Christiansen <22807962+filipchristiansen@users.noreply.github.com> Date: Fri, 17 Jan 2025 16:53:32 +0100 Subject: [PATCH] Replace dict-based query with ParsedQuery dataclass (#133) - Introduce ParsedQuery dataclass to store query parameters and metadata - Update ingestion and parser modules to use ParsedQuery instead of dict[str, Any] - Convert ignore_patterns and include_patterns to sets - Clean references to max size and pattern handling - Update tests to reflect new dataclass usage --- src/config.py | 5 + src/gitingest/cli.py | 8 +- src/gitingest/ignore_patterns.py | 22 +- src/gitingest/query_ingestion.py | 183 +++++++---------- src/gitingest/query_parser.py | 199 +++++++++++-------- src/gitingest/repository_ingest.py | 30 +-- src/main.py | 6 +- src/query_processor.py | 25 ++- tests/conftest.py | 31 +-- tests/query_parser/test_git_host_agnostic.py | 21 +- tests/query_parser/test_query_parser.py | 153 +++++++------- tests/test_query_ingestion.py | 61 +++--- 12 files changed, 369 insertions(+), 375 deletions(-) diff --git a/src/config.py b/src/config.py index 68565c8..7365ab8 100644 --- a/src/config.py +++ b/src/config.py @@ -2,6 +2,11 @@ from pathlib import Path +MAX_FILE_SIZE = 10 * 1024 * 1024 # 10 MB +MAX_DIRECTORY_DEPTH = 20 # Maximum depth of directory traversal +MAX_FILES = 10_000 # Maximum number of files to process +MAX_TOTAL_SIZE_BYTES = 500 * 1024 * 1024 # 500 MB + MAX_DISPLAY_SIZE: int = 300_000 TMP_BASE_PATH = Path("/tmp/gitingest") DELETE_REPO_AFTER: int = 60 * 60 # In seconds diff --git a/src/gitingest/cli.py b/src/gitingest/cli.py index 371263a..ef7761b 100644 --- a/src/gitingest/cli.py +++ b/src/gitingest/cli.py @@ -4,7 +4,7 @@ import click -from gitingest.query_ingestion import MAX_FILE_SIZE +from config import MAX_FILE_SIZE from gitingest.repository_ingest import ingest @@ -49,8 +49,8 @@ async def main( """ try: # Combine default and custom ignore patterns - exclude_patterns = list(exclude_pattern) - include_patterns = list(set(include_pattern)) + exclude_patterns = set(exclude_pattern) + include_patterns = set(include_pattern) if not output: output = "digest.txt" @@ -61,7 +61,7 @@ async def main( click.echo(summary) except Exception as e: - click.echo(f"Error: {str(e)}", err=True) + click.echo(f"Error: {e}", err=True) raise click.Abort() diff --git a/src/gitingest/ignore_patterns.py b/src/gitingest/ignore_patterns.py index a1a902d..90ef210 100644 --- a/src/gitingest/ignore_patterns.py +++ b/src/gitingest/ignore_patterns.py @@ -1,6 +1,6 @@ """ Default ignore patterns for Gitingest. """ -DEFAULT_IGNORE_PATTERNS: list[str] = [ +DEFAULT_IGNORE_PATTERNS: set[str] = { # Python "*.pyc", "*.pyo", @@ -29,18 +29,17 @@ "*.war", "*.ear", "*.nar", - "target/", ".gradle/", "build/", ".settings/", - ".project", ".classpath", "gradle-app.setting", "*.gradle", + # IDEs and editors / Java + ".project", # C/C++ "*.o", "*.obj", - "*.so", "*.dll", "*.dylib", "*.exe", @@ -68,14 +67,13 @@ ".ruby-gemset", ".rvmrc", # Rust - "target/", "Cargo.lock", "**/*.rs.bk", + # Java / Rust + "target/", # Go - "bin/", "pkg/", # .NET/C# - "bin/", "obj/", "*.suo", "*.user", @@ -83,6 +81,8 @@ "*.sln.docstates", "packages/", "*.nupkg", + # Go / .NET / C# + "bin/", # Version control ".git", ".svn", @@ -112,12 +112,9 @@ ".idea", ".vscode", ".vs", - "*.swp", "*.swo", "*.swn", ".settings", - ".project", - ".classpath", "*.sublime-*", # Temporary and cache files "*.log", @@ -140,9 +137,6 @@ "*.egg", "*.whl", "*.so", - "*.dylib", - "*.dll", - "*.class", # Documentation "site-packages", ".docusaurus", @@ -159,4 +153,4 @@ "*.tfstate*", ## Dependencies in various languages "vendor/", -] +} diff --git a/src/gitingest/query_ingestion.py b/src/gitingest/query_ingestion.py index 2e1f292..a6f94d2 100644 --- a/src/gitingest/query_ingestion.py +++ b/src/gitingest/query_ingestion.py @@ -6,6 +6,7 @@ import tiktoken +from config import MAX_DIRECTORY_DEPTH, MAX_FILES, MAX_TOTAL_SIZE_BYTES from gitingest.exceptions import ( AlreadyVisitedError, InvalidNotebookError, @@ -13,14 +14,10 @@ MaxFilesReachedError, ) from gitingest.notebook_utils import process_notebook +from gitingest.query_parser import ParsedQuery -MAX_FILE_SIZE = 10 * 1024 * 1024 # 10 MB -MAX_DIRECTORY_DEPTH = 20 # Maximum depth of directory traversal -MAX_FILES = 10_000 # Maximum number of files to process -MAX_TOTAL_SIZE_BYTES = 500 * 1024 * 1024 # 500 MB - -def _should_include(path: Path, base_path: Path, include_patterns: list[str]) -> bool: +def _should_include(path: Path, base_path: Path, include_patterns: set[str]) -> bool: """ Determine if the given file or directory path matches any of the include patterns. @@ -33,8 +30,8 @@ def _should_include(path: Path, base_path: Path, include_patterns: list[str]) -> The absolute path of the file or directory to check. base_path : Path The base directory from which the relative path is calculated. - include_patterns : list[str] - A list of patterns to check against the relative path. + include_patterns : set[str] + A set of patterns to check against the relative path. Returns ------- @@ -54,7 +51,7 @@ def _should_include(path: Path, base_path: Path, include_patterns: list[str]) -> return False -def _should_exclude(path: Path, base_path: Path, ignore_patterns: list[str]) -> bool: +def _should_exclude(path: Path, base_path: Path, ignore_patterns: set[str]) -> bool: """ Determine if the given file or directory path matches any of the ignore patterns. @@ -68,8 +65,8 @@ def _should_exclude(path: Path, base_path: Path, ignore_patterns: list[str]) -> The absolute path of the file or directory to check. base_path : Path The base directory from which the relative path is calculated. - ignore_patterns : list[str] - A list of patterns to check against the relative path. + ignore_patterns : set[str] + A set of patterns to check against the relative path. Returns ------- @@ -221,7 +218,7 @@ def _sort_children(children: list[dict[str, Any]]) -> list[dict[str, Any]]: def _scan_directory( path: Path, - query: dict[str, Any], + query: ParsedQuery, seen_paths: set[Path] | None = None, depth: int = 0, stats: dict[str, int] | None = None, @@ -237,8 +234,8 @@ def _scan_directory( ---------- path : Path The path of the directory to scan. - query : dict[str, Any] - A dictionary containing the query parameters, such as include and ignore patterns. + query : ParsedQuery + The parsed query object containing information about the repository and query parameters. seen_paths : set[Path] | None, optional A set to track already visited paths, by default None. depth : int @@ -287,23 +284,9 @@ def _scan_directory( "ignore_content": False, } - ignore_patterns = query["ignore_patterns"] - base_path = query["local_path"] - include_patterns = query["include_patterns"] - try: for item in path.iterdir(): - _process_item( - item=item, - query=query, - result=result, - seen_paths=seen_paths, - stats=stats, - depth=depth, - ignore_patterns=ignore_patterns, - base_path=base_path, - include_patterns=include_patterns, - ) + _process_item(item=item, query=query, result=result, seen_paths=seen_paths, stats=stats, depth=depth) except MaxFilesReachedError: print(f"Maximum file limit ({MAX_FILES}) reached.") except PermissionError: @@ -315,13 +298,11 @@ def _scan_directory( def _process_symlink( item: Path, - query: dict[str, Any], + query: ParsedQuery, result: dict[str, Any], seen_paths: set[Path], stats: dict[str, int], depth: int, - base_path: Path, - include_patterns: list[str], ) -> None: """ Process a symlink in the file system. @@ -333,8 +314,8 @@ def _process_symlink( ---------- item : Path The full path of the symlink. - query : dict[str, Any] - The query dictionary containing the parameters. + query : ParsedQuery + The parsed query object containing information about the repository and query parameters. result : dict[str, Any] The dictionary to accumulate the results. seen_paths : set[str] @@ -343,10 +324,6 @@ def _process_symlink( The dictionary to track statistics such as file count and size. depth : int The current depth in the directory traversal. - base_path : Path - The base path used for validation of the symlink. - include_patterns : list[str] - A list of include patterns for file filtering. Raises ------ @@ -357,7 +334,8 @@ def _process_symlink( MaxFilesReachedError If the number of files exceeds the maximum limit. """ - if not _is_safe_symlink(item, base_path): + + if not _is_safe_symlink(item, query.local_path): raise AlreadyVisitedError(str(item)) real_path = item.resolve() @@ -398,7 +376,7 @@ def _process_symlink( depth=depth + 1, stats=stats, ) - if subdir and (not include_patterns or subdir["file_count"] > 0): + if subdir and (not query.include_patterns or subdir["file_count"] > 0): # rename the subdir to reflect the symlink name subdir["name"] = item.name subdir["path"] = str(item) @@ -460,14 +438,11 @@ def _process_file(item: Path, result: dict[str, Any], stats: dict[str, int]) -> def _process_item( item: Path, - query: dict[str, Any], + query: ParsedQuery, result: dict[str, Any], seen_paths: set[Path], stats: dict[str, int], depth: int, - ignore_patterns: list[str], - base_path: Path, - include_patterns: list[str], ) -> None: """ Process a file or directory item within a directory. @@ -479,8 +454,8 @@ def _process_item( ---------- item : Path The full path of the file or directory to process. - query : dict[str, Any] - A dictionary of query parameters, including the base path and patterns. + query : ParsedQuery + The parsed query object containing information about the repository and query parameters. result : dict[str, Any] The result dictionary to accumulate processed file/directory data. seen_paths : set[Path] @@ -489,39 +464,29 @@ def _process_item( A dictionary of statistics like the total file count and size. depth : int The current depth of directory traversal. - ignore_patterns : list[str] - A list of patterns to exclude files or directories. - base_path : Path - The base directory used for relative path calculations. - include_patterns : list[str] - A list of patterns to include files or directories. """ - if _should_exclude(item, base_path, ignore_patterns): + + if not query.ignore_patterns or _should_exclude(item, query.local_path, query.ignore_patterns): return - if item.is_file() and query["include_patterns"] and not _should_include(item, base_path, include_patterns): + if ( + item.is_file() + and query.include_patterns + and not _should_include(item, query.local_path, query.include_patterns) + ): result["ignore_content"] = True return try: if item.is_symlink(): - _process_symlink( - item=item, - query=query, - result=result, - seen_paths=seen_paths, - stats=stats, - depth=depth, - base_path=base_path, - include_patterns=include_patterns, - ) + _process_symlink(item=item, query=query, result=result, seen_paths=seen_paths, stats=stats, depth=depth) if item.is_file(): _process_file(item=item, result=result, stats=stats) elif item.is_dir(): subdir = _scan_directory(path=item, query=query, seen_paths=seen_paths, depth=depth + 1, stats=stats) - if subdir and (not include_patterns or subdir["file_count"] > 0): + if subdir and (not query.include_patterns or subdir["file_count"] > 0): result["children"].append(subdir) result["size"] += subdir["size"] result["file_count"] += subdir["file_count"] @@ -532,9 +497,8 @@ def _process_item( def _extract_files_content( - query: dict[str, Any], + query: ParsedQuery, node: dict[str, Any], - max_file_size: int, files: list[dict[str, Any]] | None = None, ) -> list[dict[str, Any]]: """ @@ -545,12 +509,10 @@ def _extract_files_content( Parameters ---------- - query : dict[str, Any] - A dictionary containing the query parameters, including the base path of the repository. + query : ParsedQuery + The parsed query object containing information about the repository and query parameters. node : dict[str, Any] The current directory or file node being processed. - max_file_size : int - The maximum file size in bytes for which content should be extracted. files : list[dict[str, Any]] | None, optional A list to collect the extracted files' information, by default None. @@ -563,12 +525,12 @@ def _extract_files_content( files = [] if node["type"] == "file" and node["content"] != "[Non-text file]": - if node["size"] > max_file_size: + if node["size"] > query.max_file_size: content = None else: content = node["content"] - relative_path = Path(node["path"]).relative_to(query["local_path"]) + relative_path = Path(node["path"]).relative_to(query.local_path) files.append( { @@ -579,7 +541,7 @@ def _extract_files_content( ) elif node["type"] == "directory": for child in node["children"]: - _extract_files_content(query=query, node=child, max_file_size=max_file_size, files=files) + _extract_files_content(query=query, node=child, files=files) return files @@ -588,7 +550,7 @@ def _create_file_content_string(files: list[dict[str, Any]]) -> str: """ Create a formatted string of file contents with separators. - This function takes a list of files and generates a formatted string where each file’s + This function takes a list of files and generates a formatted string where each file's content is separated by a divider. Parameters @@ -617,7 +579,7 @@ def _create_file_content_string(files: list[dict[str, Any]]) -> str: return output -def _create_summary_string(query: dict[str, Any], nodes: dict[str, Any]) -> str: +def _create_summary_string(query: ParsedQuery, nodes: dict[str, Any]) -> str: """ Create a summary string with file counts and content size. @@ -626,8 +588,8 @@ def _create_summary_string(query: dict[str, Any], nodes: dict[str, Any]) -> str: Parameters ---------- - query : dict[str, Any] - Dictionary containing query parameters like repository name, commit, branch, and subpath. + query : ParsedQuery + The parsed query object containing information about the repository and query parameters. nodes : dict[str, Any] Dictionary representing the directory structure, including file and directory counts. @@ -636,24 +598,24 @@ def _create_summary_string(query: dict[str, Any], nodes: dict[str, Any]) -> str: str Summary string containing details such as repository name, file count, and other query-specific information. """ - if "user_name" in query: - summary = f"Repository: {query['user_name']}/{query['repo_name']}\n" + if query.user_name: + summary = f"Repository: {query.user_name}/{query.repo_name}\n" else: - summary = f"Repository: {query['slug']}\n" + summary = f"Repository: {query.slug}\n" summary += f"Files analyzed: {nodes['file_count']}\n" - if "subpath" in query and query["subpath"] != "/": - summary += f"Subpath: {query['subpath']}\n" - if "commit" in query and query["commit"]: - summary += f"Commit: {query['commit']}\n" - elif "branch" in query and query["branch"] != "main" and query["branch"] != "master" and query["branch"]: - summary += f"Branch: {query['branch']}\n" + if query.subpath != "/": + summary += f"Subpath: {query.subpath}\n" + if query.commit: + summary += f"Commit: {query.commit}\n" + elif query.branch and query.branch not in ("main", "master"): + summary += f"Branch: {query.branch}\n" return summary -def _create_tree_structure(query: dict[str, Any], node: dict[str, Any], prefix: str = "", is_last: bool = True) -> str: +def _create_tree_structure(query: ParsedQuery, node: dict[str, Any], prefix: str = "", is_last: bool = True) -> str: """ Create a tree-like string representation of the file structure. @@ -662,8 +624,8 @@ def _create_tree_structure(query: dict[str, Any], node: dict[str, Any], prefix: Parameters ---------- - query : dict[str, Any] - A dictionary containing query parameters like repository name and subpath. + query : ParsedQuery + The parsed query object containing information about the repository and query parameters. node : dict[str, Any] The current directory or file node being processed. prefix : str @@ -679,7 +641,7 @@ def _create_tree_structure(query: dict[str, Any], node: dict[str, Any], prefix: tree = "" if not node["name"]: - node["name"] = query["slug"] + node["name"] = query.slug if node["name"]: current_prefix = "└── " if is_last else "├── " @@ -729,7 +691,7 @@ def _generate_token_string(context_string: str) -> str | None: return str(total_tokens) -def _ingest_single_file(path: Path, query: dict[str, Any]) -> tuple[str, str, str]: +def _ingest_single_file(path: Path, query: ParsedQuery) -> tuple[str, str, str]: """ Ingest a single file and return its summary, directory structure, and content. @@ -740,8 +702,8 @@ def _ingest_single_file(path: Path, query: dict[str, Any]) -> tuple[str, str, st ---------- path : Path The path of the file to ingest. - query : dict[str, Any] - A dictionary containing query parameters, such as the maximum file size. + query : ParsedQuery + The parsed query object containing information about the repository and query parameters. Returns ------- @@ -760,12 +722,12 @@ def _ingest_single_file(path: Path, query: dict[str, Any]) -> tuple[str, str, st raise ValueError(f"File {path} is not a text file") file_size = path.stat().st_size - if file_size > query["max_file_size"]: + if file_size > query.max_file_size: content = "[Content ignored: file too large]" else: content = _read_file_content(path) - relative_path = path.relative_to(query["local_path"]) + relative_path = path.relative_to(query.local_path) file_info = { "path": str(relative_path), @@ -774,7 +736,7 @@ def _ingest_single_file(path: Path, query: dict[str, Any]) -> tuple[str, str, st } summary = ( - f"Repository: {query['user_name']}/{query['repo_name']}\n" + f"Repository: {query.user_name}/{query.repo_name}\n" f"File: {path.name}\n" f"Size: {file_size:,} bytes\n" f"Lines: {len(content.splitlines()):,}\n" @@ -790,7 +752,7 @@ def _ingest_single_file(path: Path, query: dict[str, Any]) -> tuple[str, str, st return summary, tree, files_content -def _ingest_directory(path: Path, query: dict[str, Any]) -> tuple[str, str, str]: +def _ingest_directory(path: Path, query: ParsedQuery) -> tuple[str, str, str]: """ Ingest an entire directory and return its summary, directory structure, and file contents. @@ -801,8 +763,8 @@ def _ingest_directory(path: Path, query: dict[str, Any]) -> tuple[str, str, str] ---------- path : Path The path of the directory to ingest. - query : dict[str, Any] - A dictionary containing query parameters, including maximum file size. + query : ParsedQuery + The parsed query object containing information about the repository and query parameters. Returns ------- @@ -818,7 +780,7 @@ def _ingest_directory(path: Path, query: dict[str, Any]) -> tuple[str, str, str] if not nodes: raise ValueError(f"No files found in {path}") - files = _extract_files_content(query=query, node=nodes, max_file_size=query["max_file_size"]) + files = _extract_files_content(query=query, node=nodes) summary = _create_summary_string(query, nodes) tree = "Directory structure:\n" + _create_tree_structure(query, nodes) files_content = _create_file_content_string(files) @@ -830,17 +792,18 @@ def _ingest_directory(path: Path, query: dict[str, Any]) -> tuple[str, str, str] return summary, tree, files_content -def run_ingest_query(query: dict[str, Any]) -> tuple[str, str, str]: +def run_ingest_query(query: ParsedQuery) -> tuple[str, str, str]: """ - Main entry point for analyzing a codebase directory or single file. + Run the ingestion process for a parsed query. - This function processes a file or directory based on the provided query, extracting its contents - and generating a summary, directory structure, and file content, along with token estimations. + This is the main entry point for analyzing a codebase directory or single file. It processes the query + parameters, reads the file or directory content, and generates a summary, directory structure, and file content, + along with token estimations. Parameters ---------- - query : dict[str, Any] - A dictionary containing parameters like local path, subpath, file type, etc. + query : ParsedQuery + The parsed query object containing information about the repository and query parameters. Returns ------- @@ -852,11 +815,11 @@ def run_ingest_query(query: dict[str, Any]) -> tuple[str, str, str]: ValueError If the specified path cannot be found or if the file is not a text file. """ - path = query["local_path"] / query["subpath"].lstrip("/") + path = query.local_path / query.subpath.lstrip("/") if not path.exists(): - raise ValueError(f"{query['slug']} cannot be found") + raise ValueError(f"{query.slug} cannot be found") - if query.get("type") == "blob": + if query.type and query.type == "blob": return _ingest_single_file(path, query) return _ingest_directory(path, query) diff --git a/src/gitingest/query_parser.py b/src/gitingest/query_parser.py index 809070d..435a799 100644 --- a/src/gitingest/query_parser.py +++ b/src/gitingest/query_parser.py @@ -5,11 +5,11 @@ import string import uuid import warnings +from dataclasses import dataclass from pathlib import Path -from typing import Any from urllib.parse import unquote, urlparse -from config import TMP_BASE_PATH +from config import MAX_FILE_SIZE, TMP_BASE_PATH from gitingest.exceptions import InvalidPatternError from gitingest.ignore_patterns import DEFAULT_IGNORE_PATTERNS from gitingest.repository_clone import _check_repo_exists, fetch_remote_branch_list @@ -26,19 +26,41 @@ ] +@dataclass +class ParsedQuery: # pylint: disable=too-many-instance-attributes + """ + Dataclass to store the parsed details of the repository or file path. + """ + + user_name: str | None + repo_name: str | None + subpath: str + local_path: Path + url: str | None + slug: str + id: str + type: str | None = None + branch: str | None = None + commit: str | None = None + max_file_size: int = MAX_FILE_SIZE + ignore_patterns: set[str] | None = None + include_patterns: set[str] | None = None + pattern_type: str | None = None + + async def parse_query( source: str, max_file_size: int, from_web: bool, - include_patterns: list[str] | str | None = None, - ignore_patterns: list[str] | str | None = None, -) -> dict[str, Any]: + include_patterns: set[str] | str | None = None, + ignore_patterns: set[str] | str | None = None, +) -> ParsedQuery: """ - Parse the input source to construct a query dictionary with specified parameters. + Parse the input source (URL or path) to extract relevant details for the query. - This function processes the provided source (either a URL or file path) and builds a - query dictionary that includes information such as the source URL, maximum file size, - and any patterns to include or ignore. It handles both web and file-based sources. + This function parses the input source to extract details such as the username, repository name, + commit hash, branch name, and other relevant information. It also processes the include and ignore + patterns to filter the files and directories to include or exclude from the query. Parameters ---------- @@ -48,49 +70,55 @@ async def parse_query( The maximum file size in bytes to include. from_web : bool Flag indicating whether the source is a web URL. - include_patterns : list[str] | str | None, optional - Patterns to include, by default None. Can be a list of strings or a single string. - ignore_patterns : list[str] | str | None, optional - Patterns to ignore, by default None. Can be a list of strings or a single string. + include_patterns : set[str] | str | None, optional + Patterns to include, by default None. Can be a set of strings or a single string. + ignore_patterns : set[str] | str | None, optional + Patterns to ignore, by default None. Can be a set of strings or a single string. Returns ------- - dict[str, Any] - A dictionary containing the parsed query parameters, including 'max_file_size', - 'ignore_patterns', and 'include_patterns'. + ParsedQuery + A dataclass object containing the parsed details of the repository or file path. """ # Determine the parsing method based on the source type if from_web or urlparse(source).scheme in ("https", "http") or any(h in source for h in KNOWN_GIT_HOSTS): # We either have a full URL or a domain-less slug - query = await _parse_repo_source(source) + parsed_query = await _parse_repo_source(source) else: # Local path scenario - query = _parse_path(source) + parsed_query = _parse_path(source) - # Combine ignore patterns - ignore_patterns_list = DEFAULT_IGNORE_PATTERNS.copy() + # Combine default ignore patterns + custom patterns + ignore_patterns_set = DEFAULT_IGNORE_PATTERNS.copy() if ignore_patterns: - ignore_patterns_list += _parse_patterns(ignore_patterns) + ignore_patterns_set.update(_parse_patterns(ignore_patterns)) # Process include patterns and override ignore patterns accordingly if include_patterns: parsed_include = _parse_patterns(include_patterns) - ignore_patterns_list = _override_ignore_patterns(ignore_patterns_list, include_patterns=parsed_include) + ignore_patterns_set = _override_ignore_patterns(ignore_patterns_set, include_patterns=parsed_include) else: parsed_include = None - query.update( - { - "max_file_size": max_file_size, - "ignore_patterns": ignore_patterns_list, - "include_patterns": parsed_include, - } + return ParsedQuery( + user_name=parsed_query.user_name, + repo_name=parsed_query.repo_name, + url=parsed_query.url, + subpath=parsed_query.subpath, + local_path=parsed_query.local_path, + slug=parsed_query.slug, + id=parsed_query.id, + type=parsed_query.type, + branch=parsed_query.branch, + commit=parsed_query.commit, + max_file_size=max_file_size, + ignore_patterns=ignore_patterns_set, + include_patterns=parsed_include, ) - return query -async def _parse_repo_source(source: str) -> dict[str, Any]: +async def _parse_repo_source(source: str) -> ParsedQuery: """ Parse a repository URL into a structured query dictionary. @@ -106,9 +134,8 @@ async def _parse_repo_source(source: str) -> dict[str, Any]: Returns ------- - dict[str, Any] - A dictionary containing the parsed details of the repository, including the username, - repository name, commit, branch, and other relevant information. + ParsedQuery + A dictionary containing the parsed details of the repository. """ source = unquote(source) @@ -139,18 +166,15 @@ async def _parse_repo_source(source: str) -> dict[str, Any]: local_path = Path(TMP_BASE_PATH) / _id / slug url = f"https://{host}/{user_name}/{repo_name}" - parsed = { - "user_name": user_name, - "repo_name": repo_name, - "type": None, - "branch": None, - "commit": None, - "subpath": "/", - "local_path": local_path, - "url": url, - "slug": slug, # e.g. "pandas-dev-pandas" - "id": _id, - } + parsed = ParsedQuery( + user_name=user_name, + repo_name=repo_name, + url=url, + subpath="/", + local_path=local_path, + slug=slug, + id=_id, + ) remaining_parts = parsed_url.path.strip("/").split("/")[2:] @@ -167,16 +191,20 @@ async def _parse_repo_source(source: str) -> dict[str, Any]: if remaining_parts and possible_type in ("issues", "pull"): return parsed - parsed["type"] = possible_type + parsed.type = possible_type # Commit or branch commit_or_branch = remaining_parts[0] if _is_valid_git_commit_hash(commit_or_branch): - parsed["commit"] = commit_or_branch - parsed["subpath"] += "/".join(remaining_parts[1:]) + parsed.commit = commit_or_branch + remaining_parts.pop(0) else: - parsed["branch"] = await _configure_branch_and_subpath(remaining_parts, url) - parsed["subpath"] += "/".join(remaining_parts) + parsed.branch = await _configure_branch_and_subpath(remaining_parts, url) + + # Subpath if anything left + if remaining_parts: + parsed.subpath += "/".join(remaining_parts) + return parsed @@ -199,7 +227,7 @@ async def _configure_branch_and_subpath(remaining_parts: list[str], url: str) -> # Fetch the list of branches from the remote repository branches: list[str] = await fetch_remote_branch_list(url) except RuntimeError as e: - warnings.warn(f"Warning: Failed to fetch branch list: {str(e)}") + warnings.warn(f"Warning: Failed to fetch branch list: {e}") return remaining_parts.pop(0) branch = [] @@ -255,22 +283,22 @@ def _normalize_pattern(pattern: str) -> str: return pattern -def _parse_patterns(pattern: list[str] | str) -> list[str]: +def _parse_patterns(pattern: set[str] | str) -> set[str]: """ Parse and validate file/directory patterns for inclusion or exclusion. - Takes either a single pattern string or list of pattern strings and processes them into a normalized list. + Takes either a single pattern string or set of pattern strings and processes them into a normalized list. Patterns are split on commas and spaces, validated for allowed characters, and normalized. Parameters ---------- - pattern : list[str] | str - Pattern(s) to parse - either a single string or list of strings + pattern : set[str] | str + Pattern(s) to parse - either a single string or set of strings Returns ------- - list[str] - List of normalized pattern strings + set[str] + A set of normalized patterns. Raises ------ @@ -279,49 +307,45 @@ def _parse_patterns(pattern: list[str] | str) -> list[str]: dash (-), underscore (_), dot (.), forward slash (/), plus (+), and asterisk (*) are allowed. """ - patterns = pattern if isinstance(pattern, list) else [pattern] + patterns = pattern if isinstance(pattern, set) else {pattern} - parsed_patterns = [] + parsed_patterns: set[str] = set() for p in patterns: - parsed_patterns.extend(re.split(",| ", p)) + parsed_patterns = parsed_patterns.union(set(re.split(",| ", p))) - # Filter out any empty strings - parsed_patterns = [p for p in parsed_patterns if p != ""] + # Remove empty string if present + parsed_patterns = parsed_patterns - {""} # Validate and normalize each pattern for p in parsed_patterns: if not _is_valid_pattern(p): raise InvalidPatternError(p) - return [_normalize_pattern(p) for p in parsed_patterns] + return {_normalize_pattern(p) for p in parsed_patterns} -def _override_ignore_patterns(ignore_patterns: list[str], include_patterns: list[str]) -> list[str]: +def _override_ignore_patterns(ignore_patterns: set[str], include_patterns: set[str]) -> set[str]: """ Remove patterns from ignore_patterns that are present in include_patterns using set difference. Parameters ---------- - ignore_patterns : list[str] - The list of patterns to potentially remove. - include_patterns : list[str] - The list of patterns to exclude from ignore_patterns. + ignore_patterns : set[str] + The set of ignore patterns to filter. + include_patterns : set[str] + The set of include patterns to remove from ignore_patterns. Returns ------- - list[str] - A new list of ignore_patterns with specified patterns removed. + set[str] + The filtered set of ignore patterns. """ - return list(set(ignore_patterns) - set(include_patterns)) + return set(ignore_patterns) - set(include_patterns) -def _parse_path(path_str: str) -> dict[str, Any]: +def _parse_path(path_str: str) -> ParsedQuery: """ - Parse a file path into a structured query dictionary. - - This function takes a file path and constructs a query dictionary that includes - relevant details such as the absolute path and the slug (a combination of the - directory and file names). + Parse the given file path into a structured query dictionary. Parameters ---------- @@ -330,18 +354,19 @@ def _parse_path(path_str: str) -> dict[str, Any]: Returns ------- - dict[str, Any] - A dictionary containing parsed details such as the local file path and slug. + ParsedQuery + A dictionary containing the parsed details of the file path. """ path_obj = Path(path_str).resolve() - query = { - "url": None, - "local_path": path_obj, - "slug": f"{path_obj.parent.name}/{path_obj.name}", - "subpath": "/", - "id": str(uuid.uuid4()), - } - return query + return ParsedQuery( + user_name=None, + repo_name=None, + url=None, + subpath="/", + local_path=path_obj, + slug=f"{path_obj.parent.name}/{path_obj.name}", + id=str(uuid.uuid4()), + ) def _is_valid_pattern(pattern: str) -> bool: diff --git a/src/gitingest/repository_ingest.py b/src/gitingest/repository_ingest.py index c7efa94..64b33eb 100644 --- a/src/gitingest/repository_ingest.py +++ b/src/gitingest/repository_ingest.py @@ -6,15 +6,15 @@ from config import TMP_BASE_PATH from gitingest.query_ingestion import run_ingest_query -from gitingest.query_parser import parse_query +from gitingest.query_parser import ParsedQuery, parse_query from gitingest.repository_clone import CloneConfig, clone_repo async def ingest( source: str, max_file_size: int = 10 * 1024 * 1024, # 10 MB - include_patterns: list[str] | str | None = None, - exclude_patterns: list[str] | str | None = None, + include_patterns: set[str] | str | None = None, + exclude_patterns: set[str] | str | None = None, output: str | None = None, ) -> tuple[str, str, str]: """ @@ -31,10 +31,10 @@ async def ingest( max_file_size : int Maximum allowed file size for file ingestion. Files larger than this size are ignored, by default 10*1024*1024 (10 MB). - include_patterns : list[str] | str | None, optional - Pattern or list of patterns specifying which files to include. If `None`, all files are included. - exclude_patterns : list[str] | str | None, optional - Pattern or list of patterns specifying which files to exclude. If `None`, no files are excluded. + include_patterns : set[str] | str | None, optional + Pattern or set of patterns specifying which files to include. If `None`, all files are included. + exclude_patterns : set[str] | str | None, optional + Pattern or set of patterns specifying which files to exclude. If `None`, no files are excluded. output : str | None, optional File path where the summary and content should be written. If `None`, the results are not written to a file. @@ -52,21 +52,21 @@ async def ingest( If `clone_repo` does not return a coroutine, or if the `source` is of an unsupported type. """ try: - query = await parse_query( + parsed_query: ParsedQuery = await parse_query( source=source, max_file_size=max_file_size, from_web=False, include_patterns=include_patterns, ignore_patterns=exclude_patterns, ) - if query["url"]: + if parsed_query.url: # Extract relevant fields for CloneConfig clone_config = CloneConfig( - url=query["url"], - local_path=str(query["local_path"]), - commit=query.get("commit"), - branch=query.get("branch"), + url=parsed_query.url, + local_path=str(parsed_query.local_path), + commit=parsed_query.commit, + branch=parsed_query.branch, ) clone_result = clone_repo(clone_config) @@ -75,7 +75,7 @@ async def ingest( else: raise TypeError("clone_repo did not return a coroutine as expected.") - summary, tree, content = run_ingest_query(query) + summary, tree, content = run_ingest_query(parsed_query) if output is not None: with open(output, "w", encoding="utf-8") as f: @@ -84,6 +84,6 @@ async def ingest( return summary, tree, content finally: # Clean up the temporary directory if it was created - if query["url"]: + if parsed_query.url: # Clean up the temporary directory shutil.rmtree(TMP_BASE_PATH, ignore_errors=True) diff --git a/src/main.py b/src/main.py index f2b63fd..556b3e1 100644 --- a/src/main.py +++ b/src/main.py @@ -57,7 +57,7 @@ async def remove_old_repositories(): await process_folder(folder) except Exception as e: - print(f"Error in remove_old_repositories: {str(e)}") + print(f"Error in remove_old_repositories: {e}") await asyncio.sleep(60) @@ -83,13 +83,13 @@ async def process_folder(folder: Path) -> None: history.write(f"{repo_url}\n") except Exception as e: - print(f"Error logging repository URL for {folder}: {str(e)}") + print(f"Error logging repository URL for {folder}: {e}") # Delete the folder try: shutil.rmtree(folder) except Exception as e: - print(f"Error deleting {folder}: {str(e)}") + print(f"Error deleting {folder}: {e}") @asynccontextmanager diff --git a/src/query_processor.py b/src/query_processor.py index a66bdd3..62f1c83 100644 --- a/src/query_processor.py +++ b/src/query_processor.py @@ -8,7 +8,7 @@ from config import EXAMPLE_REPOS, MAX_DISPLAY_SIZE from gitingest.query_ingestion import run_ingest_query -from gitingest.query_parser import parse_query +from gitingest.query_parser import ParsedQuery, parse_query from gitingest.repository_clone import CloneConfig, clone_repo from server_utils import Colors, log_slider_to_size @@ -77,27 +77,30 @@ async def process_query( } try: - query = await parse_query( + parsed_query: ParsedQuery = await parse_query( source=input_text, max_file_size=max_file_size, from_web=True, include_patterns=include_patterns, ignore_patterns=exclude_patterns, ) + if not parsed_query.url: + raise ValueError("The 'url' parameter is required.") + clone_config = CloneConfig( - url=query["url"], - local_path=str(query["local_path"]), - commit=query.get("commit"), - branch=query.get("branch"), + url=parsed_query.url, + local_path=str(parsed_query.local_path), + commit=parsed_query.commit, + branch=parsed_query.branch, ) await clone_repo(clone_config) - summary, tree, content = run_ingest_query(query) + summary, tree, content = run_ingest_query(parsed_query) with open(f"{clone_config.local_path}.txt", "w", encoding="utf-8") as f: f.write(tree + "\n" + content) except Exception as e: # hack to print error message when query is not defined - if "query" in locals() and query is not None and isinstance(query, dict): - _print_error(query["url"], e, max_file_size, pattern_type, pattern) + if "query" in locals() and parsed_query is not None and isinstance(parsed_query, dict): + _print_error(parsed_query["url"], e, max_file_size, pattern_type, pattern) else: print(f"{Colors.BROWN}WARN{Colors.END}: {Colors.RED}<- {Colors.END}", end="") print(f"{Colors.RED}{e}{Colors.END}") @@ -112,7 +115,7 @@ async def process_query( ) _print_success( - url=query["url"], + url=parsed_query.url, max_file_size=max_file_size, pattern_type=pattern_type, pattern=pattern, @@ -125,7 +128,7 @@ async def process_query( "summary": summary, "tree": tree, "content": content, - "ingest_id": query["id"], + "ingest_id": parsed_query.id, } ) diff --git a/tests/conftest.py b/tests/conftest.py index 87b8a4e..c11ee72 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,22 +6,25 @@ import pytest +from gitingest.query_parser import ParsedQuery + @pytest.fixture -def sample_query() -> dict[str, Any]: - return { - "user_name": "test_user", - "repo_name": "test_repo", - "local_path": Path("/tmp/test_repo").resolve(), - "subpath": "/", - "branch": "main", - "commit": None, - "max_file_size": 1_000_000, - "slug": "test_user/test_repo", - "ignore_patterns": ["*.pyc", "__pycache__", ".git"], - "include_patterns": None, - "pattern_type": "exclude", - } +def sample_query() -> ParsedQuery: + return ParsedQuery( + user_name="test_user", + repo_name="test_repo", + url=None, + subpath="/", + local_path=Path("/tmp/test_repo").resolve(), + slug="test_user/test_repo", + id="id", + branch="main", + max_file_size=1_000_000, + ignore_patterns={"*.pyc", "__pycache__", ".git"}, + include_patterns=None, + pattern_type="exclude", + ) @pytest.fixture diff --git a/tests/query_parser/test_git_host_agnostic.py b/tests/query_parser/test_git_host_agnostic.py index 8e86355..9831362 100644 --- a/tests/query_parser/test_git_host_agnostic.py +++ b/tests/query_parser/test_git_host_agnostic.py @@ -68,14 +68,13 @@ async def test_parse_query_without_host( expected_url: str, ) -> None: for url in urls: - result = await parse_query(url, max_file_size=50, from_web=True) - # Common assertions for all cases - assert result["user_name"] == expected_user - assert result["repo_name"] == expected_repo - assert result["url"] == expected_url - assert result["slug"] == f"{expected_user}-{expected_repo}" - assert result["id"] is not None - assert result["subpath"] == "/" - assert result["branch"] is None - assert result["commit"] is None - assert result["type"] is None + parsed_query = await parse_query(url, max_file_size=50, from_web=True) + assert parsed_query.user_name == expected_user + assert parsed_query.repo_name == expected_repo + assert parsed_query.url == expected_url + assert parsed_query.slug == f"{expected_user}-{expected_repo}" + assert parsed_query.id is not None + assert parsed_query.subpath == "/" + assert parsed_query.branch is None + assert parsed_query.commit is None + assert parsed_query.type is None diff --git a/tests/query_parser/test_query_parser.py b/tests/query_parser/test_query_parser.py index b2187a4..fc5fb0a 100644 --- a/tests/query_parser/test_query_parser.py +++ b/tests/query_parser/test_query_parser.py @@ -23,10 +23,10 @@ async def test_parse_url_valid_https() -> None: "https://gitingest.com/user/repo", ] for url in test_cases: - result = await _parse_repo_source(url) - assert result["user_name"] == "user" - assert result["repo_name"] == "repo" - assert result["url"] == url + parsed_query = await _parse_repo_source(url) + assert parsed_query.user_name == "user" + assert parsed_query.repo_name == "repo" + assert parsed_query.url == url async def test_parse_url_valid_http() -> None: @@ -43,10 +43,10 @@ async def test_parse_url_valid_http() -> None: "http://gitingest.com/user/repo", ] for url in test_cases: - result = await _parse_repo_source(url) - assert result["user_name"] == "user" - assert result["repo_name"] == "repo" - assert result["slug"] == "user-repo" + parsed_query = await _parse_repo_source(url) + assert parsed_query.user_name == "user" + assert parsed_query.repo_name == "repo" + assert parsed_query.slug == "user-repo" async def test_parse_url_invalid() -> None: @@ -66,11 +66,12 @@ async def test_parse_query_basic() -> None: """ test_cases = ["https://github.com/user/repo", "https://gitlab.com/user/repo"] for url in test_cases: - result = await parse_query(url, max_file_size=50, from_web=True, ignore_patterns="*.txt") - assert result["user_name"] == "user" - assert result["repo_name"] == "repo" - assert result["url"] == url - assert "*.txt" in result["ignore_patterns"] + parsed_query = await parse_query(url, max_file_size=50, from_web=True, ignore_patterns="*.txt") + assert parsed_query.user_name == "user" + assert parsed_query.repo_name == "repo" + assert parsed_query.url == url + assert parsed_query.ignore_patterns + assert "*.txt" in parsed_query.ignore_patterns async def test_parse_query_mixed_case() -> None: @@ -78,9 +79,9 @@ async def test_parse_query_mixed_case() -> None: Test `parse_query` with mixed case URLs. """ url = "Https://GitHub.COM/UsEr/rEpO" - result = await parse_query(url, max_file_size=50, from_web=True) - assert result["user_name"] == "user" - assert result["repo_name"] == "repo" + parsed_query = await parse_query(url, max_file_size=50, from_web=True) + assert parsed_query.user_name == "user" + assert parsed_query.repo_name == "repo" async def test_parse_query_include_pattern() -> None: @@ -89,9 +90,9 @@ async def test_parse_query_include_pattern() -> None: Verifies that the include pattern is set correctly and default ignore patterns are applied. """ url = "https://github.com/user/repo" - result = await parse_query(url, max_file_size=50, from_web=True, include_patterns="*.py") - assert result["include_patterns"] == ["*.py"] - assert set(result["ignore_patterns"]) == set(DEFAULT_IGNORE_PATTERNS) + parsed_query = await parse_query(url, max_file_size=50, from_web=True, include_patterns="*.py") + assert parsed_query.include_patterns == {"*.py"} + assert parsed_query.ignore_patterns == DEFAULT_IGNORE_PATTERNS async def test_parse_query_invalid_pattern() -> None: @@ -116,11 +117,11 @@ async def test_parse_url_with_subpaths() -> None: "gitingest.repository_clone.fetch_remote_branch_list", new_callable=AsyncMock ) as mock_fetch_branches: mock_fetch_branches.return_value = ["main", "dev", "feature-branch"] - result = await _parse_repo_source(url) - assert result["user_name"] == "user" - assert result["repo_name"] == "repo" - assert result["branch"] == "main" - assert result["subpath"] == "/subdir/file" + parsed_query = await _parse_repo_source(url) + assert parsed_query.user_name == "user" + assert parsed_query.repo_name == "repo" + assert parsed_query.branch == "main" + assert parsed_query.subpath == "/subdir/file" async def test_parse_url_invalid_repo_structure() -> None: @@ -139,8 +140,8 @@ def test_parse_patterns_valid() -> None: Verifies that the patterns are correctly parsed into a list. """ patterns = "*.py, *.md, docs/*" - result = _parse_patterns(patterns) - assert result == ["*.py", "*.md", "docs/*"] + parsed_patterns = _parse_patterns(patterns) + assert parsed_patterns == {"*.py", "*.md", "docs/*"} def test_parse_patterns_invalid_characters() -> None: @@ -159,9 +160,9 @@ async def test_parse_query_with_large_file_size() -> None: Verifies that the file size limit and default ignore patterns are set correctly. """ url = "https://github.com/user/repo" - result = await parse_query(url, max_file_size=10**9, from_web=True) - assert result["max_file_size"] == 10**9 - assert result["ignore_patterns"] == DEFAULT_IGNORE_PATTERNS + parsed_query = await parse_query(url, max_file_size=10**9, from_web=True) + assert parsed_query.max_file_size == 10**9 + assert parsed_query.ignore_patterns == DEFAULT_IGNORE_PATTERNS async def test_parse_query_empty_patterns() -> None: @@ -170,9 +171,9 @@ async def test_parse_query_empty_patterns() -> None: Verifies that the include patterns are set to None and default ignore patterns are applied. """ url = "https://github.com/user/repo" - result = await parse_query(url, max_file_size=50, from_web=True, include_patterns="", ignore_patterns="") - assert result["include_patterns"] is None - assert result["ignore_patterns"] == DEFAULT_IGNORE_PATTERNS + parsed_query = await parse_query(url, max_file_size=50, from_web=True, include_patterns="", ignore_patterns="") + assert parsed_query.include_patterns is None + assert parsed_query.ignore_patterns == DEFAULT_IGNORE_PATTERNS async def test_parse_query_include_and_ignore_overlap() -> None: @@ -181,16 +182,17 @@ async def test_parse_query_include_and_ignore_overlap() -> None: Verifies that overlapping patterns are removed from the ignore patterns. """ url = "https://github.com/user/repo" - result = await parse_query( + parsed_query = await parse_query( url, max_file_size=50, from_web=True, include_patterns="*.py", - ignore_patterns=["*.py", "*.txt"], + ignore_patterns={"*.py", "*.txt"}, ) - assert result["include_patterns"] == ["*.py"] - assert "*.py" not in result["ignore_patterns"] - assert "*.txt" in result["ignore_patterns"] + assert parsed_query.include_patterns == {"*.py"} + assert parsed_query.ignore_patterns is not None + assert "*.py" not in parsed_query.ignore_patterns + assert "*.txt" in parsed_query.ignore_patterns async def test_parse_query_local_path() -> None: @@ -199,11 +201,11 @@ async def test_parse_query_local_path() -> None: Verifies that the local path is set, a unique ID is generated, and the slug is correctly created. """ path = "/home/user/project" - result = await parse_query(path, max_file_size=100, from_web=False) + parsed_query = await parse_query(path, max_file_size=100, from_web=False) tail = Path("home/user/project") - assert result["local_path"].parts[-len(tail.parts) :] == tail.parts - assert result["id"] is not None - assert result["slug"] == "user/project" + assert parsed_query.local_path.parts[-len(tail.parts) :] == tail.parts + assert parsed_query.id is not None + assert parsed_query.slug == "user/project" async def test_parse_query_relative_path() -> None: @@ -212,10 +214,10 @@ async def test_parse_query_relative_path() -> None: Verifies that the local path and slug are correctly resolved. """ path = "./project" - result = await parse_query(path, max_file_size=100, from_web=False) + parsed_query = await parse_query(path, max_file_size=100, from_web=False) tail = Path("project") - assert result["local_path"].parts[-len(tail.parts) :] == tail.parts - assert result["slug"].endswith("project") + assert parsed_query.local_path.parts[-len(tail.parts) :] == tail.parts + assert parsed_query.slug.endswith("project") async def test_parse_query_empty_source() -> None: @@ -242,23 +244,24 @@ async def test_parse_url_branch_and_commit_distinction() -> None: ) as mock_fetch_branches: mock_fetch_branches.return_value = ["main", "dev", "feature-branch"] - result_branch = await _parse_repo_source(url_branch) - result_commit = await _parse_repo_source(url_commit) - assert result_branch["branch"] == "main" - assert result_branch["commit"] is None + parsed_query_with_branch = await _parse_repo_source(url_branch) + parsed_query_with_commit = await _parse_repo_source(url_commit) - assert result_commit["branch"] is None - assert result_commit["commit"] == "abcd1234abcd1234abcd1234abcd1234abcd1234" + assert parsed_query_with_branch.branch == "main" + assert parsed_query_with_branch.commit is None + + assert parsed_query_with_commit.branch is None + assert parsed_query_with_commit.commit == "abcd1234abcd1234abcd1234abcd1234abcd1234" async def test_parse_query_uuid_uniqueness() -> None: """ - Test `parse_query` to ensure that each call generates a unique UUID for the query result. + Test `parse_query` to ensure that each call generates a unique UUID for the query. """ path = "/home/user/project" - result1 = await parse_query(path, max_file_size=100, from_web=False) - result2 = await parse_query(path, max_file_size=100, from_web=False) - assert result1["id"] != result2["id"] + parsed_query_1 = await parse_query(path, max_file_size=100, from_web=False) + parsed_query_2 = await parse_query(path, max_file_size=100, from_web=False) + assert parsed_query_1.id != parsed_query_2.id async def test_parse_url_with_query_and_fragment() -> None: @@ -267,10 +270,10 @@ async def test_parse_url_with_query_and_fragment() -> None: Verifies that the URL is cleaned and other fields are correctly extracted. """ url = "https://github.com/user/repo?arg=value#fragment" - result = await _parse_repo_source(url) - assert result["user_name"] == "user" - assert result["repo_name"] == "repo" - assert result["url"] == "https://github.com/user/repo" # URL should be cleaned + parsed_query = await _parse_repo_source(url) + assert parsed_query.user_name == "user" + assert parsed_query.repo_name == "repo" + assert parsed_query.url == "https://github.com/user/repo" # URL should be cleaned async def test_parse_url_unsupported_host() -> None: @@ -281,18 +284,16 @@ async def test_parse_url_unsupported_host() -> None: async def test_parse_query_with_branch() -> None: url = "https://github.com/pandas-dev/pandas/blob/2.2.x/.github/ISSUE_TEMPLATE/documentation_improvement.yaml" - result = await parse_query(url, max_file_size=10**9, from_web=True) - assert result["user_name"] == "pandas-dev" - assert result["repo_name"] == "pandas" - assert result["url"] == "https://github.com/pandas-dev/pandas" - assert result["slug"] == "pandas-dev-pandas" - assert result["id"] is not None - print('result["subpath"]', result["subpath"]) - print("/.github/ISSUE_TEMPLATE/documentation_improvement.yaml") - assert result["subpath"] == "/.github/ISSUE_TEMPLATE/documentation_improvement.yaml" - assert result["branch"] == "2.2.x" - assert result["commit"] is None - assert result["type"] == "blob" + parsed_query = await parse_query(url, max_file_size=10**9, from_web=True) + assert parsed_query.user_name == "pandas-dev" + assert parsed_query.repo_name == "pandas" + assert parsed_query.url == "https://github.com/pandas-dev/pandas" + assert parsed_query.slug == "pandas-dev-pandas" + assert parsed_query.id is not None + assert parsed_query.subpath == "/.github/ISSUE_TEMPLATE/documentation_improvement.yaml" + assert parsed_query.branch == "2.2.x" + assert parsed_query.commit is None + assert parsed_query.type == "blob" @pytest.mark.asyncio @@ -312,10 +313,10 @@ async def test_parse_repo_source_with_failed_git_command(url, expected_branch, e with patch("gitingest.repository_clone.fetch_remote_branch_list", new_callable=AsyncMock) as mock_fetch_branches: mock_fetch_branches.side_effect = Exception("Failed to fetch branch list") - result = await _parse_repo_source(url) + parsed_query = await _parse_repo_source(url) - assert result["branch"] == expected_branch - assert result["subpath"] == expected_subpath + assert parsed_query.branch == expected_branch + assert parsed_query.subpath == expected_subpath @pytest.mark.asyncio @@ -342,6 +343,6 @@ async def test_parse_repo_source_with_various_url_patterns(url, expected_branch, ) mock_fetch_branches.return_value = ["feature/fix1", "main", "feature-branch"] - result = await _parse_repo_source(url) - assert result["branch"] == expected_branch - assert result["subpath"] == expected_subpath + parsed_query = await _parse_repo_source(url) + assert parsed_query.branch == expected_branch + assert parsed_query.subpath == expected_subpath diff --git a/tests/test_query_ingestion.py b/tests/test_query_ingestion.py index 9d2b826..0907658 100644 --- a/tests/test_query_ingestion.py +++ b/tests/test_query_ingestion.py @@ -1,14 +1,14 @@ """ Tests for the query_ingestion module """ from pathlib import Path -from typing import Any from unittest.mock import patch from gitingest.query_ingestion import _extract_files_content, _read_file_content, _scan_directory, run_ingest_query +from gitingest.query_parser import ParsedQuery -def test_scan_directory(temp_directory: Path, sample_query: dict[str, Any]) -> None: - sample_query["local_path"] = temp_directory +def test_scan_directory(temp_directory: Path, sample_query: ParsedQuery) -> None: + sample_query.local_path = temp_directory result = _scan_directory(temp_directory, query=sample_query) if result is None: assert False, "Result is None" @@ -19,12 +19,13 @@ def test_scan_directory(temp_directory: Path, sample_query: dict[str, Any]) -> N assert len(result["children"]) == 5 # file1.txt, file2.py, src, dir1, dir2 -def test_extract_files_content(temp_directory: Path, sample_query: dict[str, Any]) -> None: - sample_query["local_path"] = temp_directory +def test_extract_files_content(temp_directory: Path, sample_query: ParsedQuery) -> None: + sample_query.local_path = temp_directory + nodes = _scan_directory(temp_directory, query=sample_query) if nodes is None: assert False, "Nodes is None" - files = _extract_files_content(query=sample_query, node=nodes, max_file_size=1_000_000) + files = _extract_files_content(query=sample_query, node=nodes) assert len(files) == 8 # All .txt and .py files # Check for presence of key files @@ -58,14 +59,14 @@ def test_read_file_content_with_non_notebook(tmp_path: Path): # Test that when using a ['*.txt'] as include pattern, only .txt files are processed & .py files are excluded -def test_include_txt_pattern(temp_directory: Path, sample_query: dict[str, Any]) -> None: - sample_query["local_path"] = temp_directory - sample_query["include_patterns"] = ["*.txt"] +def test_include_txt_pattern(temp_directory: Path, sample_query: ParsedQuery) -> None: + sample_query.local_path = temp_directory + sample_query.include_patterns = {"*.txt"} result = _scan_directory(temp_directory, query=sample_query) assert result is not None, "Result should not be None" - files = _extract_files_content(query=sample_query, node=result, max_file_size=1_000_000) + files = _extract_files_content(query=sample_query, node=result) file_paths = [f["path"] for f in files] assert len(files) == 5, "Should have found exactly 5 .txt files" assert all(path.endswith(".txt") for path in file_paths), "Should only include .txt files" @@ -77,15 +78,15 @@ def test_include_txt_pattern(temp_directory: Path, sample_query: dict[str, Any]) assert not any(path.endswith(".py") for path in file_paths), "Should not include .py files" -def test_include_nonexistent_extension(temp_directory: Path, sample_query: dict[str, Any]) -> None: - sample_query["local_path"] = temp_directory - sample_query["include_patterns"] = ["*.query"] # Is a Non existant extension ? +def test_include_nonexistent_extension(temp_directory: Path, sample_query: ParsedQuery) -> None: + sample_query.local_path = temp_directory + sample_query.include_patterns = {"*.query"} # Is a Non existant extension ? result = _scan_directory(temp_directory, query=sample_query) assert result is not None, "Result should not be None" # Extract the files content & set file limit cap - files = _extract_files_content(query=sample_query, node=result, max_file_size=1_000_000) + files = _extract_files_content(query=sample_query, node=result) # Verify no file processed with wrong extension assert len(files) == 0, "Should not find any files with .qwerty extension" @@ -96,70 +97,70 @@ def test_include_nonexistent_extension(temp_directory: Path, sample_query: dict[ # single folder patterns -def test_include_src_star_pattern(temp_directory: Path, sample_query: dict[str, Any]) -> None: +def test_include_src_star_pattern(temp_directory: Path, sample_query: ParsedQuery) -> None: """ Test that when using 'src/*' as include pattern, files under the src directory are included. Note: Windows is not supported - test converts Windows paths to Unix-style for validation. """ - sample_query["local_path"] = temp_directory - sample_query["include_patterns"] = ["src/*"] + sample_query.local_path = temp_directory + sample_query.include_patterns = {"src/*"} result = _scan_directory(temp_directory, query=sample_query) assert result is not None, "Result should not be None" - files = _extract_files_content(query=sample_query, node=result, max_file_size=1_000_000) + files = _extract_files_content(query=sample_query, node=result) # Convert Windows paths to Unix-style for test validation file_paths = {f["path"].replace("\\", "/") for f in files} expected_paths = {"src/subfile1.txt", "src/subfile2.py", "src/subdir/file_subdir.txt", "src/subdir/file_subdir.py"} assert file_paths == expected_paths, "Missing or unexpected files in result" -def test_include_src_recursive(temp_directory: Path, sample_query: dict[str, Any]) -> None: +def test_include_src_recursive(temp_directory: Path, sample_query: ParsedQuery) -> None: """ Test that when using 'src/**' as include pattern, all files under src directory are included recursively. Note: Windows is not supported - test converts Windows paths to Unix-style for validation. """ - sample_query["local_path"] = temp_directory - sample_query["include_patterns"] = ["src/**"] + sample_query.local_path = temp_directory + sample_query.include_patterns = {"src/**"} result = _scan_directory(temp_directory, query=sample_query) assert result is not None, "Result should not be None" - files = _extract_files_content(query=sample_query, node=result, max_file_size=1_000_000) + files = _extract_files_content(query=sample_query, node=result) # Convert Windows paths to Unix-style for test validation file_paths = {f["path"].replace("\\", "/") for f in files} expected_paths = {"src/subfile1.txt", "src/subfile2.py", "src/subdir/file_subdir.txt", "src/subdir/file_subdir.py"} assert file_paths == expected_paths, "Missing or unexpected files in result" -def test_include_src_wildcard_prefix(temp_directory: Path, sample_query: dict[str, Any]) -> None: +def test_include_src_wildcard_prefix(temp_directory: Path, sample_query: ParsedQuery) -> None: """ Test that when using 'src*' as include pattern, it matches the src directory and any paths that start with 'src'. Note: Windows is not supported - test converts Windows paths to Unix-style for validation. """ - sample_query["local_path"] = temp_directory - sample_query["include_patterns"] = ["src*"] + sample_query.local_path = temp_directory + sample_query.include_patterns = {"src*"} result = _scan_directory(temp_directory, query=sample_query) assert result is not None, "Result should not be None" - files = _extract_files_content(query=sample_query, node=result, max_file_size=1_000_000) + files = _extract_files_content(query=sample_query, node=result) # Convert Windows paths to Unix-style for test validation file_paths = {f["path"].replace("\\", "/") for f in files} expected_paths = {"src/subfile1.txt", "src/subfile2.py", "src/subdir/file_subdir.txt", "src/subdir/file_subdir.py"} assert file_paths == expected_paths, "Missing or unexpected files in result" -def test_run_ingest_query(temp_directory: Path, sample_query: dict[str, Any]) -> None: +def test_run_ingest_query(temp_directory: Path, sample_query: ParsedQuery) -> None: """ Test the run_ingest_query function to ensure it processes the directory correctly. """ - sample_query["local_path"] = temp_directory - sample_query["subpath"] = "/" - sample_query["type"] = None + sample_query.local_path = temp_directory + sample_query.subpath = "/" + sample_query.type = None summary, _, content = run_ingest_query(sample_query)