diff --git a/TODO b/TODO new file mode 100644 index 00000000..40801e3c --- /dev/null +++ b/TODO @@ -0,0 +1,11 @@ +* do not include commented-out lines in the commit message (like #Conflict...) +* validate commit message - validate the expected format and whether people can be found in github +* let the user edit the CHANGES message, make sure it is one line message + +* prepare commit message +- list of commit messages +- * use GPT to summarize the changes +- list all users who interacted on any of the PRs or reviewers from Jira +- get the current user +- prepare the message +- open editor with tmp file containing the message diff --git a/dev/scripts/__init__.py b/dev/scripts/__init__.py new file mode 100644 index 00000000..a437d36c --- /dev/null +++ b/dev/scripts/__init__.py @@ -0,0 +1,5 @@ +import os +import sys + +PROJECT_PATH = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(PROJECT_PATH) diff --git a/dev/scripts/digest-circleci-workflow.py b/dev/scripts/digest-circleci-workflow.py new file mode 100644 index 00000000..20f0ca70 --- /dev/null +++ b/dev/scripts/digest-circleci-workflow.py @@ -0,0 +1,129 @@ +# https://app.circleci.com/pipelines/github/jacek-lewandowski/cassandra/1252/workflows/b10132a7-1b4f-44d0-8808-f19a3b5fde69/jobs/63797 +# https://circleci.com/api/v2/project/gh/jacek-lewandowski/cassandra/63797/tests +# { +# "items": [ +# { +# "classname": "org.apache.cassandra.distributed.test.LegacyCASTest", +# "name": "testRepairIncompletePropose-_jdk17", +# "result": "success", +# "message": "", +# "run_time": 15.254, +# "source": "unknown" +# } +# ,{ +# "classname": "org.apache.cassandra.distributed.test.NativeTransportEncryptionOptionsTest", +# "name": "testEndpointVerificationEnabledIpNotInSAN-cassandra.testtag_IS_UNDEFINED", +# "result": "failure", +# "message": "junit.framework.AssertionFailedError: Forked Java VM exited abnormally. Please note the time in the report does not reflect the time until the VM exit.\n\tat jdk.internal.reflect.GeneratedMethodAccessor4.invoke(Unknown Source)\n\tat java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\n\tat java.base/java.util.Vector.forEach(Vector.java:1365)\n\tat jdk.internal.reflect.GeneratedMethodAccessor4.invoke(Unknown Source)\n\tat java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\n\tat jdk.internal.reflect.GeneratedMethodAccessor4.invoke(Unknown Source)\n\tat java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\n\tat java.base/java.util.Vector.forEach(Vector.java:1365)\n\tat jdk.internal.reflect.GeneratedMethodAccessor4.invoke(Unknown Source)\n\tat java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\n\tat java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\n\tat java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:77)\n\tat java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\n\tat java.base/java.util.Vector.forEach(Vector.java:1365)\n\tat java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\n\tat java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:77)\n\tat java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\n\tat org.apache.cassandra.anttasks.TestHelper.execute(TestHelper.java:53)\n\tat jdk.internal.reflect.GeneratedMethodAccessor4.invoke(Unknown Source)\n\tat java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\n\tat java.base/java.util.Vector.forEach(Vector.java:1365)\n\tat jdk.internal.reflect.GeneratedMethodAccessor4.invoke(Unknown Source)\n\tat java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\n\tat jdk.internal.reflect.GeneratedMethodAccessor4.invoke(Unknown Source)\n\tat java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)", +# "run_time": 0.001, +# "source": "unknown" +# } +# ] +# } +import csv + +# So here is the plan: +# I have a link to the pipeline: https://app.circleci.com/pipelines/github/jacek-lewandowski/cassandra/1252 +# The program goes through all the workflow jobs and list the failed tests along with the workflow, job, etc. +# Then: +# - separate failures into 3 groups: +# 1. flaky - if a test was repeated in mulitple jobs and failred in some of them +# 2. failure - if a test was repeated in multiple jobs and failed in all of them +# 3. suspected - if a test was not repeated + +# Then for each failure list Jira tickets that mention the test name. + +# Having that information, let the user decide what to do with each failure: +# - select a jira ticket +# - create a new ticket +# - do not associate with any ticket +# - report on the PR + +# Eventually, the user can create the script which can perform the planned operations + +from lib.circleci_utils import * + +class TestFailure(NamedTuple): + file: str + classname: str + name: str + jobs_comp: str + jobs_list: list + +class TestFailureComparison(NamedTuple): + file: str + classname: str + name: str + feature_jobs: set + base_jobs: set + jobs_comp: str + +if len(sys.argv) != 4 and len(sys.argv) != 6: + print("Usage: %s " % sys.argv[0]) + print("Usage: %s " % sys.argv[0]) + sys.exit(1) + +if len(sys.argv) == 4: + repo = sys.argv[1] + workflow_id = sys.argv[2] + output_file = sys.argv[3] + failed_tests_dict = get_failed_tests(repo, workflow_id) + failed_tests = [] + for file in failed_tests_dict: + for classname in failed_tests_dict[file]: + for name in failed_tests_dict[file][classname]: + jobs = list(failed_tests_dict[file][classname][name]) + jobs.sort() + failed_tests.append(TestFailure(file, classname, name, ",".join(failed_tests_dict[file][classname][name]), jobs)) + + # sort failed tests by jobs, file, classname, name + failed_tests.sort(key=lambda test: (test.jobs_comp, test.file, test.classname, test.name)) + + # save failed_tests to csv file + with open(output_file, 'w') as csvfile: + writer = csv.writer(csvfile) + writer.writerow(['file', 'classname', 'name', 'jobs']) + for test in failed_tests: + writer.writerow([test.file, test.classname, test.name, test.jobs_comp]) + +else: + feature_repo = sys.argv[1] + feature_workflow_id = sys.argv[2] + base_repo = sys.argv[3] + base_workflow_id = sys.argv[4] + output_file = sys.argv[5] + feature_failed_tests_dict = get_failed_tests(feature_repo, feature_workflow_id) + base_failed_tests_dict = get_failed_tests(base_repo, base_workflow_id) + + failed_tests = [] + all_files = set(feature_failed_tests_dict.keys()).union(set(base_failed_tests_dict.keys())) + for file in all_files: + feature_classnames = feature_failed_tests_dict[file] if file in feature_failed_tests_dict else {} + base_classnames = base_failed_tests_dict[file] if file in base_failed_tests_dict else {} + all_classnames = set(feature_classnames.keys()).union(set(base_classnames.keys())) + for classname in all_classnames: + feature_names = feature_classnames[classname] if classname in feature_classnames else {} + base_names = base_classnames[classname] if classname in base_classnames else {} + all_names = set(feature_names.keys()).union(set(base_names.keys())) + for name in all_names: + feature_jobs = feature_names[name] if name in feature_names else set() + base_jobs = base_names[name] if name in base_names else set() + jobs_comp = list(feature_jobs.union(base_jobs)) + jobs_comp.sort() + failed_tests.append(TestFailureComparison(file, classname, name, feature_jobs, base_jobs, ",".join(jobs_comp))) + + # sort failed tests by jobs, file, classname, name + failed_tests.sort(key=lambda test: (test.jobs_comp, test.file, test.classname, test.name)) + + # save failed_tests to csv file + with open(output_file, 'w') as csvfile: + writer = csv.writer(csvfile) + writer.writerow(['file', 'classname', 'name', 'failed in feature only', 'failed in base only', 'failed in both']) + for test in failed_tests: + feature_only_jobs = list(test.feature_jobs.difference(test.base_jobs)) + feature_only_jobs.sort() + base_only_jobs = list(test.base_jobs.difference(test.feature_jobs)) + base_only_jobs.sort() + common_jobs = list(test.feature_jobs.intersection(test.base_jobs)) + common_jobs.sort() + writer.writerow([test.file, test.classname, test.name, ",".join(feature_only_jobs), ",".join(base_only_jobs), ",".join(common_jobs)]) diff --git a/dev/scripts/lib/circleci_utils.py b/dev/scripts/lib/circleci_utils.py new file mode 100644 index 00000000..34c08687 --- /dev/null +++ b/dev/scripts/lib/circleci_utils.py @@ -0,0 +1,128 @@ +import json +import sys +from enum import Enum +from typing import NamedTuple + +import urllib3 + +class PipelineInfo(NamedTuple): + id: str + number: int + +def get_pipelines_from_circleci(repo, branch): + http = urllib3.PoolManager() + url = "https://circleci.com/api/v2/project/gh/%s/cassandra/pipeline?branch=%s" % (repo, branch) + r = http.request('GET', url) + if r.status == 200: + items = json.loads(r.data.decode('utf-8'))['items'] + return [PipelineInfo(id=item['id'], number=item['number']) for item in items] + return None + +class WorkflowInfo(NamedTuple): + id: str + name: str + status: str + +def get_pipeline_workflows(pipeline_id): + http = urllib3.PoolManager() + url = "https://circleci.com/api/v2/pipeline/%s/workflow" % (pipeline_id) + r = http.request('GET', url) + if r.status == 200: + items = json.loads(r.data.decode('utf-8'))['items'] + return [WorkflowInfo(id=item['id'], name=item['name'], status=item['status']) for item in items] + +class JobType(Enum): + BUILD = "build" + APPROVAL = "approval" + +class JobStatus(Enum): + SUCCESS = "success" + RUNNING = "running" + NOT_RUN = "not_run" + FAILED = "failed" + RETRIED = "retried" + QUEUED = "queued" + NOT_RUNNING = "not_running" + INFRASTRUCTURE_FAIL = "infrastructure_fail" + TIMEDOUT = "timedout" + ON_HOLD = "on_hold" + TERMINATED_UNKNOWN = "terminated-unknown" + BLOCKED = "blocked" + CANCELED = "canceled" + UNAUTHORIZED = "unauthorized" + +class JobInfo(NamedTuple): + id: str + name: str + status: JobStatus + job_number: str + type: JobType + +def job_info_from_json(json): + return JobInfo(id=json['id'], name=json['name'], status=JobStatus(json['status']), job_number=json['job_number'] if 'job_number' in json else None , type=JobType(json['type'])) + +def get_workflow_jobs(workflow_id): + http = urllib3.PoolManager() + url = "https://circleci.com/api/v2/workflow/%s/job" % (workflow_id) + r = http.request('GET', url) + if r.status == 200: + items = json.loads(r.data.decode('utf-8'))['items'] + print("Found %d jobs" % len(items)) + return [job_info_from_json(item) for item in items] + return None + +def get_failed_jobs(workflow_id): + jobs = get_workflow_jobs(workflow_id) + failed_jobs = [] + for job in jobs: + if job.status == JobStatus.FAILED and job.job_number is not None: + failed_jobs.append(job) + else: + print("Skipping job %s" % str(job)) + return failed_jobs + +class TestResult(Enum): + SUCCESS = "success" + FAILURE = "failure" + SKIPPED = "skipped" + ERROR = "error" + UNKNOWN = "unknown" + +class TestInfo(NamedTuple): + message: str + source: str + run_time: float + file: str + result: TestResult + name: str + classname: str + +def get_job_tests(repo, job_number): + http = urllib3.PoolManager() + url = "https://circleci.com/api/v2/project/gh/%s/cassandra/%s/tests" % (repo, job_number) + r = http.request('GET', url) + if r.status == 200: + tests = [TestInfo(t['message'], t['source'], t['run_time'], t['file'] if 'file' in t else "", TestResult(t['result']), t['name'], t['classname']) for t in json.loads(r.data.decode('utf-8'))['items']] + return tests + return None + + +def get_failed_tests(repo, workflow_id): + failed_jobs = get_failed_jobs(workflow_id) + failed_tests = {} + for job in failed_jobs: + print("Getting tests for job %s" % str(job)) + tests = get_job_tests(repo, job.job_number) + for test in tests: + if test.result == TestResult.FAILURE: + if test.file not in failed_tests: + failed_tests[test.file] = {} + if test.classname not in failed_tests[test.file]: + failed_tests[test.file][test.classname] = {} + test_name = test.name.split("-", 2)[0] + test_name = test_name.split("[", 2)[0] + if test_name not in failed_tests[test.file][test.classname]: + failed_tests[test.file][test.classname][test_name] = set() + failed_tests[test.file][test.classname][test_name].add(job.name) + + return failed_tests diff --git a/dev/scripts/lib/git_utils.py b/dev/scripts/lib/git_utils.py new file mode 100644 index 00000000..0e324a24 --- /dev/null +++ b/dev/scripts/lib/git_utils.py @@ -0,0 +1,323 @@ +import re +import subprocess +import sys +from typing import NamedTuple, Tuple, Optional + + +class VersionedBranch(NamedTuple): + version: Tuple[int, int] + version_string: str + name: str + +class Commit(NamedTuple): + sha: str + author: str + email: str + title: str + body: str + +class BranchMergeInfo(NamedTuple): + release_branch: VersionedBranch + feature_branch: Optional[VersionedBranch] + commits: list[Commit] + +class TicketMergeInfo(NamedTuple): + ticket: str + update_changes: bool + upstream_repo: str + feature_repo: str + merges: list[BranchMergeInfo] + keep_changes_in_circleci: bool + commit_msg_file: str + +NO_VERSION = (-1, -1) +TRUNK_VERSION = (255, 255) + +CASSANRA_BRANCH_VERSION_RE = re.compile(r"cassandra-(\d+)\.(\d+)") +VERSION_RE = re.compile(r"(\d+)\.(\d+)") + +def version_from_re(re, string): + match = re.match(string) + if match: + return (int(match.group(1)), int(match.group(2))) + return None + + +def version_from_branch(branch): + return version_from_re(CASSANRA_BRANCH_VERSION_RE, branch) + + +def version_from_string(version_string): + if version_string == "trunk": + return TRUNK_VERSION + return version_from_re(VERSION_RE, version_string) + + +def version_as_string(version): + if version is None: + return None + if version == NO_VERSION: + return None + if version == TRUNK_VERSION: + return "trunk" + return "%s.%s" % version + + +### GIT functions ### +def guess_base_version(repo, remote_repo, branch): + version = NO_VERSION + + merge_base = None + for l in subprocess.check_output(["git", "log", "--decorate", "--simplify-by-decoration", "--oneline", "%s/%s" % (repo, branch)], text=True).split("\n"): + if "(HEAD" not in l and "(%s/%s" % (repo, branch) not in l: + merge_base = l.split(" ")[0] + break + + matching_versions = [] + if merge_base: + branch_regex = re.compile(r"\s*" + re.escape(remote_repo) + r"/((cassandra-(\d+)\.(\d+))|(trunk))$") + for l in subprocess.check_output(["git", "branch", "-r", "--contains", merge_base], text=True).split("\n"): + match = branch_regex.match(l) + if match: + if match.group(5): + matching_versions.append(TRUNK_VERSION) + elif match.group(2): + matching_versions.append((int(match.group(3)), int(match.group(4)))) + matching_versions.sort() + + if len(matching_versions) == 1: + version = matching_versions[0] + else: + branch_regex = re.compile(r".*?([-/]((\d+)\.(\d+))|(trunk))?$", flags=re.IGNORECASE) + match = branch_regex.match(branch) + if match: + if match.group(5) == "trunk": + version = TRUNK_VERSION + elif match.group(2): + version = (int(match.group(3)), int(match.group(4))) + else: + print("No match for %s" % branch) + if len(matching_versions) > 0: + version = matching_versions[0] + + return version + + +def guess_feature_branches(repo, remote_repo, ticket): + """ + Get the list of branches from the given repository that contain the given ticket, sorted by version ascending. + :param repo: configured apache repository name + :param ticket: ticket number + :return: list of branch names + """ + output = subprocess.check_output(["git", "ls-remote", "--refs", "-h", "-q", repo], text=True) + branch_regex = re.compile(r".*refs/heads/(" + re.escape(ticket) + r"(-(\d+)\.(\d+))?.*)$", flags=re.IGNORECASE) + print(r".*refs/heads/(" + re.escape(ticket) + r"((\d+)\.(\d+))?.*)$") + matching_branches = [] + for line in output.split("\n"): + match = branch_regex.match(line) + if match: + branch = match.group(1) + version = guess_base_version(repo, remote_repo, branch) + matching_branches.append(VersionedBranch(version, match.group(2), branch)) + + matching_branches.sort(key=lambda x: x.version) + return matching_branches + + +def guess_feature_repo_and_ticket(): + """ + Get the remote repository and ticket number from the current git branch. + :return: a tuple (remote_repository, ticket_number) or None if the current branch does not look like a feature branch + """ + output = subprocess.check_output(["git", "status", "-b", "--porcelain=v2"], shell=False).decode("utf-8") + regex = re.compile(r"# branch\.upstream ([^/]+)/([^ ]+)") + match = regex.search(output) + if match: + ticket_regex = re.compile(r"CASSANDRA-(\d+)", flags=re.IGNORECASE) + ticket_match = ticket_regex.search(match.group(2)) + if ticket_match: + return (match.group(1), int(ticket_match.group(1))) + return (match.group(1), None) + return (None, None) + + +def guess_upstream_repo(): + """ + Get the name of the remote repository that points to the apache cassandra repository. Prefers "apache" over "asf". + :return: the remote name + """ + output = subprocess.check_output(["git", "remote", "show"], shell=False) + apache_remote_name = None + for remote_name in output.decode("utf-8").split("\n"): + url = subprocess.check_output(["git", "remote", "get-url", remote_name], shell=False).decode("utf-8").strip() + if "apache/cassandra.git" in url: + return remote_name + if "asf/cassandra/git" in url: + apache_remote_name = remote_name + return apache_remote_name + + +def get_release_branches(repo): + """ + Get the list of main cassandra branches from the given repo, sorted by version ascending. + :param repo: configured apache repository name + :return: list of VersionedBranch objects + """ + output = subprocess.check_output(["git", "ls-remote", "--refs", "-h", "-q", repo], text=True) + branch_regex = re.compile(r".*refs/heads/(cassandra-((\d+)\.(\d+)))$") + + branches = [] + for line in output.split("\n"): + match = branch_regex.match(line) + if match: + branches.append(VersionedBranch((int(match.group(3)), int(match.group(4))), match.group(2), match.group(1))) + + branches.append(VersionedBranch(TRUNK_VERSION, "", "trunk")) + branches.sort(key=lambda x: x.version) + + return branches + + +def get_commits(from_repo, from_branch, to_repo, to_branch): + """ + Get the commit history between two branches, sorted by commit date ascending. + :param from_repo: start repository name or None for local branch + :param from_branch: start branch name + :param to_repo: end repository name or None for local branch + :param to_branch: end branch name + :return: a list of Commit objects + """ + def coordinates(repo, branch): + if repo: + return "%s/%s" % (repo, branch) + else: + return branch + output = subprocess.check_output(["git", "log", "--pretty=format:%h%n%aN%n%ae%n%s%n%b%n%x00", "--reverse", "%s..%s" % (coordinates(from_repo, from_branch), coordinates(to_repo, to_branch))], text=True) + commits = [] + for commit_block in output.split("\0"): + if not commit_block: + continue + match = commit_block.strip("\n").split(sep = "\n", maxsplit = 4) + commits.append(Commit(match[0], match[1], match[2], match[3], match[4] if len(match) > 4 else "")) + return commits + + +def parse_merge_commit_msg(msg): + """ + Parse a merge commit message and return the source and destination branches. + :param msg: a commit message + :return: a tuple of (source_branch, destination_branch) or None if the message is not a merge commit + """ + msg_regex = re.compile(r"Merge branch '(cassandra-\d+\.\d+)' into ((cassandra-(\d+\.\d+))|trunk)") + match = msg_regex.match(msg) + if match: + return (match.group(1), match.group(2)) + return None + + +def ensure_clean_git_tree(): + output = subprocess.check_output(["git", "status", "--porcelain"], text=True) + if output.strip(): + print("Your git tree is not clean. Please commit or stash your changes before running this script.") + sys.exit(1) + + +def get_push_ranges(repo, branches): + """ + Parse the output of git push --atomic -n and return a list of tuples (label, start_commit, end_commit) + :param repo: configured apache repository name + :param branches: list of branch names + :return: list of tuples (label, start_commit, end_commit) + """ + output = subprocess.check_output(["git", "push", "--atomic", "-n", "--porcelain", repo] + branches, shell=False) + range_regex = re.compile(r"^\s+refs/heads/\S+:refs/heads/(\S+)\s+([0-9a-f]+)\.\.([0-9a-f]+)$") + ranges = [] + for line in output.decode("utf-8").split("\n"): + match = range_regex.match(line) + if match: + ranges.append((match.group(1), match.group(2), match.group(3))) + return ranges + + +def check_remote_exists(remote): + try: + return subprocess.check_call(["git", "remote", "get-url", remote], stderr=sys.stderr, stdout=None) == 0 + except subprocess.CalledProcessError: + return False + + +def check_remote_branch_exists(remote, branch): + return subprocess.check_call(["git", "ls-remote", "--exit-code", remote, branch], stderr=sys.stderr, stdout=None) == 0 + + +### User input functions ### + + +def read_with_default(prompt, default): + if default: + value = input("%s [default: %s]: " % (prompt, default)) + else: + value = input("%s: " % prompt) + if not value: + value = default + return value + + +def read_remote_repository(prompt, default): + repo = None + + while not repo: + repo = read_with_default(prompt, default) + if not check_remote_exists(repo): + repo = None + + return repo + + +def read_positive_int(prompt, default): + value = None + while not value: + try: + if default: + value = input("%s [default: %s]: " % (prompt, default)) + else: + value = input(prompt) + if value: + v = int(value) + if v > 0: + return v + else: + return default + except ValueError: + print("Invalid integer value") + value = None + return value + +# from https://gist.github.com/rene-d/9e584a7dd2935d0f461904b9f2950007 +class Colors: + """ ANSI color codes """ + BLACK = "\033[0;30m" + RED = "\033[0;31m" + GREEN = "\033[0;32m" + BROWN = "\033[0;33m" + BLUE = "\033[0;34m" + PURPLE = "\033[0;35m" + CYAN = "\033[0;36m" + LIGHT_GRAY = "\033[0;37m" + DARK_GRAY = "\033[1;30m" + LIGHT_RED = "\033[1;31m" + LIGHT_GREEN = "\033[1;32m" + YELLOW = "\033[1;33m" + LIGHT_BLUE = "\033[1;34m" + LIGHT_PURPLE = "\033[1;35m" + LIGHT_CYAN = "\033[1;36m" + LIGHT_WHITE = "\033[1;37m" + BOLD = "\033[1m" + FAINT = "\033[2m" + ITALIC = "\033[3m" + UNDERLINE = "\033[4m" + BLINK = "\033[5m" + NEGATIVE = "\033[7m" + CROSSED = "\033[9m" + END = "\033[0m" diff --git a/dev/scripts/lib/jira_utils.py b/dev/scripts/lib/jira_utils.py new file mode 100644 index 00000000..3038bc82 --- /dev/null +++ b/dev/scripts/lib/jira_utils.py @@ -0,0 +1,34 @@ +import json + +import urllib3 + + +def get_assignee_from_jira(ticket): + """ + Get the assignee for the given JIRA ticket. + :param ticket: + :return: + """ + http = urllib3.PoolManager() + r = http.request('GET', 'https://issues.apache.org/jira/rest/api/latest/issue/' + ticket) + if r.status == 200: + data = json.loads(r.data.decode('utf-8')) + if data['fields']['assignee']: + return data['fields']['assignee']['displayName'] + return None + + +def get_reviewers_from_jira(ticket): + """ + Get the reviewers for the given JIRA ticket. + :param ticket: + :return: + """ + http = urllib3.PoolManager() + r = http.request('GET', 'https://issues.apache.org/jira/rest/api/latest/issue/' + ticket) + if r.status == 200: + data = json.loads(r.data.decode('utf-8')) + reviewers = data['fields']['customfield_12313420'] + if reviewers: + return [reviewer['displayName'] for reviewer in reviewers] + return None diff --git a/dev/scripts/lib/script_generator.py b/dev/scripts/lib/script_generator.py new file mode 100644 index 00000000..b7cdbd8e --- /dev/null +++ b/dev/scripts/lib/script_generator.py @@ -0,0 +1,125 @@ +import os + +from lib.git_utils import * + +def resolve_version_and_merge_sections(idx: int, merges: list[Tuple[VersionedBranch, bool]]) -> Tuple[Optional[VersionedBranch], list[VersionedBranch]]: + """ + Compute the version and merge sections for a given index in the CHANGES.txt file. + See the unit tests for examples. + + :param idx: the index of the merge + :param merges: list of merges + :return: the version and merge sections + """ + + version_section = None + merge_sections = [] + release_branch, is_patch_defined = merges[idx] + + assert idx > 0 or is_patch_defined, "The first merge must be a patch" + + if idx == 0: # which means that we are in the oldest version + # in this case we just add the title for the version + version_section = release_branch + # no merge section in this case + + elif idx == (len(merges) - 1): # which means that this is a merge for trunk + # in this case version section is either len(merges) - 2 or None + before_last_release_branch, is_patch_defined_for_before_last = merges[idx - 1] + if is_patch_defined_for_before_last: + # version section is defined only if the before last release branch is a patch + version_section = before_last_release_branch + for i in range(idx - 2, -1, -1): + release_branch, _ = merges[i] + merge_sections.append(release_branch) + + elif is_patch_defined: + # otherwise, version section is defined only if there is a patch for the release branch + version_section = release_branch + for i in range(idx - 1, -1, -1): + release_branch, _ = merges[i] + merge_sections.append(release_branch) + + return version_section, merge_sections + +def generate_script(ticket_merge_info: TicketMergeInfo): + assert ticket_merge_info.merges[0].feature_branch is not None + + script_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) + + script = ["#!/bin/bash", "", "set -xe", "", "[[ -z $(git status --porcelain) ]] # worktree must be clean"] + + script.append("") + if ticket_merge_info.update_changes: + script.append("# Edit the commit message, the first will be used as the change title to update CHNAGES.txt") + else: + script.append("# Edit the commit message") + script.append("$(git config --get core.editor) %s" % ticket_merge_info.commit_msg_file) + script.append("") + + merges = ticket_merge_info.merges + # index of first merge with undefined feature branch + for idx in range(0, len(merges)): + merge = merges[idx] + script.append("") + script.append("") + script.append("") + if merge.feature_branch is not None: + script.append("# Commands for merging %s -> %s" % (merge.feature_branch.name, merge.release_branch.name)) + else: + script.append("# Commands for skipping -> %s" % merge.release_branch.name) + script.append("#" * 80) + + if merge.feature_branch: + # ensure that there is at least one non-merge commit in the feature branch + assert len([c for c in merge.commits if parse_merge_commit_msg(c.title) is None]) > 0 + + closed = True + script.append("git switch %s" % merge.release_branch.name) + script.append("git reset --hard %s/%s" % (ticket_merge_info.upstream_repo, merge.release_branch.name)) + commits = [] + if idx == 0: + # oldest version + script.append("git cherry-pick %s # %s - %s" % ( + merge.commits[0].sha, merge.commits[0].author, merge.commits[0].title)) + commits = merge.commits[1:] + else: + script.append("git merge -s ours --log --no-edit %s" % merges[idx - 1].release_branch.name) + commits = merge.commits + + for commit in commits: + merge_msg = parse_merge_commit_msg(commit.title) + if merge_msg: + script.append("# skipping merge commit %s %s - %s" % (commit.sha, commit.author, commit.title)) + else: + script.append("git cherry-pick -n %s # %s - %s" % (commit.sha, commit.author, commit.title)) + closed = False + + version_section, merge_sections = resolve_version_and_merge_sections(idx, [(m.release_branch, m.feature_branch is not None) for m in merges]) + if ticket_merge_info.update_changes and version_section: + script.append("python3 %s/update_changes.py '%s' '%s' '%s' %s" % (script_dir, + ticket_merge_info.ticket, + version_as_string(version_section.version), + ",".join([version_as_string(m.version) for m in merge_sections]), + '"$(head -n 1 %s)"' % ticket_merge_info.commit_msg_file)) + + script.append("git add CHANGES.txt") + closed = False + + if not closed: + script.append("git commit --amend --no-edit") + + if idx == 0: + script.append("git commit --allow-empty --amend --file %s" % ticket_merge_info.commit_msg_file) + + if not ticket_merge_info.keep_changes_in_circleci: + script.append("[[ -n \"$(git diff --name-only %s/%s..HEAD -- .circleci/)\" ]] && (git diff %s/%s..HEAD -- .circleci/ | git apply -R --index) && git commit -a --amend --no-edit # Remove all changes in .circleci directory if you need to" % (ticket_merge_info.upstream_repo, merge.release_branch.name, ticket_merge_info.upstream_repo, merge.release_branch.name)) + script.append("git diff --name-only %s/%s..HEAD # print a list of all changes files" % (ticket_merge_info.upstream_repo, merge.release_branch.name)) + + script.append("") + script.append("") + script.append("") + script.append("# After executing the above commands, please run the following verification, and manually inspect the results of the commands it generates") + script.append("python3 %s/verify_git_history.py '%s' '%s'" % (script_dir, ticket_merge_info.upstream_repo, ",".join([m.release_branch.name for m in merges]))) + + return script diff --git a/dev/scripts/prepare_merge_commands.py b/dev/scripts/prepare_merge_commands.py new file mode 100644 index 00000000..752301e0 --- /dev/null +++ b/dev/scripts/prepare_merge_commands.py @@ -0,0 +1,145 @@ +import os +import tempfile + +from lib.script_generator import generate_script +from lib.git_utils import * +from lib.jira_utils import * + +ensure_clean_git_tree() + +### Read feature repo, upstream repo and ticket +print("Remote repositories:") +print("") +subprocess.check_call(["git", "remote", "show"]) +print("") + +upstream_repo = read_remote_repository("Enter the name of the remote repository that points to the upstream Apache Cassandra", guess_upstream_repo()) + +feature_repo, ticket_number = guess_feature_repo_and_ticket() +feature_repo = read_remote_repository("Enter the name of the remote repository that points to the upstream feature branch", feature_repo) + +ticket_number = read_positive_int("Enter the ticket number (for example: '12345'): ", ticket_number) +ticket = "CASSANDRA-%s" % ticket_number + +print("") +print("Fetching from %s" % upstream_repo) +subprocess.check_output(["git", "fetch", upstream_repo]) +if feature_repo != upstream_repo: + print("Fetching from %s" % feature_repo) + subprocess.check_output(["git", "fetch", feature_repo]) + + +### Get the list of release branches and feature branches ### + +release_branches = get_release_branches(upstream_repo) +if len(release_branches) == 0: + print("No release branches found in %s" % upstream_repo) + sys.exit(1) +print("Found the following release branches:\n%s" % "\n".join(["%s: %s" % (version_as_string(b.version), b.name) for b in release_branches])) +print("") + +feature_branches = guess_feature_branches(feature_repo, upstream_repo, ticket) +print("Found the following feature branches:\n%s" % "\n".join(["%s: %s" % (version_as_string(b.version), b.name) for b in feature_branches])) +print("") + +### Read the oldest release version the feature applies to ### + +guessed_oldest_feature_version = feature_branches[0].version if len(feature_branches) > 0 else None +oldest_release_version = None +while not oldest_release_version: + oldest_release_version_str = read_with_default("Enter the oldest release version to merge into", version_as_string(guessed_oldest_feature_version)) + if oldest_release_version_str: + oldest_release_version = version_from_string(oldest_release_version_str) + if oldest_release_version not in [b.version for b in release_branches]: + print("Invalid release version: %s" % str(oldest_release_version)) + oldest_release_version = None + +### Read the feature branches corresponding to each release branch ### + +target_release_branches = [b for b in release_branches if b.version >= oldest_release_version] +merges = [] +for release_branch in target_release_branches: + # find first feature branch whose version is the same as the version of the release branch + guessed_matching_feature_branch = next((b for b in feature_branches if b.version == release_branch.version), None) + guessed_matching_feature_branch_name = guessed_matching_feature_branch.name if guessed_matching_feature_branch else "none" + merge = None + while merge is None: + matching_feature_branch_name = read_with_default("Enter the name of the feature branch to merge into %s or type 'none' if there is no feature branch for this release" % release_branch.name, guessed_matching_feature_branch_name) + if matching_feature_branch_name == "none": + if len(merges) == 0: + print("Feature branch for the oldest release must be provided") + continue + merge = BranchMergeInfo(release_branch, None, []) + else: + if matching_feature_branch_name in [b.name for b in feature_branches] or check_remote_branch_exists(feature_repo, matching_feature_branch_name): + merge = BranchMergeInfo(release_branch, VersionedBranch(release_branch.version, NO_VERSION, matching_feature_branch_name), get_commits(upstream_repo, release_branch.name, feature_repo, matching_feature_branch_name)) + else: + print("Invalid feature branch name: %s" % matching_feature_branch_name) + merges.append(merge) + + +### Read the change title ### + +need_changes_txt_entry = False +response = None +while response not in ["yes", "no"]: + response = read_with_default("Do you want the script to add a line to CHANGES.txt? (yes/no)", "yes") +if response == "yes": + update_changes = True +else: + update_changes = False + +### Keep the circleci config changes? ### +keep_changes_in_circleci = False +response = None +while response not in ["yes", "no"]: + response = read_with_default("Do you want to keep changes in .circleci directory? (yes/no)", "no") +if response == "yes": + keep_changes_in_circleci = True + +### Generate commit message ### +commit_msg = merges[0].commits[0].title + "\n\n" +commit_msg = commit_msg + merges[0].commits[0].body + "\n\n" +for commit in merges[0].commits[1:]: + commit_msg = commit_msg + " - " + commit.title + "\n" + commit.body + "\n\n" + +authors = ["%s <%s>" % (c.author, c.email) for c in merges[0].commits] +authors = list(set(authors)) +authors.sort() +assignee = get_assignee_from_jira(ticket) +reviewers = get_reviewers_from_jira(ticket) +if assignee: + commit_msg = commit_msg + "Patch by %s" % assignee +if reviewers: + commit_msg = commit_msg + "; reviewed by %s" % ", ".join(reviewers) +commit_msg = commit_msg + " for %s" % ticket +commit_msg = commit_msg + "\n\n" +for author in authors: + commit_msg = commit_msg + "Co-authored-by: %s\n" % author + +temp_dir = tempfile.gettempdir() +commit_msg_file = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False) +commit_msg_file.write(commit_msg.encode('utf-8')) + +print("") +print("Commit message saved to %s - you will be asked to edit" % commit_msg_file.name) + +### Generate the script ### + +ticket_merge_info = TicketMergeInfo(ticket, update_changes, upstream_repo, feature_repo, merges, keep_changes_in_circleci, commit_msg_file.name) + +script = generate_script(ticket_merge_info) + +# Read the filename to save the script to from either the command line or from the user +if len(sys.argv) > 1: + filename = sys.argv[1] +else: + filename = read_with_default("Enter the filename to save the script to", "../merge_%s.sh" % ticket) + +# Save the script to the file +with open(filename, "w") as f: + for s in script: + f.write(s + "\n") + +# make the script executable +os.chmod(filename, 0o755) diff --git a/dev/scripts/update_changes.py b/dev/scripts/update_changes.py new file mode 100644 index 00000000..2b011159 --- /dev/null +++ b/dev/scripts/update_changes.py @@ -0,0 +1,161 @@ +import re +import subprocess +import sys +from typing import NamedTuple, Tuple + +from lib.git_utils import * + + +class MergeSection(NamedTuple): + version: Tuple[int, int] + messages: list[str] + + +class ReleaseSection(NamedTuple): + version: Tuple[int, int] + version_string: str + messages: list[str] + merge_sections: list[MergeSection] + + +def read_changes_file(ticket: str) -> list[ReleaseSection]: + """ + Read the changes file and return a list of release sections. + :return: a list of release sections + """ + merge_section_regex = re.compile(r"^Merged from (\d+)\.(\d+):") + release_sections = [] + with open("CHANGES.txt", "r") as f: + lines = f.readlines() + + messages = [] + merge_sections = [] + release_section = None + merge_section = None + + # go through each line and record its index if it matches the pattern \d+\.\d+.* + for i in range(len(lines)): + version = version_from_string(lines[i]) + merge_version = version_from_re(merge_section_regex, lines[i]) + + if version: + if merge_section: + merge_sections.append(merge_section) + + if release_section: + release_sections.append(release_section) + + messages = [] + merge_sections = [] + merge_section = None + release_section = ReleaseSection(version, lines[i], messages, merge_sections) + + elif merge_version: + if merge_section: + merge_sections.append(merge_section) + + messages = [] + merge_section = MergeSection(merge_version, messages) + + elif lines[i].strip(): + if (ticket in lines[i]): + print("Found duplicate message in line %d: %s" % (i + 1, lines[i])) + exit(1) + messages.append(lines[i]) + + if release_section: + release_sections.append(release_section) + + return release_sections + + +# write a text file with the changes +def write_changes_file(release_sections: list[ReleaseSection]): + """ + Write the changes file. + :param release_sections: the release sections to write + """ + with open("CHANGES.txt", "w") as f: + for version_section in release_sections: + f.write(version_section.version_string) + for message in version_section.messages: + f.write(message) + + for merge_section in version_section.merge_sections: + f.write("Merged from %s:\n" % version_as_string(merge_section.version)) + for message in merge_section.messages: + f.write(message) + + f.write("\n\n") + + +def get_or_insert_merge_section(target_section: ReleaseSection, target_version: Tuple[int, int]) -> MergeSection: + """ + Get the merge section for the given version in the given release section. If the merge section does not exist, it is + created and inserted in the correct position. + :param target_section: the release section to search for the merge section + :param target_version: the version of the merge section to search for + :return: found or created merge section + """ + target_merge_section = None + insertion_index = -1 + for idx in range(len(target_section.merge_sections)): + insertion_index = idx + 1 + if target_section.merge_sections[idx].version == target_version: + # merge section already exists, return it + target_merge_section = target_section.merge_sections[idx] + break + elif target_section.merge_sections[idx].version < target_version: + # merge section does not exist because we just reached the first merge section with a lower version + insertion_index = idx - 1 + break + + if not target_merge_section: + # merge section does not exist, create it and insert in the correct position + target_merge_section = MergeSection(target_version, []) + target_section.merge_sections.insert(insertion_index, target_merge_section) + + return target_merge_section + + +# check if the commond line args contain the message and a list of branches +if len(sys.argv) < 5: + print("Adds a change info to the CHANGES.txt file.") + print("Usage: %s " % sys.argv[0]) + print("") + print("Example: %s CASSANDRA-12345 '4.1' '3.11,4.0' 'Some awesome change'" % sys.argv[0]) + print("It adds a change info to the top of 'Merged from 3.11' section for the latest '4.1' section, ensuring that 'Merged from 4.0' is there as well.") + exit(1) + +ticket = sys.argv[1] +target_version_section_str = sys.argv[2] +target_merge_sections_strs = [s.strip() for s in sys.argv[3].split(",") if s.strip()] +title = sys.argv[4] + +release_sections = read_changes_file(ticket) + +if target_version_section_str == version_as_string(TRUNK_VERSION): + # if the target version is trunk, we prepend the message to the first encountered version + target_section = release_sections[0] +else: + target_section = None + for section in release_sections: + if version_as_string(section.version) == target_version_section_str: + target_section = section + break + +assert target_section, "Could not find target version section %s" % target_version_section_str + +merge_section = None +for merge_section_str in target_merge_sections_strs: + print("Looking for merge section %d" % len(target_merge_sections_strs)) + merge_section = get_or_insert_merge_section(target_section, version_from_string(merge_section_str)) + +new_message = " * %s (%s)\n" % (title, ticket) + +if merge_section: + merge_section.messages.insert(0, new_message) +else: + target_section.messages.insert(0, new_message) + +write_changes_file(release_sections) diff --git a/dev/scripts/verify_git_history.py b/dev/scripts/verify_git_history.py new file mode 100644 index 00000000..586d697c --- /dev/null +++ b/dev/scripts/verify_git_history.py @@ -0,0 +1,116 @@ +from lib.git_utils import * + +# The script does two things: +# 1. Check that the history of the main branches (trunk, 4.0, 4.1, etc) is valid. +# The history of the oldest branch must contain only one commit, and that commit must not be a merge commit. +# The history of each newer branch must contain the history of the previous branch and a merge commit from that +# previous branch. +# 2. Execute dry run of the push command and parse the results. Then, generate diff and show commands for the user +# to manually inspect the changes. + +# Example usage: +# python3 dev/scripts/verify_git_history.py apache cassandra-4.0,cassandra-4.1,trunk +# +# The script will check the history of local cassandra-4.0, cassandra-4.1 and trunk branches against their remote +# counterparts in the apache repository. + +# Read the command line arguments and validate them +if len(sys.argv) != 3: + print("Usage: %s <upstream-repo-name> <comma-separated-branches-to-push>" % sys.argv[0]) + exit(1) + +repo = sys.argv[1] +main_branches = [s.strip() for s in sys.argv[2].split(",") if s.strip()] + +if len(main_branches) == 0: + print("No branches specified") + exit(1) + +# get the patch commit message +history = get_commits(repo, main_branches[0], None, main_branches[0]) + +print("") +print("Checking branch %s" % main_branches[0]) +print("Expected merges: []") +print("History: \n - -%s" % "\n - ".join(str(x) for x in history)) + +# history for the first branch must contain only one commit +if len(history) != 1: + print("%sInvalid history for branch %s, must contain only one commit, but found %d: \n\n%s%s\n" % ( + Colors.RED, + main_branches[0], len(history), "\n".join(str(x) for x in history), + Colors.END)) + exit(1) + +# check if the commit message is valid, that is, it must not be a merge commit +if parse_merge_commit_msg(history[0].title): + print("%sInvalid commit message for branch %s, must not be a merge commit, but found: \n\n%s%s\n" % ( + Colors.RED, + main_branches[0], history[0].title, + Colors.END)) + exit(1) + +# Check the history of the branches to confirm that each branch contains exactly one main commit +# and the rest are the merge commits from the previous branch in order +expected_merges = [] +prev_branch = main_branches[0] +prev_history = history +for branch in main_branches[1:]: + print("-" * 80) + + expected_merges.append((prev_branch, branch)) + history = get_commits(repo, branch, None, branch) + + print("") + print("Checking branch %s" % branch) + print("Expected merges: %s" % str(expected_merges)) + print("History between %s/%s..local %s: \n - %s" % (repo, branch, branch, "\n - ".join(str(x) for x in history))) + + if history[:-1] != prev_history: + print("%sInvalid history for branch %s, must include the history of branch %s:\n\n%s\n\n, but found: \n\n%s%s\n" % ( + Colors.RED, + branch, prev_branch, + "\n".join(str(x) for x in prev_history), + "\n".join(str(x) for x in history), + Colors.END)) + + # expect that the rest of the commits are merge commits matching the expected merges in the same order + for i in range(1, len(history)): + merge = parse_merge_commit_msg(history[i].title) + if not merge: + print("%sInvalid commit message for branch %s, must be a merge commit, but found: \n%s%s\n" % ( + Colors.RED, + branch, history[i], + Colors.END)) + break + + if merge != expected_merges[i - 1]: + print( + "%sInvalid merge commit for branch %s, expected: %s, but found: %s%s\n" % ( + Colors.RED, + branch, expected_merges[i - 1], merge, + Colors.END)) + break + + prev_branch = branch + prev_history = history + +# finally we print the commands to explore the changes in each push range +print("=" * 80) + +push_ranges = get_push_ranges(repo, main_branches) +# number of push ranges must match the number of branches we want to merge +if len(push_ranges) != len(main_branches): + print("" + "%sInvalid number of push ranges, expected %d, but found %d:\n%s%s" % ( + Colors.RED, + len(main_branches), len(push_ranges), "\n".join(str(x) for x in push_ranges), + Colors.END)) + exit(1) + +for push_range in push_ranges: + print("Push range for branch %s: %s..%s" % (push_range[0], push_range[1], push_range[2])) + print("%sgit diff --name-only %s..%s%s" % (Colors.LIGHT_BLUE, push_range[1], push_range[2], Colors.END)) + print("%sgit show %s..%s%s" % (Colors.LIGHT_BLUE, push_range[1], push_range[2], Colors.END)) + print("") + print("-" * 80) diff --git a/dev/test/__init__.py b/dev/test/__init__.py new file mode 100644 index 00000000..5792318c --- /dev/null +++ b/dev/test/__init__.py @@ -0,0 +1,5 @@ +import sys +import os + +PROJECT_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "scripts") +sys.path.append(PROJECT_PATH) diff --git a/dev/test/test_resolving_version_and_merge_sections.py b/dev/test/test_resolving_version_and_merge_sections.py new file mode 100644 index 00000000..f7ec7e17 --- /dev/null +++ b/dev/test/test_resolving_version_and_merge_sections.py @@ -0,0 +1,141 @@ +import unittest + +from lib.script_generator import * + + +class MyTestCase(unittest.TestCase): + v_50 = VersionedBranch((5, 0), "5.0", "trunk") + v_41 = VersionedBranch((4, 1), "4.1", "cassandra-4.1") + v_40 = VersionedBranch((4, 0), "4.0", "cassandra-4.0") + v_311 = VersionedBranch((3, 11), "3.11", "cassandra-3.11") + v_30 = VersionedBranch((3, 0), "3.0", "cassandra-3.0") + + # If the change is only for trunk, then: + # - we add the entry in the trunk section (top section). + def test_trunk(self): + merges = [(self.v_50, True)] + version_section, merge_sections = resolve_version_and_merge_sections(0, merges) + self.assertEqual(version_section, self.v_50) + self.assertEqual(merge_sections, []) + + # If the change is for 4.1 and trunk, then: + # - in 4.1, we add the entry in the 4.1 section (top section) + # - in trunk, we add the entry in the 4.1 section (first encountered 4.1 section) + def test_41_trunk(self): + merges = [(self.v_41, True), (self.v_50, True)] + version_section, merge_sections = resolve_version_and_merge_sections(0, merges) + self.assertEqual(version_section, self.v_41) + self.assertEqual(merge_sections, []) + + version_section, merge_sections = resolve_version_and_merge_sections(1, merges) + self.assertEqual(version_section, self.v_41) + self.assertEqual(merge_sections, []) + + # If the change is for 4.0, 4.1 and trunk, then: + # - in 4.0, we add the entry in the 4.0 section (top section) + # - in 4.1, we add then entry in the 4.1 section (top section), under "Merged from 4.0" subsection + # - in trunk, we add the entry in the 4.1 section (first encountered 4.1 section), under "Merged from 4.0" subsection + def test_40_41_trunk(self): + merges = [(self.v_40, True), (self.v_41, True), (self.v_50, True)] + version_section, merge_sections = resolve_version_and_merge_sections(0, merges) + self.assertEqual(version_section, self.v_40) + self.assertEqual(merge_sections, []) + + version_section, merge_sections = resolve_version_and_merge_sections(1, merges) + self.assertEqual(version_section, self.v_41) + self.assertEqual(merge_sections, [self.v_40]) + + version_section, merge_sections = resolve_version_and_merge_sections(2, merges) + self.assertEqual(version_section, self.v_41) + self.assertEqual(merge_sections, [self.v_40]) + + # If the change is for 4.0 and not for 4.1 or trunk, then: + # - in 4.0, we add the entry in the 4.0 section (top section) + # - in 4.1, no changes + # - in trunk, no changes + def test_40(self): + merges = [(self.v_40, True), (self.v_41, False), (self.v_50, False)] + version_section, merge_sections = resolve_version_and_merge_sections(0, merges) + self.assertEqual(version_section, self.v_40) + self.assertEqual(merge_sections, []) + + version_section, merge_sections = resolve_version_and_merge_sections(1, merges) + self.assertEqual(version_section, None) + self.assertEqual(merge_sections, []) + + version_section, merge_sections = resolve_version_and_merge_sections(2, merges) + self.assertEqual(version_section, None) + self.assertEqual(merge_sections, []) + + # If the change is for 3.11 and 4.1 and not for 4.0 or trunk, then: + # - in 3.11, we add the entry in the 3.11 section (top section) + # - in 4.0, no changes + # - in 4.1, we add the entry in the 4.1 section (top section), under "Merged from 3.11" subsection + # - in trunk, we add the entry in the 4.1 section (first encountered 4.1 section), under "Merged from 3.11" subsection + def test_311_41(self): + merges = [(self.v_311, True), (self.v_40, False), (self.v_41, True), (self.v_50, False)] + version_section, merge_sections = resolve_version_and_merge_sections(0, merges) + self.assertEqual(version_section, self.v_311) + self.assertEqual(merge_sections, []) + + version_section, merge_sections = resolve_version_and_merge_sections(1, merges) + self.assertEqual(version_section, None) + self.assertEqual(merge_sections, []) + + version_section, merge_sections = resolve_version_and_merge_sections(2, merges) + self.assertEqual(version_section, self.v_41) + self.assertEqual(merge_sections, [self.v_40, self.v_311]) + + version_section, merge_sections = resolve_version_and_merge_sections(3, merges) + self.assertEqual(version_section, self.v_41) + self.assertEqual(merge_sections, [self.v_40, self.v_311]) + + # If the change is for 4.0 and trunk, and not for 4.1, then: + # - in 4.0, we add the entry in the 4.0 section (top section) + # - in 4.1, no changes + # - in trunk, no changes + def test_40_trunk(self): + merges = [(self.v_40, True), (self.v_41, False), (self.v_50, True)] + version_section, merge_sections = resolve_version_and_merge_sections(0, merges) + self.assertEqual(version_section, self.v_40) + self.assertEqual(merge_sections, []) + + version_section, merge_sections = resolve_version_and_merge_sections(1, merges) + self.assertEqual(version_section, None) + self.assertEqual(merge_sections, []) + + version_section, merge_sections = resolve_version_and_merge_sections(2, merges) + self.assertEqual(version_section, None) + self.assertEqual(merge_sections, []) + + # If the change is for 3.0, 3.11, 4.0, 4.1 and trunk, then: + # - in 3.0, we add the entry in the 3.0 section (top section) + # - in 3.11, we add the entry in the 3.11 section (top section), under "Merged from 3.0" subsection + # - in 4.0, we add the entry in the 4.0 section (top section), under "Merged from 3.0" subsection + # - in 4.1, we add the entry in the 4.1 section (top section), under "Merged from 3.0" subsection + # - in trunk, we add the entry in the 4.1 section (first encountered 4.1 section), under "Merged from 3.0" subsection + def test_30_311_40_41_trunk(self): + merges = [(self.v_30, True), (self.v_311, True), (self.v_40, True), (self.v_41, True), (self.v_50, True)] + version_section, merge_sections = resolve_version_and_merge_sections(0, merges) + self.assertEqual(version_section, self.v_30) + self.assertEqual(merge_sections, []) + + version_section, merge_sections = resolve_version_and_merge_sections(1, merges) + self.assertEqual(version_section, self.v_311) + self.assertEqual(merge_sections, [self.v_30]) + + version_section, merge_sections = resolve_version_and_merge_sections(2, merges) + self.assertEqual(version_section, self.v_40) + self.assertEqual(merge_sections, [self.v_311, self.v_30]) + + version_section, merge_sections = resolve_version_and_merge_sections(3, merges) + self.assertEqual(version_section, self.v_41) + self.assertEqual(merge_sections, [self.v_40, self.v_311, self.v_30]) + + version_section, merge_sections = resolve_version_and_merge_sections(4, merges) + self.assertEqual(version_section, self.v_41) + self.assertEqual(merge_sections, [self.v_40, self.v_311, self.v_30]) + + +if __name__ == '__main__': + unittest.main()