Skip to content

Commit

Permalink
Fix initialize and update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ines committed Feb 7, 2021
1 parent 7586325 commit b80070c
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 23 deletions.
21 changes: 20 additions & 1 deletion sense2vec/component.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Tuple, Union, List, Dict
from typing import Tuple, Union, List, Dict, Callable, Iterable, Optional
from spacy.language import Language
from spacy.tokens import Doc, Token, Span
from spacy.training import Example
from spacy.vocab import Vocab
from spacy.util import SimpleFrozenDict
from pathlib import Path
Expand Down Expand Up @@ -215,6 +216,24 @@ def s2v_other_senses(self, obj: Union[Token, Span]) -> List[str]:
key = self.s2v_key(obj)
return obj.doc._._s2v.get_other_senses(key)

def initialize(
self,
get_examples: Callable[[], Iterable[Example]],
*,
nlp: Optional[Language] = None,
data_path: Optional[str] = None
):
"""Initialize the component and load in data. Can be used to add the
component with vectors to a pipeline before training.
get_examples (Callable[[], Iterable[Example]]): Function that
returns a representative sample of gold-standard Example objects.
nlp (Language): The current nlp object the component is part of.
data_path (Optional[str]): Optional path to sense2vec model.
"""
if data_path is not None:
self.from_disk(data_path)

def to_bytes(self) -> bytes:
"""Serialize the component to a bytestring.
Expand Down
22 changes: 1 addition & 21 deletions sense2vec/sense2vec.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from typing import Tuple, List, Union, Sequence, Dict, Callable, Any, Iterable
from typing import Optional
from typing import Tuple, List, Union, Sequence, Dict, Callable, Any
from pathlib import Path
from spacy.language import Language
from spacy.vectors import Vectors
from spacy.strings import StringStore
from spacy.util import SimpleFrozenDict
Expand Down Expand Up @@ -297,24 +295,6 @@ def to_bytes(self, exclude: Sequence[str] = tuple()) -> bytes:
data["cache"] = self.cache
return srsly.msgpack_dumps(data)

def initialize(
self,
get_examples: Callable[[], Iterable],
*,
nlp: Optional[Language] = None,
data_path: Optional[str] = None
):
"""Initialize the component and load in data. Can be used to add the
component with vectors to a pipeline before training.
get_examples (Callable[[], Iterable[Example]]): Function that
returns a representative sample of gold-standard Example objects.
nlp (Language): The current nlp object the component is part of.
data_path (Optional[str]): Optional path to sense2vec model.
"""
if data_path is not None:
self.from_disk(data_path)

def from_bytes(self, bytes_data: bytes, exclude: Sequence[str] = tuple()):
"""Load a Sense2Vec object from a bytestring.
Expand Down
32 changes: 32 additions & 0 deletions sense2vec/tests/test_component.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import pytest
import numpy
import spacy
from spacy.vocab import Vocab
from spacy.tokens import Doc, Span
from sense2vec import Sense2VecComponent
from pathlib import Path


@pytest.fixture
Expand Down Expand Up @@ -103,3 +105,33 @@ def test_component_to_from_bytes(doc):
assert doc[0]._.in_s2v is False
new_doc = new_s2v(doc)
assert new_doc[0]._.in_s2v is True


def test_component_initialize():
data_path = Path(__file__).parent / "data"
# With from_disk
nlp = spacy.blank("en")
s2v = nlp.add_pipe("sense2vec")
if Doc.has_extension("s2v_phrases"):
s2v.first_run = False # don't set up extensions again
s2v.from_disk(data_path)
doc = Doc(nlp.vocab, words=["beekeepers"], pos=["NOUN"])
s2v(doc)
assert doc[0]._.s2v_key == "beekeepers|NOUN"
most_similar = [item for item, score in doc[0]._.s2v_most_similar(2)]
assert most_similar[0] == ("honey bees", "NOUN")
assert most_similar[1] == ("Beekeepers", "NOUN")

# With initialize
nlp = spacy.blank("en")
s2v = nlp.add_pipe("sense2vec")
s2v.first_run = False # don't set up extensions again
init_cfg = {"sense2vec": {"data_path": str(data_path)}}
nlp.config["initialize"]["components"] = init_cfg
nlp.initialize()
doc = Doc(nlp.vocab, words=["beekeepers"], pos=["NOUN"])
s2v(doc)
assert doc[0]._.s2v_key == "beekeepers|NOUN"
most_similar = [item for item, score in doc[0]._.s2v_most_similar(2)]
assert most_similar[0] == ("honey bees", "NOUN")
assert most_similar[1] == ("Beekeepers", "NOUN")
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ install_requires =

[options.entry_points]
spacy_factories =
sense2vec = sense2vec:make_sense2vec
sense2vec = sense2vec:component.make_sense2vec
prodigy_recipes =
sense2vec.teach = sense2vec:prodigy_recipes.teach
sens2vec.to-patterns = sense2vec:prodigy_recipes.to_patterns
Expand Down

0 comments on commit b80070c

Please sign in to comment.