Source code for forayer.transformation.word_embedding
"""Module concerned with utilizing word embeddings."""
import logging
import os
from typing import Any, Callable, Dict, List
from warnings import warn
from zipfile import ZipFile
import numpy as np
import wget
from gensim.models import KeyedVectors
from gensim.models.fasttext import load_facebook_model
from gensim.utils import tokenize
_EMBEDDING_INFO = {
"fasttext": (
"https://dl.fbaipublicfiles.com/fasttext/vectors-wiki/wiki.simple.zip",
"wiki.simple.zip", # zip name
"wiki.simple.bin", # file containing wanted embeddings
),
"glove": (
"http://nlp.stanford.edu/data/glove.6B.zip",
"glove.6B.zip",
"glove.6B.300d.txt",
),
}
logger = logging.getLogger(__name__)
[docs]class AttributeVectorizer:
"""Vectorizer class to get attribute embeddings of entities with pre-trained embeddings.
Attributes
----------
tokenizer: Callable
callable that tokenizes a string
embedding_type: str
type of pretrained embeddings
vectors_path: str
path to pre-trained embeddings
wv: gensim.models.KeyedVectors
word embeddings
"""
[docs] def __init__(
self,
tokenizer: Callable = None,
embedding_type: str = "fasttext",
vectors_path: str = None,
default_download_dir: str = None,
):
"""Initialize an AttributeVectorizer object and load the pre-trained embeddings.
Parameters
----------
tokenizer: Callable
callable that tokenizes a string
embedding_type: str
type of pretrained embeddings
vectors_path: str
path to pre-trained embeddings
default_download_dir: str
directory where embeddings are downloaded if they are not present
default is "./data/word_embeddings/"
Raises
------
TypeError
if tokenizer is not callable
ValueError
if embedding_type is unknown or vectors_path does not exist
"""
warn(
(
"AttributeVectorizer is deprecated and will be removed in the next"
" minor version."
),
DeprecationWarning,
stacklevel=2,
)
if tokenizer is None:
self.tokenizer = tokenize
else:
if not hasattr(tokenizer, "___call__"):
raise TypeError(
f"tokenizer should be a function, but was {type(tokenizer)}"
)
else:
self.tokenizer = tokenizer
embedding_type = embedding_type.lower()
if embedding_type not in _EMBEDDING_INFO:
raise ValueError(
f"embedding_type has to be one of {set(_EMBEDDING_INFO.keys())}, not"
f" {embedding_type}"
)
self.embedding_type = embedding_type
if vectors_path is not None and not os.path.exists(vectors_path):
raise ValueError(f"vectors_path: {vectors_path} does not exist")
self.vectors_path = vectors_path
self.default_download_dir = (
default_download_dir
if default_download_dir is not None
else os.path.join("data", "word_embeddings")
)
self._download_embeddings_if_needed()
self.wv = self._load_embeddings()
self.vocab: Dict = {}
self.seen_tokens = 0
self.ignored_tokens = 0
[docs] def reset_token_count(self):
"""Reset .seen_tokens and .ignored_tokens."""
self.seen_tokens = 0
self.ignored_tokens = 0
def _download_embeddings_if_needed(self):
if self.vectors_path is None:
embeddings_dir = os.path.join(
self.default_download_dir, self.embedding_type
)
dl_url, zip_name, embedding_file = _EMBEDDING_INFO[self.embedding_type]
vectors_path = os.path.join(embeddings_dir, embedding_file)
if os.path.exists(vectors_path):
# embeddings were already downloaded
self.vectors_path = vectors_path
return False
# we have to download them
if not os.path.exists(embeddings_dir):
os.makedirs(embeddings_dir)
logger.info(
f"Downloading {self.embedding_type} embeddings to {embeddings_dir}"
)
wget.download(dl_url, embeddings_dir)
with ZipFile(os.path.join(embeddings_dir, zip_name), "r") as zip_obj:
zip_obj.extractall(embeddings_dir)
os.remove(os.path.join(embeddings_dir, zip_name))
self.vectors_path = vectors_path
return True
def _load_embeddings(self):
logger.info(f"Loading word embeddings from {self.vectors_path}")
if self.embedding_type == "fasttext":
return load_facebook_model(self.vectors_path).wv
try:
return KeyedVectors.load_word2vec_format(self.vectors_path, binary=False)
except ValueError:
if self.embedding_type == "glove":
wv = KeyedVectors.load_word2vec_format(
self.vectors_path, binary=False, no_header=True
)
return wv
[docs] def vectorize(self, sentence: str) -> List[np.ndarray]:
"""Tokenize and vectorize a sentence with the given word embeddings.
Parameters
----------
sentence : str
sentence to vectorize
Returns
-------
vectorized: List[np.ndarray]
List of token embeddings,
Notes
-----
Ignores tokens that are not contained in the used embeddings
Ignored tokens will be set to np.NaN
"""
vectorized = []
_number_of_tokens = 0 # initialize before in case loop is not executed
for _number_of_tokens, word in enumerate(self.tokenizer(sentence), start=1):
if word in self.wv:
vectorized.append(self.wv[word])
else:
vectorized.append(np.NaN)
self.ignored_tokens += 1
logging.debug(f"Did not find embedding for {word}, ignoring")
self.seen_tokens += _number_of_tokens
return vectorized
[docs] def vectorize_entity_attributes(
self, attributes: Dict[Any, Any]
) -> Dict[Any, List[np.ndarray]]:
"""Tokenize and vectorize values of entity attributes.
Parameters
----------
attributes : Dict[Any, Any]
dictionary of entity attributes with attribute names as keys
Returns
-------
embedded_entity_attributes: Dict[Any, List[np.ndarray]]
entity dicts with attribute values replaced with list of token embeddings
"""
return {
attr_name: self.vectorize(attr_value)
for attr_name, attr_value in attributes.items()
}