Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TD] Historical edited files and profiling heuristics #4590

Merged
merged 3 commits into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions .github/workflows/update_test_file_ratings.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ on:
paths:
- ".github/workflows/update_test_file_ratings.yml"
- "torchci/scripts/calculate_file_test_rating.py"
- "torchci/scripts/test_calculate_file_test_rating.py"
- "torchci/scripts/td_heuristic_historical_edited_files.py"
- "torchci/scripts/td_heuristic_profiling.py"
- "torchci/scripts/get_merge_base_info.py"
schedule:
- cron: 5 11 * * * # At 11:05 UTC every day or about 4am PT
Expand Down Expand Up @@ -46,6 +49,10 @@ jobs:
- name: Generate file test ratings
run: |
python3 test-infra/torchci/scripts/calculate_file_test_rating.py
python3 test-infra/torchci/scripts/td_heuristic_historical_edited_files.py
# Do not run this one, it won't change
# python3 test-infra/torchci/scripts/td_heuristic_profiling.py

env:
ROCKSET_API_KEY: ${{ secrets.ROCKSET_API_KEY }}

Expand Down Expand Up @@ -76,3 +83,17 @@ jobs:
user_email: "[email protected]"
user_name: "Pytorch Test Infra"
commit_message: "Updating file to test class correlations"

- name: Push historical edited files heuristic to test-infra repository
if: github.event_name != 'pull_request'
uses: dmnemec/copy_file_to_another_repo_action@eebb594efdf52bc12e1b461988d7254322dac131
env:
API_TOKEN_GITHUB: ${{ secrets.GITHUB_TOKEN }}
with:
source_file: "td_heuristic_historical_edited_files.json"
destination_repo: "pytorch/test-infra"
destination_folder: "stats"
destination_branch: generated-stats
user_email: "[email protected]"
user_name: "Pytorch Test Infra"
commit_message: "Updating TD heuristic: historical edited files"
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,6 @@ docs/_build/

# Pyenv
.python-version

# torchci caching utils
.torchci_python_utils_cache
44 changes: 31 additions & 13 deletions torchci/scripts/get_merge_base_info.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import subprocess
from datetime import datetime
from multiprocessing import Pool
from pathlib import Path
from typing import List

from pathlib import Path

from utils_td_heuristics import (
list_past_year_shas,
run_command,
)

from rockset_utils import query_rockset, remove_from_rockset, upload_to_rockset

