Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upstream/2 write a predictor class to mimic the gec model #172

Open
wants to merge 77 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
aa05578
Create Python package & bump to Python 3.8
Sep 14, 2022
03ab20e
Address pr comments
Sep 14, 2022
ad1d4c0
Address pr comments
Sep 14, 2022
47b9779
Merge pull request #2 from EducationalTestingService/create_python_pa…
damien2012eng Sep 14, 2022
4a2ec3d
add unit tests for tokenization file
Sep 7, 2022
87a4997
Address pr comments
Sep 14, 2022
67d2256
add unit tests for token_indexer
Sep 14, 2022
6e57109
Address pr comments
Sep 15, 2022
f311813
Address pr comments
Sep 15, 2022
5c0cf10
Merge pull request #1 from EducationalTestingService/features/add_uni…
damien2012eng Sep 16, 2022
b00e946
Add unit tests for pretrained BERT embedder
Frost45 Sep 9, 2022
05407f3
Add unit tests for pretrained RoBERTa embedder
Frost45 Sep 14, 2022
b649aeb
Add unit tests for seq2labels model
Frost45 Sep 20, 2022
6801c42
Update tokenization tests to use AllenNLP test modules
Frost45 Sep 23, 2022
2964b05
Address PR comments
Frost45 Sep 26, 2022
a3c74c2
Add CI plan
Frost45 Sep 26, 2022
dbc39e2
Addressed PR comments
Frost45 Sep 27, 2022
ca950c3
Unit test for GecModel Prediction.
Sep 27, 2022
e3ea139
Removing duplicate import.
Sep 27, 2022
8291809
Adding readme file so that fixtures dir exists for downloading gector
ksteimel Sep 27, 2022
9ba0aeb
Minor changes to make these tests pass if a cuda device is available.
ksteimel Sep 27, 2022
34386bd
Adding registered names for use by predictor
ksteimel Sep 27, 2022
36ab359
Adding expected test output.
ksteimel Sep 27, 2022
9bd6a7b
Added WIP docstring to GecBERTModel
ksteimel Sep 27, 2022
6f3f7ae
WIP Gec Predictor.
ksteimel Sep 27, 2022
9dd6bf1
words metadata is getting filled if unspecified when text_to_instance…
ksteimel Sep 28, 2022
64cbc8d
Using JustSpacesWordSplitter so that tokenization matches that used b…
ksteimel Sep 28, 2022
10e8069
Decode now adds the corrected sentence to the output dict.
ksteimel Sep 28, 2022
f6a9185
Updating gitignore to prevent adding .th files
ksteimel Sep 28, 2022
657f72b
Adding directory fixture as analogue to model archive
ksteimel Sep 28, 2022
e3cca59
Fixing errors in modeling code now that model.decode adds the origina…
ksteimel Sep 28, 2022
1c0025c
Adding conditional so that no correction is performed in decode if no…
ksteimel Sep 28, 2022
0f204ea
Appending start token when creating instances from json or string.
ksteimel Sep 28, 2022
67ef511
Start token is expected in ouptut.
ksteimel Sep 28, 2022
53a94dd
Drop START_TOKEN from output_dict["words"]. This interferes with the …
ksteimel Sep 28, 2022
86212b7
The outputs now no longer have $START_TOKEN in the corrected sentence…
ksteimel Sep 28, 2022
40619fe
Merge pull request #5 from EducationalTestingService/feature/add_inte…
ksteimel Sep 28, 2022
20a0692
Handling multiple iterations of correction in predictor now.
ksteimel Sep 28, 2022
0b88028
Changed location of weights file so it can be used by gec_predictor a…
ksteimel Sep 28, 2022
0b6b508
setup is now downloading weights file if it does not already exist.
ksteimel Sep 28, 2022
c535c53
Add how to run unit tests on README
Sep 29, 2022
86a7aff
Add regression data for raw and predictions
Sep 29, 2022
469d162
Merge pull request #7 from EducationalTestingService/features/add_reg…
damien2012eng Oct 3, 2022
bf4e209
Add regression test file
Frost45 Oct 6, 2022
2b4513e
Update CI plan to run regression tests
Frost45 Oct 6, 2022
68d900e
Addressed PR comments
Frost45 Oct 10, 2022
1b60b1b
Addressed PR comments
Frost45 Oct 12, 2022
f1db9a9
Addressed PR comments
Frost45 Oct 12, 2022
6f8f23c
Apply suggestions from code review
ksteimel Oct 12, 2022
7427f93
Removing unused imports, adding docstrings.
Oct 12, 2022
73aef1b
Removing unused predictions to labeled_instances method.
Oct 12, 2022
944993c
Updated docstring for decode()
Oct 12, 2022
709ba28
Removed unused imports.
Oct 12, 2022
323d61a
Add environment.yml
damien2012eng Oct 17, 2022
21d496c
versioning starting with 1.0.0
damien2012eng Oct 17, 2022
5da5955
Address PR comments
damien2012eng Oct 17, 2022
b04376c
Adding back import of gec_predictor that shouldn't have been removed
ksteimel Oct 22, 2022
ba588e2
Add how to run unit tests on README
Sep 29, 2022
f9fb4d7
Add regression data for raw and predictions
Sep 29, 2022
fbbcf10
Add regression test file
Frost45 Oct 6, 2022
d523ed9
Update CI plan to run regression tests
Frost45 Oct 6, 2022
e2014bc
Addressed PR comments
Frost45 Oct 10, 2022
9e84cb8
Addressed PR comments
Frost45 Oct 12, 2022
93f2722
Addressed PR comments
Frost45 Oct 12, 2022
bc89564
Add environment.yml
damien2012eng Oct 17, 2022
e96c38d
versioning starting with 1.0.0
damien2012eng Oct 17, 2022
618d686
Address PR comments
damien2012eng Oct 17, 2022
41a1cf0
Modify gec_predictor and seq2labels to work as gec_model does
Frost45 Oct 20, 2022
1bfc956
Add regression test script for predictor
Frost45 Oct 20, 2022
d952b18
Modify gec_predictor and seq2labels to work as gec_model does
Frost45 Oct 24, 2022
5d7ef1d
Update CI plan to run on all PRs and add regression tests
Frost45 Oct 24, 2022
d96fb80
Addressed PR comments
Frost45 Oct 24, 2022
c937981
Merge pull request #6 from EducationalTestingService/feature/predicto…
ksteimel Oct 24, 2022
3dabb0f
Merge branch 'master' into feature/fix-gec-predictor
Frost45 Oct 24, 2022
7f46ba1
Addressed PR comments
Frost45 Oct 24, 2022
7c6882b
Merge branch 'feature/fix-gec-predictor' of github.com:EducationalTes…
Frost45 Oct 24, 2022
7c5610d
Merge pull request #13 from EducationalTestingService/feature/fix-gec…
Frost45 Oct 24, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions .github/workflows/python-test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions

