Refer to this Repo (Links to an external site.).
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import Multi30k
Note that there are more api’s coming up in torchtext.experimental
which will be released soon in 0.11.0
and 0.12.0
, hopefully with good documentation :) and i’ll be a part of creating that documentation 😁
token_transform[SRC_LANGUAGE] = get_tokenizer('spacy', language='de_core_news_sm')
token_transform[TGT_LANGUAGE] = get_tokenizer('spacy', language='en_core_web_sm')
# Training, Validation and Test data Iterator
train_iter, val_iter, test_iter = Multi30k(split=('train', 'valid', 'test'), language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
train_list, val_list, test_list = list(train_iter), list(val_iter), list(test_iter)
A thing to note here is that list(iter)
will be costly for large dataset, so its preferable to keep it as an iter
and make a Dataloader
out of it and use the Dataloader
for whatever you want to do, because once the iterator is exhausted, you’ll need to call the Multi30k
function again to make the iter, which seems kind of waste of cpu cycles.
Vocab
# helper function to yield list of tokens
def yield_tokens(data_iter: Iterable, language: str) -> List[str]:
language_index = {SRC_LANGUAGE: 0, TGT_LANGUAGE: 1}
for data_sample in data_iter:
yield token_transform[language](data_sample[language_index[language]])
# Define special symbols and indices
UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
# Make sure the tokens are in order of their indices to properly insert them in vocab
special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
# Create torchtext's Vocab object
vocab_transform[ln] = build_vocab_from_iterator(yield_tokens(train_list, ln),
min_freq=1,
specials=special_symbols,
special_first=True)
# Set UNK_IDX as the default index. This index is returned when the token is not found.
# If not set, it throws RuntimeError when the queried token is not found in the Vocabulary.
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
vocab_transform[ln].set_default_index(UNK_IDX)
collate_fn
from torch.nn.utils.rnn import pad_sequence
# helper function to club together sequential operations
def sequential_transforms(*transforms):
def func(txt_input):
for transform in transforms:
txt_input = transform(txt_input)
return txt_input
return func
# function to add BOS/EOS and create tensor for input sequence indices
def tensor_transform(token_ids: List[int]):
return torch.cat((torch.tensor([BOS_IDX]),
torch.tensor(token_ids),
torch.tensor([EOS_IDX])))
# src and tgt language text transforms to convert raw strings into tensors indices
text_transform = {}
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
text_transform[ln] = sequential_transforms(token_transform[ln], #Tokenization
vocab_transform[ln], #Numericalization
tensor_transform) # Add BOS/EOS and create tensor
# function to collate data samples into batch tesors
def collate_fn(batch):
src_batch, src_len, tgt_batch = [], [], []
for src_sample, tgt_sample in batch:
src_batch.append(text_transform[SRC_LANGUAGE](src_sample.rstrip("\n")))
src_len.append(len(src_batch[-1]))
tgt_batch.append(text_transform[TGT_LANGUAGE](tgt_sample.rstrip("\n")))
src_batch = pad_sequence(src_batch, padding_value=PAD_IDX)
tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)
return src_batch, torch.LongTensor(src_len), tgt_batch
Dataloader
BATCH_SIZE = 128
train_dataloader = DataLoader(train_list, batch_size=BATCH_SIZE, collate_fn=collate_fn)
val_dataloader = DataLoader(val_list, batch_size=BATCH_SIZE, collate_fn=collate_fn)
test_dataloader = DataLoader(test_list, batch_size=BATCH_SIZE, collate_fn=collate_fn)
The main aim of the new API is to be consistent with torch
, which uses a classical Dataloader
object, and torchtext
is moving towards it. UNIFY EVERYTHING (╯°□°)╯︵ ┻━┻
torchtext
is evolving a lott, consequently there has been a lot of breaking changes, and not much documentation on it sadly (┬┬﹏┬┬)
, below are some of the official reference which use the new api
torchtext
legacy migration tutorialApart from these there are some useful GitHub Issues that must be looked at. Some of them are my contributions ψ(`∇´)ψ
Vector
with the new torchtext
api