diff --git a/py_css/models/base.py b/py_css/models/base.py index fdc2273..06ff8a1 100644 --- a/py_css/models/base.py +++ b/py_css/models/base.py @@ -211,12 +211,19 @@ def gen_context_docs(context: Context) -> Generator[Document, None, None]: # if in df there are multiple rows that have the same qid and docno, keep the one with the highest score. For the ones removed, add a row each with the EMPTY_PLACEHOLDER_DOC rank_size_per_qid: int = df.groupby("qid").size().max() + print(f"Rank size per qid: {rank_size_per_qid}") df = df.sort_values(["qid", "docno", "score"], ascending=[True, True, False]) + total_size = df.shape[0] df = df.drop_duplicates(subset=["qid", "docno"], keep="first") + dropped_any: bool = total_size != df.shape[0] + print(f"Dropped any: {dropped_any}") df = df.reset_index(drop=True) df = self.pad_empty_documents( df, df["qid"].unique(), rank_size_per_qid, df[["qid", "query"]] ) + print(f"Number of max rank size per qid: {df.groupby('qid').size().max()}") + df = df.reset_index(drop=True) + df = df.sort_values(["qid", "rank"], ascending=[True, True]) for query, context in context_list: # check if there is a row in the df with "qid" == query.query_id, where "docno" == EMPTY_PLACEHOLDER_DOC.docno diff --git a/py_css/models/baseline.py b/py_css/models/baseline.py index a536fd6..1125a6e 100644 --- a/py_css/models/baseline.py +++ b/py_css/models/baseline.py @@ -48,6 +48,7 @@ def __init__( bm25 = pt.BatchRetrieve(index, wmodel="BM25", metadata=["docno", "text"]) self.top_docs = ((bm25 % bm25_docs).compile(), bm25_docs) self.mono_t5 = (MonoT5ReRanker(batch_size=BATCH_SIZE), mono_t5_docs) + # self.mono_t5 = (MonoT5ReRanker(batch_size=BATCH_SIZE, model="castorini/monot5-large-msmarco"), mono_t5_docs) self.duo_t5 = (DuoT5ReRanker(batch_size=BATCH_SIZE), duo_t5_docs) def transform_input( diff --git a/py_css/models/baseline_prf.py b/py_css/models/baseline_prf.py index f8876ff..5a9990b 100644 --- a/py_css/models/baseline_prf.py +++ b/py_css/models/baseline_prf.py @@ -55,6 +55,7 @@ def __init__( rm3 = pt.rewrite.RM3(index, fb_docs=rm3_fb_docs, fb_terms=rm3_fb_terms) self.top_docs = ((bm25 >> rm3 >> bm25) % bm25_docs, bm25_docs) self.mono_t5 = (MonoT5ReRanker(batch_size=BATCH_SIZE), mono_t5_docs) + # self.mono_t5 = (MonoT5ReRanker(batch_size=BATCH_SIZE, model="castorini/monot5-large-msmarco"), mono_t5_docs) self.duo_t5 = (DuoT5ReRanker(batch_size=BATCH_SIZE), duo_t5_docs) def transform_input(