name: Python test

on:
push:
branches: [ "master" ]
pull_request:

jobs:
build:

runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.8"]

steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e .
- name: Unit Testing
run: |
pytest -v tests
- name: Regression Testing
run: |
python regression_tests/test_gector_roberta.py
python regression_tests/test_regression_data_predictor.py
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,7 @@ dmypy.json
# PyCharm
.idea

*.sh
*.sh

# pytorch weights files
*.th
12 changes: 9 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,16 @@ This repository provides code for training and testing state-of-the-art models f
It is mainly based on `AllenNLP` and `transformers`.
## Installation
The following command installs all necessary packages:
```.bash
pip install -r requirements.txt
```bash
conda create --name <Environment_name> python=3.8
conda activate <Environment_name>
pip install -e .
```
The project was tested using Python 3.7.
The project was tested using Python 3.8.

## Unit tests
After activating the conda environment, simply run the code below:
`pytest -v tests`

## Datasets
All the public GEC datasets used in the paper can be downloaded from [here](https://www.cl.cam.ac.uk/research/nl/bea2019st/#data).<br>
Expand Down
Empty file added data/__init__.py
Empty file.
13 changes: 13 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
name: gector
dependencies:
- python=3.8
- pytorch=1.10.0
- python-Levenshtein
- transformers
- scikit-learn
- sentencepiece
- overrides=4.1.2
- numpy
- pip:
- allennlp==0.9.0

Empty file added gector/__init__.py
Empty file.
2 changes: 1 addition & 1 deletion gector/bert_token_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def forward(
return util.uncombine_initial_dims(selected_embeddings, offsets.size())


# @TokenEmbedder.register("bert-pretrained")
@TokenEmbedder.register("gec-bert-pretrained")
class PretrainedBertEmbedder(BertEmbedder):

"""
Expand Down
93 changes: 61 additions & 32 deletions gector/datareader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@

from allennlp.common.file_utils import cached_path
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.data.fields import TextField, SequenceLabelField, MetadataField, Field
from allennlp.data.fields import (
TextField,
SequenceLabelField,
MetadataField,
Field,
)
from allennlp.data.instance import Instance
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
from allennlp.data.tokenizers import Token
Expand Down Expand Up @@ -37,23 +42,28 @@ class Seq2LabelsDatasetReader(DatasetReader):
are pre-tokenised in the data file.
max_len: if set than will truncate long sentences
"""

# fix broken sentences mostly in Lang8
BROKEN_SENTENCES_REGEXP = re.compile(r'\.[a-zA-RT-Z]')

def __init__(self,
token_indexers: Dict[str, TokenIndexer] = None,
delimeters: dict = SEQ_DELIMETERS,
skip_correct: bool = False,
skip_complex: int = 0,
lazy: bool = False,
max_len: int = None,
test_mode: bool = False,
tag_strategy: str = "keep_one",
tn_prob: float = 0,
tp_prob: float = 0,
broken_dot_strategy: str = "keep") -> None:
BROKEN_SENTENCES_REGEXP = re.compile(r"\.[a-zA-RT-Z]")

def __init__(
self,
token_indexers: Dict[str, TokenIndexer] = None,
delimeters: dict = SEQ_DELIMETERS,
skip_correct: bool = False,
skip_complex: int = 0,
lazy: bool = False,
max_len: int = None,
test_mode: bool = False,
tag_strategy: str = "keep_one",
tn_prob: float = 0,
tp_prob: float = 0,
broken_dot_strategy: str = "keep",
) -> None:
super().__init__(lazy)
self._token_indexers = token_indexers or {'tokens': SingleIdTokenIndexer()}
self._token_indexers = token_indexers or {
"tokens": SingleIdTokenIndexer()
}
self._delimeters = delimeters
self._max_len = max_len
self._skip_correct = skip_correct
Expand All @@ -69,16 +79,23 @@ def _read(self, file_path):
# if `file_path` is a URL, redirect to the cache
file_path = cached_path(file_path)
with open(file_path, "r") as data_file:
logger.info("Reading instances from lines in file at: %s", file_path)
logger.info(
"Reading instances from lines in file at: %s", file_path
)
for line in data_file:
line = line.strip("\n")
# skip blank and broken lines
if not line or (not self._test_mode and self._broken_dot_strategy == 'skip'
and self.BROKEN_SENTENCES_REGEXP.search(line) is not None):
if not line or (
not self._test_mode
and self._broken_dot_strategy == "skip"
and self.BROKEN_SENTENCES_REGEXP.search(line) is not None
):
continue

tokens_and_tags = [pair.rsplit(self._delimeters['labels'], 1)
for pair in line.split(self._delimeters['tokens'])]
tokens_and_tags = [
pair.rsplit(self._delimeters["labels"], 1)
for pair in line.split(self._delimeters["tokens"])
]
try:
tokens = [Token(token) for token, tag in tokens_and_tags]
tags = [tag for token, tag in tokens_and_tags]
Expand All @@ -91,14 +108,14 @@ def _read(self, file_path):

words = [x.text for x in tokens]
if self._max_len is not None:
tokens = tokens[:self._max_len]
tags = None if tags is None else tags[:self._max_len]
tokens = tokens[: self._max_len]
tags = None if tags is None else tags[: self._max_len]
instance = self.text_to_instance(tokens, tags, words)
if instance:
yield instance

def extract_tags(self, tags: List[str]):
op_del = self._delimeters['operations']
op_del = self._delimeters["operations"]

labels = [x.split(op_del) for x in tags]

Expand All @@ -117,18 +134,28 @@ def extract_tags(self, tags: List[str]):
else:
raise Exception("Incorrect tag strategy")

detect_tags = ["CORRECT" if label == "$KEEP" else "INCORRECT" for label in labels]
detect_tags = [
"CORRECT" if label == "$KEEP" else "INCORRECT" for label in labels
]
return labels, detect_tags, comlex_flag_dict

def text_to_instance(self, tokens: List[Token], tags: List[str] = None,
words: List[str] = None) -> Instance: # type: ignore
def text_to_instance(
self,
tokens: List[Token],
tags: List[str] = None,
words: List[str] = None,
) -> Instance: # type: ignore
"""
We take `pre-tokenized` input here, because we don't have a tokenizer in this class.
"""
# pylint: disable=arguments-differ
fields: Dict[str, Field] = {}
sequence = TextField(tokens, self._token_indexers)
# Set size of tokens to _max_len + 1 since $START token is being added
sequence = TextField(tokens[: self._max_len + 1], self._token_indexers)
fields["tokens"] = sequence
# If words has not been explicitly passed in, create it from tokens.
if words is None:
words = [token.text for token in tokens]
fields["metadata"] = MetadataField({"words": words})
if tags is not None:
labels, detect_tags, complex_flag_dict = self.extract_tags(tags)
Expand All @@ -144,8 +171,10 @@ def text_to_instance(self, tokens: List[Token], tags: List[str] = None,
if rnd > self._tp_prob:
return None

fields["labels"] = SequenceLabelField(labels, sequence,
label_namespace="labels")
fields["d_tags"] = SequenceLabelField(detect_tags, sequence,
label_namespace="d_tags")
fields["labels"] = SequenceLabelField(
labels, sequence, label_namespace="labels"
)
fields["d_tags"] = SequenceLabelField(
detect_tags, sequence, label_namespace="d_tags"
)
return Instance(fields)
Loading