forked from osalpekar/llm-target-determinator
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgen_pr_items.py
77 lines (59 loc) · 1.91 KB
/
gen_pr_items.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""Various ways to generate strings from a PR on which to embed and compare to
test embeddings"""
import os
import subprocess
from pathlib import Path
from typing import List
REPO_ROOT = Path(__file__).resolve().parent
def get_merge_base() -> str:
default_branch = f"origin/{os.environ.get('GIT_DEFAULT_BRANCH', 'main')}"
merge_base = (
subprocess.check_output(
["git", "merge-base", default_branch, "HEAD"],
cwd=REPO_ROOT.parent / "pytorch",
)
.decode()
.strip()
)
head = (
subprocess.check_output(
["git", "rev-parse", "HEAD"], cwd=REPO_ROOT.parent / "pytorch"
)
.decode()
.strip()
)
base_commit = merge_base
if base_commit == head:
# We are on the default branch, so check for changes since the last commit
base_commit = "HEAD^"
return base_commit
def query_changed_files() -> List[str]:
base_commit = get_merge_base()
proc = subprocess.run(
["git", "diff", "--name-only", base_commit, "HEAD"],
cwd=REPO_ROOT.parent / "pytorch",
capture_output=True,
check=False,
)
if proc.returncode != 0:
raise RuntimeError("Unable to get changed files")
lines = proc.stdout.decode().strip().split("\n")
lines = [line.strip() for line in lines]
items = []
for file in lines:
with open(REPO_ROOT.parent / "pytorch" / file) as f:
items.append(f.read())
return items
def get_git_diff():
base_commit = get_merge_base()
proc = subprocess.run(
["git", "diff", base_commit, "HEAD"],
cwd=REPO_ROOT.parent / "pytorch",
capture_output=True,
check=False,
)
if proc.returncode != 0:
raise RuntimeError("Unable to get git diff")
lines = proc.stdout.decode().strip()
return [lines]
PR_ITEMS = {"GITDIFF": get_git_diff, "CHANGEDFILES": query_changed_files}