From 3462732a09feec9ac56705794f10d8a168bb48ba Mon Sep 17 00:00:00 2001 From: Desmond Cheong Date: Thu, 30 Jan 2025 13:39:05 -0800 Subject: [PATCH] ci: Allow upstream git refs to be used for benchmarking (#3730) Had trouble running the benchmarking tool when local branch names didn't match remote branch names. Fix it so that we check for upstream branch names and use them. E.g. if I run the tool on `local-branch-name`, it finds `origin/user/remote-branch-name` then runs the action on `user/remote-branch-name`. --- tools/git_utils.py | 35 +++++++++++++++++++++++++++-------- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/tools/git_utils.py b/tools/git_utils.py index 0e75301646..1d577d164f 100644 --- a/tools/git_utils.py +++ b/tools/git_utils.py @@ -62,17 +62,36 @@ def get_latest_run(workflow: Workflow) -> WorkflowRun: raise RuntimeError("Unable to list all workflow invocations") -def get_name_and_commit_hash(branch_name: Optional[str]) -> tuple[str, str]: - branch_name = branch_name or "HEAD" - name = ( - subprocess.check_output(["git", "rev-parse", "--abbrev-ref", branch_name], stderr=subprocess.STDOUT) +def get_name_and_commit_hash(local_branch_name: Optional[str]) -> tuple[str, str]: + local_branch_name = local_branch_name or "HEAD" + remote_branch_name = local_branch_name + + try: + # Check if the branch has a remote tracking branch. + local_branch_name = ( + subprocess.check_output( + ["git", "rev-parse", "--abbrev-ref", f"{local_branch_name}@{{upstream}}"], stderr=subprocess.STDOUT + ) + .strip() + .decode("utf-8") + ) + # Strip the upstream name from the branch to get the branch name on the remote repo. + remote_branch_name = local_branch_name.split("/", 1)[1] + except subprocess.CalledProcessError: + local_branch_name = ( + subprocess.check_output(["git", "rev-parse", "--abbrev-ref", local_branch_name], stderr=subprocess.STDOUT) + .strip() + .decode("utf-8") + ) + remote_branch_name = local_branch_name + + commit_hash = ( + subprocess.check_output(["git", "rev-parse", local_branch_name], stderr=subprocess.STDOUT) .strip() .decode("utf-8") ) - commit_hash = ( - subprocess.check_output(["git", "rev-parse", branch_name], stderr=subprocess.STDOUT).strip().decode("utf-8") - ) - return name, commit_hash + # Return the remote branch name for the github action. + return remote_branch_name, commit_hash def parse_questions(questions: Optional[str], total_number_of_questions: int) -> list[int]: