-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdar_type.py
70 lines (40 loc) · 1.65 KB
/
dar_type.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import typing
import numpy as np
import transformers
from datasets.arrow_dataset import Dataset
MetricScoreList = typing.List[float]
MetricScoreDict = typing.Dict[str, MetricScoreList]
TextList = typing.List[str]
TextListSegments = typing.List[typing.List[str]]
TextSegments = typing.List[str]
class Embedder(typing.Protocol): # sentence_transformers.SentenceTransformer
__name__: str
def __call__(self, sentences: typing.List[str]) -> np.ndarray:
...
class SimilarityMatrixFunc(typing.Protocol):
__name__: str
def __call__(self, cand_segments: TextSegments, ref_segments: TextSegments) -> typing.Optional[np.ndarray]:
...
SimpleSimilarityMatrixFunc = typing.Callable[[TextSegments, TextSegments], np.ndarray]
Pipeline = transformers.Pipeline
class PipelinesList():
__name__: str
pipelines: typing.List[Pipeline]
PipelinesDict = typing.Dict[str, PipelinesList]
class MetricComputeFunc(typing.Protocol):
def __call__(self, predictions: TextList, references: TextList) -> MetricScoreDict:
...
MetricDict = typing.Dict[str, MetricComputeFunc]
class DocWarning(Warning):
"""
Warning raised when the document is malformed
e.g., an empty document, which is meaningless to metrics
"""
class MNLICategory(typing.TypedDict):
label: str
score: float
MNLICategories = typing.List[MNLICategory]
MNLISimilarityExpression = typing.Callable[[MNLICategories], float]
SummaryLengthExpression = typing.Callable[[typing.Optional[Dataset]], typing.Optional[int]]
SentenceWeightFunction = typing.Callable[[np.ndarray], float]
IdfScoreFunction = typing.Callable[[TextSegments, np.ndarray], np.ndarray]