REPO_ROOT = Path(__file__).resolve().parent.parent.parent
Expand All @@ -19,6 +25,16 @@
mb.merge_base is null
"""

NOT_IN_MERGE_BASES_TABLE = """
select
shas.sha as head_sha
from
unnest(SPLIT(:shas, ',') as sha) as shas
left outer join commons.merge_bases mb on mb.sha = shas.sha
where
mb.sha is null
or mb.repo is null
"""

DUP_MERGE_BASE_INFO = """
select
Expand Down Expand Up @@ -99,18 +115,20 @@ def upload_merge_base_info(shas: List[str]) -> None:
print(
f"There are {len(failed_test_shas)} shas, uploading in intervals of {interval}"
)
pool = Pool(20)
errors = []
for i in range(0, len(failed_test_shas), interval):
pull_shas(failed_test_shas[i : i + interval])
errors.append(
pool.apply_async(
upload_merge_base_info, args=(failed_test_shas[i : i + interval],)
upload_merge_base_info(failed_test_shas[i : i + interval])

interval = 500
main_branch_shas = list_past_year_shas()
print(f"There are {len(main_branch_shas)} shas, uploading in batches of {interval}")
for i in range(0, len(main_branch_shas), interval):
shas = [
x["head_sha"]
for x in query_rockset(
NOT_IN_MERGE_BASES_TABLE,
{"shas": ",".join(main_branch_shas[i : i + interval])},
)
)
print("done pulling")
pool.close()
pool.join()
for i in errors:
if i.get() is not None:
print(i.get())
]
upload_merge_base_info(shas)
print(f"{i} to {i + interval} done")
22 changes: 12 additions & 10 deletions torchci/scripts/rockset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import os
from typing import Any, Dict, List, Optional

import rockset # type: ignore[import]
import rockset

from utils import cache_json # type: ignore[import]


@lru_cache
Expand All @@ -13,16 +15,16 @@ def get_rockset_client():


def query_rockset(
query: str, params: Optional[Dict[str, Any]] = None
query: str, params: Optional[Dict[str, Any]] = None, use_cache: bool = False
) -> List[Dict[str, Any]]:
res: List[Dict[str, Any]] = (
rockset.RocksetClient(
host="api.rs2.usw2.rockset.com", api_key=os.environ["ROCKSET_API_KEY"]
)
.sql(query, params=params)
.results
)
return res
if not use_cache:
return get_rockset_client().sql(query, params=params).results

@cache_json
def cache_query_rockset(query, params):
return get_rockset_client().sql(query, params=params).results

return cache_query_rockset(query, params)


def upload_to_rockset(
Expand Down
64 changes: 64 additions & 0 deletions torchci/scripts/td_heuristic_historical_edited_files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import json
from collections import defaultdict
from typing import Dict

from utils_td_heuristics import (
cache_json,
evaluate,
get_all_invoking_files,
get_filtered_failed_tests,
get_merge_bases_dict,
list_past_year_shas,
query_rockset,
)

CHANGED_FILES_QUERY = """
select
sha,
changed_files
from
commons.merge_bases
where
ARRAY_CONTAINS(SPLIT(:shas, ','), sha)
"""


@cache_json
def gen_correlation_dict() -> Dict[str, Dict[str, float]]:
shas = list_past_year_shas()

interval = 500
commits = []
for i in range(0, len(shas), interval):
commits.extend(
query_rockset(
CHANGED_FILES_QUERY,
params={"shas": ",".join(shas[i : i + interval])},
use_cache=True,
)
)

invoking_files = get_all_invoking_files()

d = defaultdict(lambda: defaultdict(float))
for commit in commits:
changed_files = commit["changed_files"]
# Fullname of test files look like test/<file>.py, but invoking files
# from rockset don't include the test/ or the .py extension, so remove
# those
test_files = [x[5:-3] for x in changed_files if x[5:-3] in invoking_files]
clee2000 marked this conversation as resolved.
Show resolved Hide resolved
for test_file in test_files:
for file in changed_files:
d[file][test_file] += 1 / len(changed_files)
return d


if __name__ == "__main__":
correlation_dict = gen_correlation_dict()
merge_bases = get_merge_bases_dict()
filtered_tests = get_filtered_failed_tests()

evaluate(filtered_tests, merge_bases, correlation_dict)

with open("td_heuristic_historical_edited_files.json", mode="w") as file:
json.dump(correlation_dict, file, sort_keys=True, indent=2)
26 changes: 26 additions & 0 deletions torchci/scripts/td_heuristic_profiling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import json

import requests
from utils_td_heuristics import evaluate, get_filtered_failed_tests, get_merge_bases_dict


def get_profiling_dict():
# The dict should be generated elsewhere and this function modified to
# retrieve the data.
url = "https://raw.githubusercontent.com/pytorch/test-infra/generated-stats/stats/td_heuristic_profiling.json"
return json.loads(requests.get(url).text)


def main() -> None:
profiling_dict = get_profiling_dict()
merge_bases = get_merge_bases_dict()
filtered_tests = get_filtered_failed_tests()

evaluate(filtered_tests, merge_bases, profiling_dict)

with open("td_heuristic_profiling.json", mode="w") as file:
json.dump(profiling_dict, file, sort_keys=True, indent=2)


if __name__ == "__main__":
main()
68 changes: 68 additions & 0 deletions torchci/scripts/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import datetime
from hashlib import sha256
import json
import os
import pathlib
import subprocess
from typing import List, Union


FILE_CACHE_LIFESPAN_SECONDS = 60 * 60 * 24 # 1 day
REPO_ROOT = pathlib.Path(__file__).parent.parent.parent
CACHE_FOLDER = REPO_ROOT / ".torchci_python_utils_cache"


def js_beautify(obj):
# Like json.dumps with indent=2, but only at the first level. Nice for
# dictionaries of str -> really long list
import jsbeautifier

opts = jsbeautifier.default_options()
opts.indent_size = 2
return jsbeautifier.beautify(json.dumps(obj), opts)


def run_command(command: Union[str, List[str]]) -> str:
# Runs command in pytorch folder. Assumes test-infra and pytorch are in the
# same folder.
if isinstance(command, str):
command = command.split(" ")
cwd = REPO_ROOT / ".." / "pytorch"
return (
subprocess.check_output(
command,
cwd=cwd,
)
.decode("utf-8")
.strip()
)


def cache_json(func):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

# Requires that both input and output be json serializable.
# Decorator for caching function results into a file so it can be reused betwen runs.
os.makedirs(CACHE_FOLDER, exist_ok=True)

def wrapper(*args, **kwargs):
os.makedirs(CACHE_FOLDER, exist_ok=True)
args_key = sha256(json.dumps(args).encode("utf-8")).hexdigest()
kwargs_key = sha256(
json.dumps(kwargs, sort_keys=True).encode("utf-8")
).hexdigest()
file_name = f"{func.__name__} args={args_key} kwargs={kwargs_key}.json"

if os.path.exists(CACHE_FOLDER / file_name):
now = datetime.datetime.now()
mtime = datetime.datetime.fromtimestamp(
(CACHE_FOLDER / file_name).stat().st_mtime
)
diff = now - mtime
if diff.total_seconds() < FILE_CACHE_LIFESPAN_SECONDS:
return json.load(open(CACHE_FOLDER / file_name))

res = func(*args, **kwargs)
with open(CACHE_FOLDER / file_name, "w") as f:
f.write(json.dumps(res))
return res

return wrapper
Loading