Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-3488: Support for writing a ColumnCorpus instance to files #3497

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
30 changes: 30 additions & 0 deletions flair/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1219,6 +1219,36 @@ def remove_labels(self, typename: str):
# delete labels at object itself
super().remove_labels(typename)

def _get_token_level_label_of_each_token(self, label_type: str) -> List[str]:
"""Generates a label for each token in the sentence. This function requires that the labels corresponding to the label_type are token-level tokens.

Args:
sentence: a flair sentence to generate labels for
label_type: a string representing the type of the labels, e.g., "pos"
"""
list_of_labels = ["O" for _ in range(len(self.tokens))]
for label in self.get_labels(label_type):
label_token_index = label.data_point._internal_index
list_of_labels[label_token_index - 1] = label.value
return list_of_labels

def _get_span_level_label_of_each_token(self, label_type: str) -> List[str]:
"""Generates a label for each token in the sentence in BIO format. This function requires that the labels corresponding to the label_type are span-level tokens.

Args:
sentence: a flair sentence to generate labels for
label_type: a string representing the type of the labels, e.g., "ner"
"""
list_of_labels = ["O" for _ in range(len(self.tokens))]
for label in self.get_labels(label_type):
tokens = label.data_point.tokens
start_token_index = tokens[0]._internal_index
list_of_labels[start_token_index - 1] = f"B-{label.value}"
for token in tokens[1:]:
token_index = token._internal_index
list_of_labels[token_index - 1] = f"I-{label.value}"
return list_of_labels


class DataPair(DataPoint, typing.Generic[DT, DT2]):
def __init__(self, first: DT, second: DT2) -> None:
Expand Down
180 changes: 167 additions & 13 deletions flair/datasets/sequence_labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,9 @@
import shutil
from collections import defaultdict
from pathlib import Path
from typing import (
Any,
DefaultDict,
Dict,
Iterable,
Iterator,
List,
Optional,
Tuple,
Union,
cast,
)
from typing import Any, DefaultDict, Dict, Iterable, Iterator, List, Optional, Tuple, Type, Union, cast

from torch.utils.data import ConcatDataset, Dataset
from torch.utils.data import ConcatDataset, Dataset, Subset

import flair
from flair.data import (
Expand All @@ -28,7 +17,9 @@
MultiCorpus,
Relation,
Sentence,
Span,
Token,
_iter_dataset,
get_spans_from_bio,
)
from flair.datasets.base import find_train_dev_test_files
Expand Down Expand Up @@ -443,6 +434,149 @@ def __init__(
**corpusargs,
)

@staticmethod
def _get_level_of_label(dataset: Optional[Dataset], label_type: str) -> Optional[Union[Type[Token], Type[Span]]]:
"""Gets level of label type by checking the first label in this dataset.

Raises:
NotImplementedError: if level of label_type is neither Token nor Span
"""
for sentence in _iter_dataset(dataset):
for label in sentence.get_labels(label_type):
if isinstance(label.data_point, Token):
return Token
elif isinstance(label.data_point, Span):
return Span
else:
raise NotImplementedError(
f"The level of {label_type} is neither token nor span. Only token level labels and span level labels can be handled now."
)
log.warning(f"There is no label of type {label_type} in this dataset.")
return None

@staticmethod
def _write_dataset_to_file(
dataset: Optional[Dataset], label_types: List[str], file_path: Path, column_delimiter: str = "\t"
) -> None:
"""Writes a dataset to a file.

Following these two rules.
(1) the text and the label(s) of every token is represented in one line separated by column_delimiter
(2) every sentence is separated from the previous one by an empty line

Note:
Only labels corresponding to label_types will be written.
Only token level or span level sequence tagging labels are supported.
Currently, the whitespace_after attribute of each token will not be preserved in the written file.

Args:
dataset: a dataset to write
label_types: a list of label types to write e.g., ["ner", "pos"]
file_path: a path to store the file
column_delimiter: a string to separate token texts and labels in a line, the default value is a tab
"""
if dataset:
label_type_tuples = []
for label_type in label_types:
level_of_label = ColumnCorpus._get_level_of_label(dataset, label_type)
label_type_tuples.append((label_type, level_of_label))

with open(file_path, mode="w") as output_file:
for sentence in _iter_dataset(dataset):
texts = [token.text for token in sentence.tokens]
texts_and_labels = [texts]
for label_type, level in label_type_tuples:
if level is None:
texts_and_labels.append(["O" for _ in range(len(sentence))])
elif level is Token:
texts_and_labels.append(sentence._get_token_level_label_of_each_token(label_type))
elif level is Span:
texts_and_labels.append(sentence._get_span_level_label_of_each_token(label_type))
else:
raise NotImplementedError(f"The level of {label_type} is neither token nor span.")

for text_and_labels_of_a_token in zip(*texts_and_labels):
output_file.write(column_delimiter.join(text_and_labels_of_a_token) + "\n")
output_file.write("\n")
else:
log.warning("dataset is None, did not write any file.")

@classmethod
def load_corpus_with_meta_data(cls, directory: Path) -> "ColumnCorpus":
"""Creates a ColumnCorpus instance from the directory generated by 'write_to_directory'."""
with open(directory / "meta_data.json") as file:
meta_data = json.load(file)

