-
-
Notifications
You must be signed in to change notification settings - Fork 528
/
Copy pathrecommender.py
19 lines (18 loc) · 1.07 KB
/
recommender.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import numpy as np
from sentence_transformers import SentenceTransformer
from paper import ArxivPaper
from datetime import datetime
def rerank_paper(candidate:list[ArxivPaper],corpus:list[dict],model:str='avsolatorio/GIST-small-Embedding-v0') -> list[ArxivPaper]:
encoder = SentenceTransformer(model)
#sort corpus by date, from newest to oldest
corpus = sorted(corpus,key=lambda x: datetime.strptime(x['data']['dateAdded'], '%Y-%m-%dT%H:%M:%SZ'),reverse=True)
time_decay_weight = 1 / (1 + np.log10(np.arange(len(corpus)) + 1))
time_decay_weight = time_decay_weight / time_decay_weight.sum()
corpus_feature = encoder.encode([paper['data']['abstractNote'] for paper in corpus])
candidate_feature = encoder.encode([paper.summary for paper in candidate])
sim = encoder.similarity(candidate_feature,corpus_feature) # [n_candidate, n_corpus]
scores = (sim * time_decay_weight).sum(axis=1) * 10 # [n_candidate]
for s,c in zip(scores,candidate):
c.score = s.item()
candidate = sorted(candidate,key=lambda x: x.score,reverse=True)
return candidate