meta_data["column_format"] = {int(key): value for key, value in meta_data["column_format"].items()}

return cls(
data_folder=directory,
autofind_splits=True,
skip_first_line=False,
**meta_data,
)

def _write_corpus_meta_data(
self, label_types: List[str], file_path: Path, column_delimiter: str, max_depth=5
) -> None:
"""Writes meta data of this corpus to a json file.

Note:
Currently, the whitespace_after attribute of each token will not be preserved. Only default_whitespace_after attribute of each dataset will be written to the file.
"""
meta_data = {
"name": self.name,
"sample_missing_splits": False,
"column_delimiter": column_delimiter,
}

column_format = {0: "text"}
for label_type_index, label_type in enumerate(label_types):
column_format[label_type_index + 1] = label_type
meta_data["column_format"] = column_format

nonempty_dataset = self.train or self.dev or self.test
# Sometimes, nonempty_dataset is a ConcatDataset or Subset, we need to get the original ColumnDataset
# to access the encoding, in_memory, banned_sentences and default_whitespace_after attributes
for _ in range(max_depth):
if type(nonempty_dataset) is ColumnDataset:
break
elif type(nonempty_dataset) is ConcatDataset:
nonempty_dataset = nonempty_dataset.datasets[0]
elif type(nonempty_dataset) is Subset:
nonempty_dataset = nonempty_dataset.dataset
else:
raise NotImplementedError("Unsupported type")

if type(nonempty_dataset) is not ColumnDataset:
raise NotImplementedError("Unsupported type")

meta_data["encoding"] = nonempty_dataset.encoding
meta_data["in_memory"] = nonempty_dataset.in_memory
meta_data["banned_sentences"] = nonempty_dataset.banned_sentences
meta_data["default_whitespace_after"] = nonempty_dataset.default_whitespace_after

with open(file_path, mode="w") as output_file:
json.dump(meta_data, output_file)

def write_to_directory(self, label_types: List[str], output_directory: Path, column_delimiter: str = "\t") -> None:
"""Writes train, dev, test dataset (if exist) and the meta data of the corpus to a directory.

Note:
Only labels corresponding to label_types will be written.
Only token level or span level sequence tagging labels are supported.
Currently, the whitespace_after attribute of each token will not be preserved in the written file.

Args:
label_types: a list of label types to write e.g., ["ner", "pos"]
output_directory: a directory to store the files
column_delimiter: a string to separate token texts and labels in a line, the default value is a tab
"""
os.makedirs(output_directory, exist_ok=True)
for dataset, file_name in [(self.train, "train.conll"), (self.dev, "dev.conll"), (self.test, "test.conll")]:
ColumnCorpus._write_dataset_to_file(dataset, label_types, output_directory / file_name, column_delimiter)
self._write_corpus_meta_data(label_types, output_directory / "meta_data.json", column_delimiter)


class ColumnDataset(FlairDataset):
# special key for space after
Expand Down Expand Up @@ -797,6 +931,26 @@ def _remap_label(self, tag):
tag = self.label_name_map[tag] # for example, transforming 'PER' to 'person'
return tag

def write_dataset_to_file(self, label_types: List[str], file_path: Path, column_delimiter: str = "\t") -> None:
"""Writes a dataset to a file.

Following these two rules.
(1) the text and the label(s) of every token is represented in one line separated by column_delimiter
(2) every sentence is separated from the previous one by an empty line

Note:
Only labels corresponding to label_types will be written.
Only token level or span level sequence tagging labels are supported.
Currently, the whitespace_after attribute of each token will not be preserved in the written file.

Args:
label_types: a list of label types to write e.g., ["ner", "pos"]
file_path: a path to store the file
column_delimiter: a string to separate token texts and labels in a line, the default value is a tab
"""
file_path.parent.mkdir(exist_ok=True, parents=True)
ColumnCorpus._write_dataset_to_file(self, label_types, file_path, column_delimiter)

def __line_completes_sentence(self, line: str) -> bool:
sentence_completed = line.isspace() or line == ""
return sentence_completed
Expand Down
18 changes: 18 additions & 0 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,24 @@ def test_load_universal_dependencies_conllu_corpus(tasks_base_path):
_assert_universal_dependencies_conllu_dataset(corpus.train)


def test_write_to_and_load_from_directory(tasks_base_path):
from pathlib import Path

corpus = ColumnCorpus(
tasks_base_path / "column_with_whitespaces",
train_file="eng.train",
column_format={0: "text", 1: "ner"},
column_delimiter=" ",
skip_first_line=False,
sample_missing_splits=False,
)
directory = Path("resources/taggers/")
corpus.write_to_directory(["ner"], directory, column_delimiter="\t")
loaded_corpus = ColumnCorpus.load_corpus_with_meta_data(directory)
assert len(loaded_corpus.train) == len(corpus.train)
assert loaded_corpus.train[0].to_tagged_string() == corpus.train[0].to_tagged_string()


@pytest.mark.integration()
def test_hipe_2022_corpus(tasks_base_path):
# This test covers the complete HIPE 2022 dataset.
Expand Down