Initial commit
This commit is contained in:
@@ -0,0 +1,416 @@
|
||||
"""
|
||||
Tests for BLEU translation evaluation metric
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from nltk.data import find
|
||||
from nltk.translate.bleu_score import (
|
||||
SmoothingFunction,
|
||||
brevity_penalty,
|
||||
closest_ref_length,
|
||||
corpus_bleu,
|
||||
modified_precision,
|
||||
sentence_bleu,
|
||||
)
|
||||
|
||||
|
||||
class TestBLEU(unittest.TestCase):
|
||||
def test_modified_precision(self):
|
||||
"""
|
||||
Examples from the original BLEU paper
|
||||
https://www.aclweb.org/anthology/P02-1040.pdf
|
||||
"""
|
||||
# Example 1: the "the*" example.
|
||||
# Reference sentences.
|
||||
ref1 = "the cat is on the mat".split()
|
||||
ref2 = "there is a cat on the mat".split()
|
||||
# Hypothesis sentence(s).
|
||||
hyp1 = "the the the the the the the".split()
|
||||
|
||||
references = [ref1, ref2]
|
||||
|
||||
# Testing modified unigram precision.
|
||||
hyp1_unigram_precision = float(modified_precision(references, hyp1, n=1))
|
||||
assert round(hyp1_unigram_precision, 4) == 0.2857
|
||||
# With assertAlmostEqual at 4 place precision.
|
||||
self.assertAlmostEqual(hyp1_unigram_precision, 0.28571428, places=4)
|
||||
|
||||
# Testing modified bigram precision.
|
||||
assert float(modified_precision(references, hyp1, n=2)) == 0.0
|
||||
|
||||
# Example 2: the "of the" example.
|
||||
# Reference sentences
|
||||
ref1 = str(
|
||||
"It is a guide to action that ensures that the military "
|
||||
"will forever heed Party commands"
|
||||
).split()
|
||||
ref2 = str(
|
||||
"It is the guiding principle which guarantees the military "
|
||||
"forces always being under the command of the Party"
|
||||
).split()
|
||||
ref3 = str(
|
||||
"It is the practical guide for the army always to heed "
|
||||
"the directions of the party"
|
||||
).split()
|
||||
# Hypothesis sentence(s).
|
||||
hyp1 = "of the".split()
|
||||
|
||||
references = [ref1, ref2, ref3]
|
||||
# Testing modified unigram precision.
|
||||
assert float(modified_precision(references, hyp1, n=1)) == 1.0
|
||||
|
||||
# Testing modified bigram precision.
|
||||
assert float(modified_precision(references, hyp1, n=2)) == 1.0
|
||||
|
||||
# Example 3: Proper MT outputs.
|
||||
hyp1 = str(
|
||||
"It is a guide to action which ensures that the military "
|
||||
"always obeys the commands of the party"
|
||||
).split()
|
||||
hyp2 = str(
|
||||
"It is to insure the troops forever hearing the activity "
|
||||
"guidebook that party direct"
|
||||
).split()
|
||||
|
||||
references = [ref1, ref2, ref3]
|
||||
|
||||
# Unigram precision.
|
||||
hyp1_unigram_precision = float(modified_precision(references, hyp1, n=1))
|
||||
hyp2_unigram_precision = float(modified_precision(references, hyp2, n=1))
|
||||
# Test unigram precision with assertAlmostEqual at 4 place precision.
|
||||
self.assertAlmostEqual(hyp1_unigram_precision, 0.94444444, places=4)
|
||||
self.assertAlmostEqual(hyp2_unigram_precision, 0.57142857, places=4)
|
||||
# Test unigram precision with rounding.
|
||||
assert round(hyp1_unigram_precision, 4) == 0.9444
|
||||
assert round(hyp2_unigram_precision, 4) == 0.5714
|
||||
|
||||
# Bigram precision
|
||||
hyp1_bigram_precision = float(modified_precision(references, hyp1, n=2))
|
||||
hyp2_bigram_precision = float(modified_precision(references, hyp2, n=2))
|
||||
# Test bigram precision with assertAlmostEqual at 4 place precision.
|
||||
self.assertAlmostEqual(hyp1_bigram_precision, 0.58823529, places=4)
|
||||
self.assertAlmostEqual(hyp2_bigram_precision, 0.07692307, places=4)
|
||||
# Test bigram precision with rounding.
|
||||
assert round(hyp1_bigram_precision, 4) == 0.5882
|
||||
assert round(hyp2_bigram_precision, 4) == 0.0769
|
||||
|
||||
def test_brevity_penalty(self):
|
||||
# Test case from brevity_penalty_closest function in mteval-v13a.pl.
|
||||
# Same test cases as in the doctest in nltk.translate.bleu_score.py
|
||||
references = [["a"] * 11, ["a"] * 8]
|
||||
hypothesis = ["a"] * 7
|
||||
hyp_len = len(hypothesis)
|
||||
closest_ref_len = closest_ref_length(references, hyp_len)
|
||||
self.assertAlmostEqual(
|
||||
brevity_penalty(closest_ref_len, hyp_len), 0.8669, places=4
|
||||
)
|
||||
|
||||
references = [["a"] * 11, ["a"] * 8, ["a"] * 6, ["a"] * 7]
|
||||
hypothesis = ["a"] * 7
|
||||
hyp_len = len(hypothesis)
|
||||
closest_ref_len = closest_ref_length(references, hyp_len)
|
||||
assert brevity_penalty(closest_ref_len, hyp_len) == 1.0
|
||||
|
||||
def test_zero_matches(self):
|
||||
# Test case where there's 0 matches
|
||||
references = ["The candidate has no alignment to any of the references".split()]
|
||||
hypothesis = "John loves Mary".split()
|
||||
|
||||
# Test BLEU to nth order of n-grams, where n is len(hypothesis).
|
||||
for n in range(1, len(hypothesis)):
|
||||
weights = (1.0 / n,) * n # Uniform weights.
|
||||
assert sentence_bleu(references, hypothesis, weights) == 0
|
||||
|
||||
def test_full_matches(self):
|
||||
# Test case where there's 100% matches
|
||||
references = ["John loves Mary".split()]
|
||||
hypothesis = "John loves Mary".split()
|
||||
|
||||
# Test BLEU to nth order of n-grams, where n is len(hypothesis).
|
||||
for n in range(1, len(hypothesis)):
|
||||
weights = (1.0 / n,) * n # Uniform weights.
|
||||
assert sentence_bleu(references, hypothesis, weights) == 1.0
|
||||
|
||||
def test_partial_matches_hypothesis_longer_than_reference(self):
|
||||
references = ["John loves Mary".split()]
|
||||
hypothesis = "John loves Mary who loves Mike".split()
|
||||
# Since no 4-grams matches were found the result should be zero
|
||||
# exp(w_1 * 1 * w_2 * 1 * w_3 * 1 * w_4 * -inf) = 0
|
||||
self.assertAlmostEqual(sentence_bleu(references, hypothesis), 0.0, places=4)
|
||||
# Checks that the warning has been raised because len(reference) < 4.
|
||||
try:
|
||||
self.assertWarns(UserWarning, sentence_bleu, references, hypothesis)
|
||||
except AttributeError:
|
||||
pass # unittest.TestCase.assertWarns is only supported in Python >= 3.2.
|
||||
|
||||
|
||||
# @unittest.skip("Skipping fringe cases for BLEU.")
|
||||
class TestBLEUFringeCases(unittest.TestCase):
|
||||
def test_case_where_n_is_bigger_than_hypothesis_length(self):
|
||||
# Test BLEU to nth order of n-grams, where n > len(hypothesis).
|
||||
references = ["John loves Mary ?".split()]
|
||||
hypothesis = "John loves Mary".split()
|
||||
n = len(hypothesis) + 1 #
|
||||
weights = (1.0 / n,) * n # Uniform weights.
|
||||
# Since no n-grams matches were found the result should be zero
|
||||
# exp(w_1 * 1 * w_2 * 1 * w_3 * 1 * w_4 * -inf) = 0
|
||||
self.assertAlmostEqual(
|
||||
sentence_bleu(references, hypothesis, weights), 0.0, places=4
|
||||
)
|
||||
# Checks that the warning has been raised because len(hypothesis) < 4.
|
||||
try:
|
||||
self.assertWarns(UserWarning, sentence_bleu, references, hypothesis)
|
||||
except AttributeError:
|
||||
pass # unittest.TestCase.assertWarns is only supported in Python >= 3.2.
|
||||
|
||||
# Test case where n > len(hypothesis) but so is n > len(reference), and
|
||||
# it's a special case where reference == hypothesis.
|
||||
references = ["John loves Mary".split()]
|
||||
hypothesis = "John loves Mary".split()
|
||||
# Since no 4-grams matches were found the result should be zero
|
||||
# exp(w_1 * 1 * w_2 * 1 * w_3 * 1 * w_4 * -inf) = 0
|
||||
self.assertAlmostEqual(
|
||||
sentence_bleu(references, hypothesis, weights), 0.0, places=4
|
||||
)
|
||||
|
||||
def test_empty_hypothesis(self):
|
||||
# Test case where there's hypothesis is empty.
|
||||
references = ["The candidate has no alignment to any of the references".split()]
|
||||
hypothesis = []
|
||||
assert sentence_bleu(references, hypothesis) == 0
|
||||
|
||||
def test_length_one_hypothesis(self):
|
||||
# Test case where there's hypothesis is of length 1 in Smoothing method 4.
|
||||
references = ["The candidate has no alignment to any of the references".split()]
|
||||
hypothesis = ["Foo"]
|
||||
method4 = SmoothingFunction().method4
|
||||
try:
|
||||
sentence_bleu(references, hypothesis, smoothing_function=method4)
|
||||
except ValueError:
|
||||
pass # unittest.TestCase.assertWarns is only supported in Python >= 3.2.
|
||||
|
||||
def test_empty_references(self):
|
||||
# Test case where there's reference is empty.
|
||||
references = [[]]
|
||||
hypothesis = "John loves Mary".split()
|
||||
assert sentence_bleu(references, hypothesis) == 0
|
||||
|
||||
def test_empty_references_and_hypothesis(self):
|
||||
# Test case where both references and hypothesis is empty.
|
||||
references = [[]]
|
||||
hypothesis = []
|
||||
assert sentence_bleu(references, hypothesis) == 0
|
||||
|
||||
def test_reference_or_hypothesis_shorter_than_fourgrams(self):
|
||||
# Test case where the length of reference or hypothesis
|
||||
# is shorter than 4.
|
||||
references = ["let it go".split()]
|
||||
hypothesis = "let go it".split()
|
||||
# Checks that the value the hypothesis and reference returns is 0.0
|
||||
# exp(w_1 * 1 * w_2 * 1 * w_3 * 1 * w_4 * -inf) = 0
|
||||
self.assertAlmostEqual(sentence_bleu(references, hypothesis), 0.0, places=4)
|
||||
# Checks that the warning has been raised.
|
||||
try:
|
||||
self.assertWarns(UserWarning, sentence_bleu, references, hypothesis)
|
||||
except AttributeError:
|
||||
pass # unittest.TestCase.assertWarns is only supported in Python >= 3.2.
|
||||
|
||||
def test_numpy_weights(self):
|
||||
# Test case where there's 0 matches
|
||||
references = ["The candidate has no alignment to any of the references".split()]
|
||||
hypothesis = "John loves Mary".split()
|
||||
|
||||
weights = np.array([0.25] * 4)
|
||||
assert sentence_bleu(references, hypothesis, weights) == 0
|
||||
|
||||
|
||||
class TestBLEUvsMteval13a(unittest.TestCase):
|
||||
def test_corpus_bleu(self):
|
||||
ref_file = find("models/wmt15_eval/ref.ru")
|
||||
hyp_file = find("models/wmt15_eval/google.ru")
|
||||
mteval_output_file = find("models/wmt15_eval/mteval-13a.output")
|
||||
|
||||
# Reads the BLEU scores from the `mteval-13a.output` file.
|
||||
# The order of the list corresponds to the order of the ngrams.
|
||||
with open(mteval_output_file) as mteval_fin:
|
||||
# The numbers are located in the last 2nd line of the file.
|
||||
# The first and 2nd item in the list are the score and system names.
|
||||
mteval_bleu_scores = map(float, mteval_fin.readlines()[-2].split()[1:-1])
|
||||
|
||||
with open(ref_file, encoding="utf8") as ref_fin:
|
||||
with open(hyp_file, encoding="utf8") as hyp_fin:
|
||||
# Whitespace tokenize the file.
|
||||
# Note: split() automatically strip().
|
||||
hypothesis = list(map(lambda x: x.split(), hyp_fin))
|
||||
# Note that the corpus_bleu input is list of list of references.
|
||||
references = list(map(lambda x: [x.split()], ref_fin))
|
||||
# Without smoothing.
|
||||
for i, mteval_bleu in zip(range(1, 10), mteval_bleu_scores):
|
||||
nltk_bleu = corpus_bleu(
|
||||
references, hypothesis, weights=(1.0 / i,) * i
|
||||
)
|
||||
# Check that the BLEU scores difference is less than 0.005 .
|
||||
# Note: This is an approximate comparison; as much as
|
||||
# +/- 0.01 BLEU might be "statistically significant",
|
||||
# the actual translation quality might not be.
|
||||
assert abs(mteval_bleu - nltk_bleu) < 0.005
|
||||
|
||||
# With the same smoothing method used in mteval-v13a.pl
|
||||
chencherry = SmoothingFunction()
|
||||
for i, mteval_bleu in zip(range(1, 10), mteval_bleu_scores):
|
||||
nltk_bleu = corpus_bleu(
|
||||
references,
|
||||
hypothesis,
|
||||
weights=(1.0 / i,) * i,
|
||||
smoothing_function=chencherry.method3,
|
||||
)
|
||||
assert abs(mteval_bleu - nltk_bleu) < 0.005
|
||||
|
||||
|
||||
class TestBLEUWithBadSentence(unittest.TestCase):
|
||||
def test_corpus_bleu_with_bad_sentence(self):
|
||||
hyp = "Teo S yb , oe uNb , R , T t , , t Tue Ar saln S , , 5istsi l , 5oe R ulO sae oR R"
|
||||
ref = str(
|
||||
"Their tasks include changing a pump on the faulty stokehold ."
|
||||
"Likewise , two species that are very similar in morphology "
|
||||
"were distinguished using genetics ."
|
||||
)
|
||||
references = [[ref.split()]]
|
||||
hypotheses = [hyp.split()]
|
||||
try: # Check that the warning is raised since no. of 2-grams < 0.
|
||||
with self.assertWarns(UserWarning):
|
||||
# Verify that the BLEU output is undesired since no. of 2-grams < 0.
|
||||
self.assertAlmostEqual(
|
||||
corpus_bleu(references, hypotheses), 0.0, places=4
|
||||
)
|
||||
except (
|
||||
AttributeError
|
||||
): # unittest.TestCase.assertWarns is only supported in Python >= 3.2.
|
||||
self.assertAlmostEqual(corpus_bleu(references, hypotheses), 0.0, places=4)
|
||||
|
||||
|
||||
class TestBLEUWithMultipleWeights(unittest.TestCase):
|
||||
def test_corpus_bleu_with_multiple_weights(self):
|
||||
hyp1 = [
|
||||
"It",
|
||||
"is",
|
||||
"a",
|
||||
"guide",
|
||||
"to",
|
||||
"action",
|
||||
"which",
|
||||
"ensures",
|
||||
"that",
|
||||
"the",
|
||||
"military",
|
||||
"always",
|
||||
"obeys",
|
||||
"the",
|
||||
"commands",
|
||||
"of",
|
||||
"the",
|
||||
"party",
|
||||
]
|
||||
ref1a = [
|
||||
"It",
|
||||
"is",
|
||||
"a",
|
||||
"guide",
|
||||
"to",
|
||||
"action",
|
||||
"that",
|
||||
"ensures",
|
||||
"that",
|
||||
"the",
|
||||
"military",
|
||||
"will",
|
||||
"forever",
|
||||
"heed",
|
||||
"Party",
|
||||
"commands",
|
||||
]
|
||||
ref1b = [
|
||||
"It",
|
||||
"is",
|
||||
"the",
|
||||
"guiding",
|
||||
"principle",
|
||||
"which",
|
||||
"guarantees",
|
||||
"the",
|
||||
"military",
|
||||
"forces",
|
||||
"always",
|
||||
"being",
|
||||
"under",
|
||||
"the",
|
||||
"command",
|
||||
"of",
|
||||
"the",
|
||||
"Party",
|
||||
]
|
||||
ref1c = [
|
||||
"It",
|
||||
"is",
|
||||
"the",
|
||||
"practical",
|
||||
"guide",
|
||||
"for",
|
||||
"the",
|
||||
"army",
|
||||
"always",
|
||||
"to",
|
||||
"heed",
|
||||
"the",
|
||||
"directions",
|
||||
"of",
|
||||
"the",
|
||||
"party",
|
||||
]
|
||||
hyp2 = [
|
||||
"he",
|
||||
"read",
|
||||
"the",
|
||||
"book",
|
||||
"because",
|
||||
"he",
|
||||
"was",
|
||||
"interested",
|
||||
"in",
|
||||
"world",
|
||||
"history",
|
||||
]
|
||||
ref2a = [
|
||||
"he",
|
||||
"was",
|
||||
"interested",
|
||||
"in",
|
||||
"world",
|
||||
"history",
|
||||
"because",
|
||||
"he",
|
||||
"read",
|
||||
"the",
|
||||
"book",
|
||||
]
|
||||
weight_1 = (1, 0, 0, 0)
|
||||
weight_2 = (0.25, 0.25, 0.25, 0.25)
|
||||
weight_3 = (0, 0, 0, 0, 1)
|
||||
|
||||
bleu_scores = corpus_bleu(
|
||||
list_of_references=[[ref1a, ref1b, ref1c], [ref2a]],
|
||||
hypotheses=[hyp1, hyp2],
|
||||
weights=[weight_1, weight_2, weight_3],
|
||||
)
|
||||
assert bleu_scores[0] == corpus_bleu(
|
||||
[[ref1a, ref1b, ref1c], [ref2a]], [hyp1, hyp2], weight_1
|
||||
)
|
||||
assert bleu_scores[1] == corpus_bleu(
|
||||
[[ref1a, ref1b, ref1c], [ref2a]], [hyp1, hyp2], weight_2
|
||||
)
|
||||
assert bleu_scores[2] == corpus_bleu(
|
||||
[[ref1a, ref1b, ref1c], [ref2a]], [hyp1, hyp2], weight_3
|
||||
)
|
||||
@@ -0,0 +1,154 @@
|
||||
"""
|
||||
Tests GDFA alignments
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
from nltk.translate.gdfa import grow_diag_final_and
|
||||
|
||||
|
||||
class TestGDFA(unittest.TestCase):
|
||||
def test_from_eflomal_outputs(self):
|
||||
"""
|
||||
Testing GDFA with first 10 eflomal outputs from issue #1829
|
||||
https://github.com/nltk/nltk/issues/1829
|
||||
"""
|
||||
# Input.
|
||||
forwards = [
|
||||
"0-0 1-2",
|
||||
"0-0 1-1",
|
||||
"0-0 2-1 3-2 4-3 5-4 6-5 7-6 8-7 7-8 9-9 10-10 9-11 11-12 12-13 13-14",
|
||||
"0-0 1-1 1-2 2-3 3-4 4-5 4-6 5-7 6-8 8-9 9-10",
|
||||
"0-0 14-1 15-2 16-3 20-5 21-6 22-7 5-8 6-9 7-10 8-11 9-12 10-13 11-14 12-15 13-16 14-17 17-18 18-19 19-20 20-21 23-22 24-23 25-24 26-25 27-27 28-28 29-29 30-30 31-31",
|
||||
"0-0 1-1 0-2 2-3",
|
||||
"0-0 2-2 4-4",
|
||||
"0-0 1-1 2-3 3-4 5-5 7-6 8-7 9-8 10-9 11-10 12-11 13-12 14-13 15-14 16-16 17-17 18-18 19-19 20-20",
|
||||
"3-0 4-1 6-2 5-3 6-4 7-5 8-6 9-7 10-8 11-9 16-10 9-12 10-13 12-14",
|
||||
"1-0",
|
||||
]
|
||||
backwards = [
|
||||
"0-0 1-2",
|
||||
"0-0 1-1",
|
||||
"0-0 2-1 3-2 4-3 5-4 6-5 7-6 8-7 9-8 10-10 11-12 12-11 13-13",
|
||||
"0-0 1-2 2-3 3-4 4-6 6-8 7-5 8-7 9-8",
|
||||
"0-0 1-8 2-9 3-10 4-11 5-12 6-11 8-13 9-14 10-15 11-16 12-17 13-18 14-19 15-20 16-21 17-22 18-23 19-24 20-29 21-30 22-31 23-2 24-3 25-4 26-5 27-5 28-6 29-7 30-28 31-31",
|
||||
"0-0 1-1 2-3",
|
||||
"0-0 1-1 2-3 4-4",
|
||||
"0-0 1-1 2-3 3-4 5-5 7-6 8-7 9-8 10-9 11-10 12-11 13-12 14-13 15-14 16-16 17-17 18-18 19-19 20-16 21-18",
|
||||
"0-0 1-1 3-2 4-1 5-3 6-4 7-5 8-6 9-7 10-8 11-9 12-8 13-9 14-8 15-9 16-10",
|
||||
"1-0",
|
||||
]
|
||||
source_lens = [2, 3, 3, 15, 11, 33, 4, 6, 23, 18]
|
||||
target_lens = [2, 4, 3, 16, 12, 33, 5, 6, 22, 16]
|
||||
# Expected Output.
|
||||
expected = [
|
||||
[(0, 0), (1, 2)],
|
||||
[(0, 0), (1, 1)],
|
||||
[
|
||||
(0, 0),
|
||||
(2, 1),
|
||||
(3, 2),
|
||||
(4, 3),
|
||||
(5, 4),
|
||||
(6, 5),
|
||||
(7, 6),
|
||||
(8, 7),
|
||||
(10, 10),
|
||||
(11, 12),
|
||||
],
|
||||
[
|
||||
(0, 0),
|
||||
(1, 1),
|
||||
(1, 2),
|
||||
(2, 3),
|
||||
(3, 4),
|
||||
(4, 5),
|
||||
(4, 6),
|
||||
(5, 7),
|
||||
(6, 8),
|
||||
(7, 5),
|
||||
(8, 7),
|
||||
(8, 9),
|
||||
(9, 8),
|
||||
(9, 10),
|
||||
],
|
||||
[
|
||||
(0, 0),
|
||||
(1, 8),
|
||||
(2, 9),
|
||||
(3, 10),
|
||||
(4, 11),
|
||||
(5, 8),
|
||||
(6, 9),
|
||||
(6, 11),
|
||||
(7, 10),
|
||||
(8, 11),
|
||||
(31, 31),
|
||||
],
|
||||
[(0, 0), (0, 2), (1, 1), (2, 3)],
|
||||
[(0, 0), (1, 1), (2, 2), (2, 3), (4, 4)],
|
||||
[
|
||||
(0, 0),
|
||||
(1, 1),
|
||||
(2, 3),
|
||||
(3, 4),
|
||||
(5, 5),
|
||||
(7, 6),
|
||||
(8, 7),
|
||||
(9, 8),
|
||||
(10, 9),
|
||||
(11, 10),
|
||||
(12, 11),
|
||||
(13, 12),
|
||||
(14, 13),
|
||||
(15, 14),
|
||||
(16, 16),
|
||||
(17, 17),
|
||||
(18, 18),
|
||||
(19, 19),
|
||||
],
|
||||
[
|
||||
(0, 0),
|
||||
(1, 1),
|
||||
(3, 0),
|
||||
(3, 2),
|
||||
(4, 1),
|
||||
(5, 3),
|
||||
(6, 2),
|
||||
(6, 4),
|
||||
(7, 5),
|
||||
(8, 6),
|
||||
(9, 7),
|
||||
(9, 12),
|
||||
(10, 8),
|
||||
(10, 13),
|
||||
(11, 9),
|
||||
(12, 8),
|
||||
(12, 14),
|
||||
(13, 9),
|
||||
(14, 8),
|
||||
(15, 9),
|
||||
(16, 10),
|
||||
],
|
||||
[(1, 0)],
|
||||
[
|
||||
(0, 0),
|
||||
(1, 1),
|
||||
(3, 2),
|
||||
(4, 3),
|
||||
(5, 4),
|
||||
(6, 5),
|
||||
(7, 6),
|
||||
(9, 10),
|
||||
(10, 12),
|
||||
(11, 13),
|
||||
(12, 14),
|
||||
(13, 15),
|
||||
],
|
||||
]
|
||||
|
||||
# Iterate through all 10 examples and check for expected outputs.
|
||||
for fw, bw, src_len, trg_len, expect in zip(
|
||||
forwards, backwards, source_lens, target_lens, expected
|
||||
):
|
||||
self.assertListEqual(expect, grow_diag_final_and(src_len, trg_len, fw, bw))
|
||||
@@ -0,0 +1,73 @@
|
||||
"""
|
||||
Tests for IBM Model 1 training methods
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from collections import defaultdict
|
||||
|
||||
from nltk.translate import AlignedSent, IBMModel, IBMModel1
|
||||
from nltk.translate.ibm_model import AlignmentInfo
|
||||
|
||||
|
||||
class TestIBMModel1(unittest.TestCase):
|
||||
def test_set_uniform_translation_probabilities(self):
|
||||
# arrange
|
||||
corpus = [
|
||||
AlignedSent(["ham", "eggs"], ["schinken", "schinken", "eier"]),
|
||||
AlignedSent(["spam", "spam", "spam", "spam"], ["spam", "spam"]),
|
||||
]
|
||||
model1 = IBMModel1(corpus, 0)
|
||||
|
||||
# act
|
||||
model1.set_uniform_probabilities(corpus)
|
||||
|
||||
# assert
|
||||
# expected_prob = 1.0 / (target vocab size + 1)
|
||||
self.assertEqual(model1.translation_table["ham"]["eier"], 1.0 / 3)
|
||||
self.assertEqual(model1.translation_table["eggs"][None], 1.0 / 3)
|
||||
|
||||
def test_set_uniform_translation_probabilities_of_non_domain_values(self):
|
||||
# arrange
|
||||
corpus = [
|
||||
AlignedSent(["ham", "eggs"], ["schinken", "schinken", "eier"]),
|
||||
AlignedSent(["spam", "spam", "spam", "spam"], ["spam", "spam"]),
|
||||
]
|
||||
model1 = IBMModel1(corpus, 0)
|
||||
|
||||
# act
|
||||
model1.set_uniform_probabilities(corpus)
|
||||
|
||||
# assert
|
||||
# examine target words that are not in the training data domain
|
||||
self.assertEqual(model1.translation_table["parrot"]["eier"], IBMModel.MIN_PROB)
|
||||
|
||||
def test_prob_t_a_given_s(self):
|
||||
# arrange
|
||||
src_sentence = ["ich", "esse", "ja", "gern", "räucherschinken"]
|
||||
trg_sentence = ["i", "love", "to", "eat", "smoked", "ham"]
|
||||
corpus = [AlignedSent(trg_sentence, src_sentence)]
|
||||
alignment_info = AlignmentInfo(
|
||||
(0, 1, 4, 0, 2, 5, 5),
|
||||
[None] + src_sentence,
|
||||
["UNUSED"] + trg_sentence,
|
||||
None,
|
||||
)
|
||||
|
||||
translation_table = defaultdict(lambda: defaultdict(float))
|
||||
translation_table["i"]["ich"] = 0.98
|
||||
translation_table["love"]["gern"] = 0.98
|
||||
translation_table["to"][None] = 0.98
|
||||
translation_table["eat"]["esse"] = 0.98
|
||||
translation_table["smoked"]["räucherschinken"] = 0.98
|
||||
translation_table["ham"]["räucherschinken"] = 0.98
|
||||
|
||||
model1 = IBMModel1(corpus, 0)
|
||||
model1.translation_table = translation_table
|
||||
|
||||
# act
|
||||
probability = model1.prob_t_a_given_s(alignment_info)
|
||||
|
||||
# assert
|
||||
lexical_translation = 0.98 * 0.98 * 0.98 * 0.98 * 0.98 * 0.98
|
||||
expected_probability = lexical_translation
|
||||
self.assertEqual(round(probability, 4), round(expected_probability, 4))
|
||||
@@ -0,0 +1,86 @@
|
||||
"""
|
||||
Tests for IBM Model 2 training methods
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from collections import defaultdict
|
||||
|
||||
from nltk.translate import AlignedSent, IBMModel, IBMModel2
|
||||
from nltk.translate.ibm_model import AlignmentInfo
|
||||
|
||||
|
||||
class TestIBMModel2(unittest.TestCase):
|
||||
def test_set_uniform_alignment_probabilities(self):
|
||||
# arrange
|
||||
corpus = [
|
||||
AlignedSent(["ham", "eggs"], ["schinken", "schinken", "eier"]),
|
||||
AlignedSent(["spam", "spam", "spam", "spam"], ["spam", "spam"]),
|
||||
]
|
||||
model2 = IBMModel2(corpus, 0)
|
||||
|
||||
# act
|
||||
model2.set_uniform_probabilities(corpus)
|
||||
|
||||
# assert
|
||||
# expected_prob = 1.0 / (length of source sentence + 1)
|
||||
self.assertEqual(model2.alignment_table[0][1][3][2], 1.0 / 4)
|
||||
self.assertEqual(model2.alignment_table[2][4][2][4], 1.0 / 3)
|
||||
|
||||
def test_set_uniform_alignment_probabilities_of_non_domain_values(self):
|
||||
# arrange
|
||||
corpus = [
|
||||
AlignedSent(["ham", "eggs"], ["schinken", "schinken", "eier"]),
|
||||
AlignedSent(["spam", "spam", "spam", "spam"], ["spam", "spam"]),
|
||||
]
|
||||
model2 = IBMModel2(corpus, 0)
|
||||
|
||||
# act
|
||||
model2.set_uniform_probabilities(corpus)
|
||||
|
||||
# assert
|
||||
# examine i and j values that are not in the training data domain
|
||||
self.assertEqual(model2.alignment_table[99][1][3][2], IBMModel.MIN_PROB)
|
||||
self.assertEqual(model2.alignment_table[2][99][2][4], IBMModel.MIN_PROB)
|
||||
|
||||
def test_prob_t_a_given_s(self):
|
||||
# arrange
|
||||
src_sentence = ["ich", "esse", "ja", "gern", "räucherschinken"]
|
||||
trg_sentence = ["i", "love", "to", "eat", "smoked", "ham"]
|
||||
corpus = [AlignedSent(trg_sentence, src_sentence)]
|
||||
alignment_info = AlignmentInfo(
|
||||
(0, 1, 4, 0, 2, 5, 5),
|
||||
[None] + src_sentence,
|
||||
["UNUSED"] + trg_sentence,
|
||||
None,
|
||||
)
|
||||
|
||||
translation_table = defaultdict(lambda: defaultdict(float))
|
||||
translation_table["i"]["ich"] = 0.98
|
||||
translation_table["love"]["gern"] = 0.98
|
||||
translation_table["to"][None] = 0.98
|
||||
translation_table["eat"]["esse"] = 0.98
|
||||
translation_table["smoked"]["räucherschinken"] = 0.98
|
||||
translation_table["ham"]["räucherschinken"] = 0.98
|
||||
|
||||
alignment_table = defaultdict(
|
||||
lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(float)))
|
||||
)
|
||||
alignment_table[0][3][5][6] = 0.97 # None -> to
|
||||
alignment_table[1][1][5][6] = 0.97 # ich -> i
|
||||
alignment_table[2][4][5][6] = 0.97 # esse -> eat
|
||||
alignment_table[4][2][5][6] = 0.97 # gern -> love
|
||||
alignment_table[5][5][5][6] = 0.96 # räucherschinken -> smoked
|
||||
alignment_table[5][6][5][6] = 0.96 # räucherschinken -> ham
|
||||
|
||||
model2 = IBMModel2(corpus, 0)
|
||||
model2.translation_table = translation_table
|
||||
model2.alignment_table = alignment_table
|
||||
|
||||
# act
|
||||
probability = model2.prob_t_a_given_s(alignment_info)
|
||||
|
||||
# assert
|
||||
lexical_translation = 0.98 * 0.98 * 0.98 * 0.98 * 0.98 * 0.98
|
||||
alignment = 0.97 * 0.97 * 0.97 * 0.97 * 0.96 * 0.96
|
||||
expected_probability = lexical_translation * alignment
|
||||
self.assertEqual(round(probability, 4), round(expected_probability, 4))
|
||||
@@ -0,0 +1,105 @@
|
||||
"""
|
||||
Tests for IBM Model 3 training methods
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from collections import defaultdict
|
||||
|
||||
from nltk.translate import AlignedSent, IBMModel, IBMModel3
|
||||
from nltk.translate.ibm_model import AlignmentInfo
|
||||
|
||||
|
||||
class TestIBMModel3(unittest.TestCase):
|
||||
def test_set_uniform_distortion_probabilities(self):
|
||||
# arrange
|
||||
corpus = [
|
||||
AlignedSent(["ham", "eggs"], ["schinken", "schinken", "eier"]),
|
||||
AlignedSent(["spam", "spam", "spam", "spam"], ["spam", "spam"]),
|
||||
]
|
||||
model3 = IBMModel3(corpus, 0)
|
||||
|
||||
# act
|
||||
model3.set_uniform_probabilities(corpus)
|
||||
|
||||
# assert
|
||||
# expected_prob = 1.0 / length of target sentence
|
||||
self.assertEqual(model3.distortion_table[1][0][3][2], 1.0 / 2)
|
||||
self.assertEqual(model3.distortion_table[4][2][2][4], 1.0 / 4)
|
||||
|
||||
def test_set_uniform_distortion_probabilities_of_non_domain_values(self):
|
||||
# arrange
|
||||
corpus = [
|
||||
AlignedSent(["ham", "eggs"], ["schinken", "schinken", "eier"]),
|
||||
AlignedSent(["spam", "spam", "spam", "spam"], ["spam", "spam"]),
|
||||
]
|
||||
model3 = IBMModel3(corpus, 0)
|
||||
|
||||
# act
|
||||
model3.set_uniform_probabilities(corpus)
|
||||
|
||||
# assert
|
||||
# examine i and j values that are not in the training data domain
|
||||
self.assertEqual(model3.distortion_table[0][0][3][2], IBMModel.MIN_PROB)
|
||||
self.assertEqual(model3.distortion_table[9][2][2][4], IBMModel.MIN_PROB)
|
||||
self.assertEqual(model3.distortion_table[2][9][2][4], IBMModel.MIN_PROB)
|
||||
|
||||
def test_prob_t_a_given_s(self):
|
||||
# arrange
|
||||
src_sentence = ["ich", "esse", "ja", "gern", "räucherschinken"]
|
||||
trg_sentence = ["i", "love", "to", "eat", "smoked", "ham"]
|
||||
corpus = [AlignedSent(trg_sentence, src_sentence)]
|
||||
alignment_info = AlignmentInfo(
|
||||
(0, 1, 4, 0, 2, 5, 5),
|
||||
[None] + src_sentence,
|
||||
["UNUSED"] + trg_sentence,
|
||||
[[3], [1], [4], [], [2], [5, 6]],
|
||||
)
|
||||
|
||||
distortion_table = defaultdict(
|
||||
lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(float)))
|
||||
)
|
||||
distortion_table[1][1][5][6] = 0.97 # i -> ich
|
||||
distortion_table[2][4][5][6] = 0.97 # love -> gern
|
||||
distortion_table[3][0][5][6] = 0.97 # to -> NULL
|
||||
distortion_table[4][2][5][6] = 0.97 # eat -> esse
|
||||
distortion_table[5][5][5][6] = 0.97 # smoked -> räucherschinken
|
||||
distortion_table[6][5][5][6] = 0.97 # ham -> räucherschinken
|
||||
|
||||
translation_table = defaultdict(lambda: defaultdict(float))
|
||||
translation_table["i"]["ich"] = 0.98
|
||||
translation_table["love"]["gern"] = 0.98
|
||||
translation_table["to"][None] = 0.98
|
||||
translation_table["eat"]["esse"] = 0.98
|
||||
translation_table["smoked"]["räucherschinken"] = 0.98
|
||||
translation_table["ham"]["räucherschinken"] = 0.98
|
||||
|
||||
fertility_table = defaultdict(lambda: defaultdict(float))
|
||||
fertility_table[1]["ich"] = 0.99
|
||||
fertility_table[1]["esse"] = 0.99
|
||||
fertility_table[0]["ja"] = 0.99
|
||||
fertility_table[1]["gern"] = 0.99
|
||||
fertility_table[2]["räucherschinken"] = 0.999
|
||||
fertility_table[1][None] = 0.99
|
||||
|
||||
probabilities = {
|
||||
"p1": 0.167,
|
||||
"translation_table": translation_table,
|
||||
"distortion_table": distortion_table,
|
||||
"fertility_table": fertility_table,
|
||||
"alignment_table": None,
|
||||
}
|
||||
|
||||
model3 = IBMModel3(corpus, 0, probabilities)
|
||||
|
||||
# act
|
||||
probability = model3.prob_t_a_given_s(alignment_info)
|
||||
|
||||
# assert
|
||||
null_generation = 5 * pow(0.167, 1) * pow(0.833, 4)
|
||||
fertility = 1 * 0.99 * 1 * 0.99 * 1 * 0.99 * 1 * 0.99 * 2 * 0.999
|
||||
lexical_translation = 0.98 * 0.98 * 0.98 * 0.98 * 0.98 * 0.98
|
||||
distortion = 0.97 * 0.97 * 0.97 * 0.97 * 0.97 * 0.97
|
||||
expected_probability = (
|
||||
null_generation * fertility * lexical_translation * distortion
|
||||
)
|
||||
self.assertEqual(round(probability, 4), round(expected_probability, 4))
|
||||
@@ -0,0 +1,120 @@
|
||||
"""
|
||||
Tests for IBM Model 4 training methods
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from collections import defaultdict
|
||||
|
||||
from nltk.translate import AlignedSent, IBMModel, IBMModel4
|
||||
from nltk.translate.ibm_model import AlignmentInfo
|
||||
|
||||
|
||||
class TestIBMModel4(unittest.TestCase):
|
||||
def test_set_uniform_distortion_probabilities_of_max_displacements(self):
|
||||
# arrange
|
||||
src_classes = {"schinken": 0, "eier": 0, "spam": 1}
|
||||
trg_classes = {"ham": 0, "eggs": 1, "spam": 2}
|
||||
corpus = [
|
||||
AlignedSent(["ham", "eggs"], ["schinken", "schinken", "eier"]),
|
||||
AlignedSent(["spam", "spam", "spam", "spam"], ["spam", "spam"]),
|
||||
]
|
||||
model4 = IBMModel4(corpus, 0, src_classes, trg_classes)
|
||||
|
||||
# act
|
||||
model4.set_uniform_probabilities(corpus)
|
||||
|
||||
# assert
|
||||
# number of displacement values =
|
||||
# 2 *(number of words in longest target sentence - 1)
|
||||
expected_prob = 1.0 / (2 * (4 - 1))
|
||||
|
||||
# examine the boundary values for (displacement, src_class, trg_class)
|
||||
self.assertEqual(model4.head_distortion_table[3][0][0], expected_prob)
|
||||
self.assertEqual(model4.head_distortion_table[-3][1][2], expected_prob)
|
||||
self.assertEqual(model4.non_head_distortion_table[3][0], expected_prob)
|
||||
self.assertEqual(model4.non_head_distortion_table[-3][2], expected_prob)
|
||||
|
||||
def test_set_uniform_distortion_probabilities_of_non_domain_values(self):
|
||||
# arrange
|
||||
src_classes = {"schinken": 0, "eier": 0, "spam": 1}
|
||||
trg_classes = {"ham": 0, "eggs": 1, "spam": 2}
|
||||
corpus = [
|
||||
AlignedSent(["ham", "eggs"], ["schinken", "schinken", "eier"]),
|
||||
AlignedSent(["spam", "spam", "spam", "spam"], ["spam", "spam"]),
|
||||
]
|
||||
model4 = IBMModel4(corpus, 0, src_classes, trg_classes)
|
||||
|
||||
# act
|
||||
model4.set_uniform_probabilities(corpus)
|
||||
|
||||
# assert
|
||||
# examine displacement values that are not in the training data domain
|
||||
self.assertEqual(model4.head_distortion_table[4][0][0], IBMModel.MIN_PROB)
|
||||
self.assertEqual(model4.head_distortion_table[100][1][2], IBMModel.MIN_PROB)
|
||||
self.assertEqual(model4.non_head_distortion_table[4][0], IBMModel.MIN_PROB)
|
||||
self.assertEqual(model4.non_head_distortion_table[100][2], IBMModel.MIN_PROB)
|
||||
|
||||
def test_prob_t_a_given_s(self):
|
||||
# arrange
|
||||
src_sentence = ["ich", "esse", "ja", "gern", "räucherschinken"]
|
||||
trg_sentence = ["i", "love", "to", "eat", "smoked", "ham"]
|
||||
src_classes = {"räucherschinken": 0, "ja": 1, "ich": 2, "esse": 3, "gern": 4}
|
||||
trg_classes = {"ham": 0, "smoked": 1, "i": 3, "love": 4, "to": 2, "eat": 4}
|
||||
corpus = [AlignedSent(trg_sentence, src_sentence)]
|
||||
alignment_info = AlignmentInfo(
|
||||
(0, 1, 4, 0, 2, 5, 5),
|
||||
[None] + src_sentence,
|
||||
["UNUSED"] + trg_sentence,
|
||||
[[3], [1], [4], [], [2], [5, 6]],
|
||||
)
|
||||
|
||||
head_distortion_table = defaultdict(
|
||||
lambda: defaultdict(lambda: defaultdict(float))
|
||||
)
|
||||
head_distortion_table[1][None][3] = 0.97 # None, i
|
||||
head_distortion_table[3][2][4] = 0.97 # ich, eat
|
||||
head_distortion_table[-2][3][4] = 0.97 # esse, love
|
||||
head_distortion_table[3][4][1] = 0.97 # gern, smoked
|
||||
|
||||
non_head_distortion_table = defaultdict(lambda: defaultdict(float))
|
||||
non_head_distortion_table[1][0] = 0.96 # ham
|
||||
|
||||
translation_table = defaultdict(lambda: defaultdict(float))
|
||||
translation_table["i"]["ich"] = 0.98
|
||||
translation_table["love"]["gern"] = 0.98
|
||||
translation_table["to"][None] = 0.98
|
||||
translation_table["eat"]["esse"] = 0.98
|
||||
translation_table["smoked"]["räucherschinken"] = 0.98
|
||||
translation_table["ham"]["räucherschinken"] = 0.98
|
||||
|
||||
fertility_table = defaultdict(lambda: defaultdict(float))
|
||||
fertility_table[1]["ich"] = 0.99
|
||||
fertility_table[1]["esse"] = 0.99
|
||||
fertility_table[0]["ja"] = 0.99
|
||||
fertility_table[1]["gern"] = 0.99
|
||||
fertility_table[2]["räucherschinken"] = 0.999
|
||||
fertility_table[1][None] = 0.99
|
||||
|
||||
probabilities = {
|
||||
"p1": 0.167,
|
||||
"translation_table": translation_table,
|
||||
"head_distortion_table": head_distortion_table,
|
||||
"non_head_distortion_table": non_head_distortion_table,
|
||||
"fertility_table": fertility_table,
|
||||
"alignment_table": None,
|
||||
}
|
||||
|
||||
model4 = IBMModel4(corpus, 0, src_classes, trg_classes, probabilities)
|
||||
|
||||
# act
|
||||
probability = model4.prob_t_a_given_s(alignment_info)
|
||||
|
||||
# assert
|
||||
null_generation = 5 * pow(0.167, 1) * pow(0.833, 4)
|
||||
fertility = 1 * 0.99 * 1 * 0.99 * 1 * 0.99 * 1 * 0.99 * 2 * 0.999
|
||||
lexical_translation = 0.98 * 0.98 * 0.98 * 0.98 * 0.98 * 0.98
|
||||
distortion = 0.97 * 0.97 * 1 * 0.97 * 0.97 * 0.96
|
||||
expected_probability = (
|
||||
null_generation * fertility * lexical_translation * distortion
|
||||
)
|
||||
self.assertEqual(round(probability, 4), round(expected_probability, 4))
|
||||
@@ -0,0 +1,160 @@
|
||||
"""
|
||||
Tests for IBM Model 5 training methods
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from collections import defaultdict
|
||||
|
||||
from nltk.translate import AlignedSent, IBMModel, IBMModel4, IBMModel5
|
||||
from nltk.translate.ibm_model import AlignmentInfo
|
||||
|
||||
|
||||
class TestIBMModel5(unittest.TestCase):
|
||||
def test_set_uniform_vacancy_probabilities_of_max_displacements(self):
|
||||
# arrange
|
||||
src_classes = {"schinken": 0, "eier": 0, "spam": 1}
|
||||
trg_classes = {"ham": 0, "eggs": 1, "spam": 2}
|
||||
corpus = [
|
||||
AlignedSent(["ham", "eggs"], ["schinken", "schinken", "eier"]),
|
||||
AlignedSent(["spam", "spam", "spam", "spam"], ["spam", "spam"]),
|
||||
]
|
||||
model5 = IBMModel5(corpus, 0, src_classes, trg_classes)
|
||||
|
||||
# act
|
||||
model5.set_uniform_probabilities(corpus)
|
||||
|
||||
# assert
|
||||
# number of vacancy difference values =
|
||||
# 2 * number of words in longest target sentence
|
||||
expected_prob = 1.0 / (2 * 4)
|
||||
|
||||
# examine the boundary values for (dv, max_v, trg_class)
|
||||
self.assertEqual(model5.head_vacancy_table[4][4][0], expected_prob)
|
||||
self.assertEqual(model5.head_vacancy_table[-3][1][2], expected_prob)
|
||||
self.assertEqual(model5.non_head_vacancy_table[4][4][0], expected_prob)
|
||||
self.assertEqual(model5.non_head_vacancy_table[-3][1][2], expected_prob)
|
||||
|
||||
def test_set_uniform_vacancy_probabilities_of_non_domain_values(self):
|
||||
# arrange
|
||||
src_classes = {"schinken": 0, "eier": 0, "spam": 1}
|
||||
trg_classes = {"ham": 0, "eggs": 1, "spam": 2}
|
||||
corpus = [
|
||||
AlignedSent(["ham", "eggs"], ["schinken", "schinken", "eier"]),
|
||||
AlignedSent(["spam", "spam", "spam", "spam"], ["spam", "spam"]),
|
||||
]
|
||||
model5 = IBMModel5(corpus, 0, src_classes, trg_classes)
|
||||
|
||||
# act
|
||||
model5.set_uniform_probabilities(corpus)
|
||||
|
||||
# assert
|
||||
# examine dv and max_v values that are not in the training data domain
|
||||
self.assertEqual(model5.head_vacancy_table[5][4][0], IBMModel.MIN_PROB)
|
||||
self.assertEqual(model5.head_vacancy_table[-4][1][2], IBMModel.MIN_PROB)
|
||||
self.assertEqual(model5.head_vacancy_table[4][0][0], IBMModel.MIN_PROB)
|
||||
self.assertEqual(model5.non_head_vacancy_table[5][4][0], IBMModel.MIN_PROB)
|
||||
self.assertEqual(model5.non_head_vacancy_table[-4][1][2], IBMModel.MIN_PROB)
|
||||
|
||||
def test_prob_t_a_given_s(self):
|
||||
# arrange
|
||||
src_sentence = ["ich", "esse", "ja", "gern", "räucherschinken"]
|
||||
trg_sentence = ["i", "love", "to", "eat", "smoked", "ham"]
|
||||
src_classes = {"räucherschinken": 0, "ja": 1, "ich": 2, "esse": 3, "gern": 4}
|
||||
trg_classes = {"ham": 0, "smoked": 1, "i": 3, "love": 4, "to": 2, "eat": 4}
|
||||
corpus = [AlignedSent(trg_sentence, src_sentence)]
|
||||
alignment_info = AlignmentInfo(
|
||||
(0, 1, 4, 0, 2, 5, 5),
|
||||
[None] + src_sentence,
|
||||
["UNUSED"] + trg_sentence,
|
||||
[[3], [1], [4], [], [2], [5, 6]],
|
||||
)
|
||||
|
||||
head_vacancy_table = defaultdict(
|
||||
lambda: defaultdict(lambda: defaultdict(float))
|
||||
)
|
||||
head_vacancy_table[1 - 0][6][3] = 0.97 # ich -> i
|
||||
head_vacancy_table[3 - 0][5][4] = 0.97 # esse -> eat
|
||||
head_vacancy_table[1 - 2][4][4] = 0.97 # gern -> love
|
||||
head_vacancy_table[2 - 0][2][1] = 0.97 # räucherschinken -> smoked
|
||||
|
||||
non_head_vacancy_table = defaultdict(
|
||||
lambda: defaultdict(lambda: defaultdict(float))
|
||||
)
|
||||
non_head_vacancy_table[1 - 0][1][0] = 0.96 # räucherschinken -> ham
|
||||
|
||||
translation_table = defaultdict(lambda: defaultdict(float))
|
||||
translation_table["i"]["ich"] = 0.98
|
||||
translation_table["love"]["gern"] = 0.98
|
||||
translation_table["to"][None] = 0.98
|
||||
translation_table["eat"]["esse"] = 0.98
|
||||
translation_table["smoked"]["räucherschinken"] = 0.98
|
||||
translation_table["ham"]["räucherschinken"] = 0.98
|
||||
|
||||
fertility_table = defaultdict(lambda: defaultdict(float))
|
||||
fertility_table[1]["ich"] = 0.99
|
||||
fertility_table[1]["esse"] = 0.99
|
||||
fertility_table[0]["ja"] = 0.99
|
||||
fertility_table[1]["gern"] = 0.99
|
||||
fertility_table[2]["räucherschinken"] = 0.999
|
||||
fertility_table[1][None] = 0.99
|
||||
|
||||
probabilities = {
|
||||
"p1": 0.167,
|
||||
"translation_table": translation_table,
|
||||
"fertility_table": fertility_table,
|
||||
"head_vacancy_table": head_vacancy_table,
|
||||
"non_head_vacancy_table": non_head_vacancy_table,
|
||||
"head_distortion_table": None,
|
||||
"non_head_distortion_table": None,
|
||||
"alignment_table": None,
|
||||
}
|
||||
|
||||
model5 = IBMModel5(corpus, 0, src_classes, trg_classes, probabilities)
|
||||
|
||||
# act
|
||||
probability = model5.prob_t_a_given_s(alignment_info)
|
||||
|
||||
# assert
|
||||
null_generation = 5 * pow(0.167, 1) * pow(0.833, 4)
|
||||
fertility = 1 * 0.99 * 1 * 0.99 * 1 * 0.99 * 1 * 0.99 * 2 * 0.999
|
||||
lexical_translation = 0.98 * 0.98 * 0.98 * 0.98 * 0.98 * 0.98
|
||||
vacancy = 0.97 * 0.97 * 1 * 0.97 * 0.97 * 0.96
|
||||
expected_probability = (
|
||||
null_generation * fertility * lexical_translation * vacancy
|
||||
)
|
||||
self.assertEqual(round(probability, 4), round(expected_probability, 4))
|
||||
|
||||
def test_prune(self):
|
||||
# arrange
|
||||
alignment_infos = [
|
||||
AlignmentInfo((1, 1), None, None, None),
|
||||
AlignmentInfo((1, 2), None, None, None),
|
||||
AlignmentInfo((2, 1), None, None, None),
|
||||
AlignmentInfo((2, 2), None, None, None),
|
||||
AlignmentInfo((0, 0), None, None, None),
|
||||
]
|
||||
min_factor = IBMModel5.MIN_SCORE_FACTOR
|
||||
best_score = 0.9
|
||||
scores = {
|
||||
(1, 1): min(min_factor * 1.5, 1) * best_score, # above threshold
|
||||
(1, 2): best_score,
|
||||
(2, 1): min_factor * best_score, # at threshold
|
||||
(2, 2): min_factor * best_score * 0.5, # low score
|
||||
(0, 0): min(min_factor * 1.1, 1) * 1.2, # above threshold
|
||||
}
|
||||
corpus = [AlignedSent(["a"], ["b"])]
|
||||
original_prob_function = IBMModel4.model4_prob_t_a_given_s
|
||||
# mock static method
|
||||
IBMModel4.model4_prob_t_a_given_s = staticmethod(
|
||||
lambda a, model: scores[a.alignment]
|
||||
)
|
||||
model5 = IBMModel5(corpus, 0, None, None)
|
||||
|
||||
# act
|
||||
pruned_alignments = model5.prune(alignment_infos)
|
||||
|
||||
# assert
|
||||
self.assertEqual(len(pruned_alignments), 3)
|
||||
|
||||
# restore static method
|
||||
IBMModel4.model4_prob_t_a_given_s = original_prob_function
|
||||
@@ -0,0 +1,269 @@
|
||||
"""
|
||||
Tests for common methods of IBM translation models
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from collections import defaultdict
|
||||
|
||||
from nltk.translate import AlignedSent, IBMModel
|
||||
from nltk.translate.ibm_model import AlignmentInfo
|
||||
|
||||
|
||||
class TestIBMModel(unittest.TestCase):
|
||||
__TEST_SRC_SENTENCE = ["j'", "aime", "bien", "jambon"]
|
||||
__TEST_TRG_SENTENCE = ["i", "love", "ham"]
|
||||
|
||||
def test_vocabularies_are_initialized(self):
|
||||
parallel_corpora = [
|
||||
AlignedSent(["one", "two", "three", "four"], ["un", "deux", "trois"]),
|
||||
AlignedSent(["five", "one", "six"], ["quatre", "cinq", "six"]),
|
||||
AlignedSent([], ["sept"]),
|
||||
]
|
||||
|
||||
ibm_model = IBMModel(parallel_corpora)
|
||||
self.assertEqual(len(ibm_model.src_vocab), 8)
|
||||
self.assertEqual(len(ibm_model.trg_vocab), 6)
|
||||
|
||||
def test_vocabularies_are_initialized_even_with_empty_corpora(self):
|
||||
parallel_corpora = []
|
||||
|
||||
ibm_model = IBMModel(parallel_corpora)
|
||||
self.assertEqual(len(ibm_model.src_vocab), 1) # addition of NULL token
|
||||
self.assertEqual(len(ibm_model.trg_vocab), 0)
|
||||
|
||||
def test_best_model2_alignment(self):
|
||||
# arrange
|
||||
sentence_pair = AlignedSent(
|
||||
TestIBMModel.__TEST_TRG_SENTENCE, TestIBMModel.__TEST_SRC_SENTENCE
|
||||
)
|
||||
# None and 'bien' have zero fertility
|
||||
translation_table = {
|
||||
"i": {"j'": 0.9, "aime": 0.05, "bien": 0.02, "jambon": 0.03, None: 0},
|
||||
"love": {"j'": 0.05, "aime": 0.9, "bien": 0.01, "jambon": 0.01, None: 0.03},
|
||||
"ham": {"j'": 0, "aime": 0.01, "bien": 0, "jambon": 0.99, None: 0},
|
||||
}
|
||||
alignment_table = defaultdict(
|
||||
lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: 0.2)))
|
||||
)
|
||||
|
||||
ibm_model = IBMModel([])
|
||||
ibm_model.translation_table = translation_table
|
||||
ibm_model.alignment_table = alignment_table
|
||||
|
||||
# act
|
||||
a_info = ibm_model.best_model2_alignment(sentence_pair)
|
||||
|
||||
# assert
|
||||
self.assertEqual(a_info.alignment[1:], (1, 2, 4)) # 0th element unused
|
||||
self.assertEqual(a_info.cepts, [[], [1], [2], [], [3]])
|
||||
|
||||
def test_best_model2_alignment_does_not_change_pegged_alignment(self):
|
||||
# arrange
|
||||
sentence_pair = AlignedSent(
|
||||
TestIBMModel.__TEST_TRG_SENTENCE, TestIBMModel.__TEST_SRC_SENTENCE
|
||||
)
|
||||
translation_table = {
|
||||
"i": {"j'": 0.9, "aime": 0.05, "bien": 0.02, "jambon": 0.03, None: 0},
|
||||
"love": {"j'": 0.05, "aime": 0.9, "bien": 0.01, "jambon": 0.01, None: 0.03},
|
||||
"ham": {"j'": 0, "aime": 0.01, "bien": 0, "jambon": 0.99, None: 0},
|
||||
}
|
||||
alignment_table = defaultdict(
|
||||
lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: 0.2)))
|
||||
)
|
||||
|
||||
ibm_model = IBMModel([])
|
||||
ibm_model.translation_table = translation_table
|
||||
ibm_model.alignment_table = alignment_table
|
||||
|
||||
# act: force 'love' to be pegged to 'jambon'
|
||||
a_info = ibm_model.best_model2_alignment(sentence_pair, 2, 4)
|
||||
# assert
|
||||
self.assertEqual(a_info.alignment[1:], (1, 4, 4))
|
||||
self.assertEqual(a_info.cepts, [[], [1], [], [], [2, 3]])
|
||||
|
||||
def test_best_model2_alignment_handles_fertile_words(self):
|
||||
# arrange
|
||||
sentence_pair = AlignedSent(
|
||||
["i", "really", ",", "really", "love", "ham"],
|
||||
TestIBMModel.__TEST_SRC_SENTENCE,
|
||||
)
|
||||
# 'bien' produces 2 target words: 'really' and another 'really'
|
||||
translation_table = {
|
||||
"i": {"j'": 0.9, "aime": 0.05, "bien": 0.02, "jambon": 0.03, None: 0},
|
||||
"really": {"j'": 0, "aime": 0, "bien": 0.9, "jambon": 0.01, None: 0.09},
|
||||
",": {"j'": 0, "aime": 0, "bien": 0.3, "jambon": 0, None: 0.7},
|
||||
"love": {"j'": 0.05, "aime": 0.9, "bien": 0.01, "jambon": 0.01, None: 0.03},
|
||||
"ham": {"j'": 0, "aime": 0.01, "bien": 0, "jambon": 0.99, None: 0},
|
||||
}
|
||||
alignment_table = defaultdict(
|
||||
lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: 0.2)))
|
||||
)
|
||||
|
||||
ibm_model = IBMModel([])
|
||||
ibm_model.translation_table = translation_table
|
||||
ibm_model.alignment_table = alignment_table
|
||||
|
||||
# act
|
||||
a_info = ibm_model.best_model2_alignment(sentence_pair)
|
||||
|
||||
# assert
|
||||
self.assertEqual(a_info.alignment[1:], (1, 3, 0, 3, 2, 4))
|
||||
self.assertEqual(a_info.cepts, [[3], [1], [5], [2, 4], [6]])
|
||||
|
||||
def test_best_model2_alignment_handles_empty_src_sentence(self):
|
||||
# arrange
|
||||
sentence_pair = AlignedSent(TestIBMModel.__TEST_TRG_SENTENCE, [])
|
||||
ibm_model = IBMModel([])
|
||||
|
||||
# act
|
||||
a_info = ibm_model.best_model2_alignment(sentence_pair)
|
||||
|
||||
# assert
|
||||
self.assertEqual(a_info.alignment[1:], (0, 0, 0))
|
||||
self.assertEqual(a_info.cepts, [[1, 2, 3]])
|
||||
|
||||
def test_best_model2_alignment_handles_empty_trg_sentence(self):
|
||||
# arrange
|
||||
sentence_pair = AlignedSent([], TestIBMModel.__TEST_SRC_SENTENCE)
|
||||
ibm_model = IBMModel([])
|
||||
|
||||
# act
|
||||
a_info = ibm_model.best_model2_alignment(sentence_pair)
|
||||
|
||||
# assert
|
||||
self.assertEqual(a_info.alignment[1:], ())
|
||||
self.assertEqual(a_info.cepts, [[], [], [], [], []])
|
||||
|
||||
def test_neighboring_finds_neighbor_alignments(self):
|
||||
# arrange
|
||||
a_info = AlignmentInfo(
|
||||
(0, 3, 2),
|
||||
(None, "des", "œufs", "verts"),
|
||||
("UNUSED", "green", "eggs"),
|
||||
[[], [], [2], [1]],
|
||||
)
|
||||
ibm_model = IBMModel([])
|
||||
|
||||
# act
|
||||
neighbors = ibm_model.neighboring(a_info)
|
||||
|
||||
# assert
|
||||
neighbor_alignments = set()
|
||||
for neighbor in neighbors:
|
||||
neighbor_alignments.add(neighbor.alignment)
|
||||
expected_alignments = {
|
||||
# moves
|
||||
(0, 0, 2),
|
||||
(0, 1, 2),
|
||||
(0, 2, 2),
|
||||
(0, 3, 0),
|
||||
(0, 3, 1),
|
||||
(0, 3, 3),
|
||||
# swaps
|
||||
(0, 2, 3),
|
||||
# original alignment
|
||||
(0, 3, 2),
|
||||
}
|
||||
self.assertEqual(neighbor_alignments, expected_alignments)
|
||||
|
||||
def test_neighboring_sets_neighbor_alignment_info(self):
|
||||
# arrange
|
||||
a_info = AlignmentInfo(
|
||||
(0, 3, 2),
|
||||
(None, "des", "œufs", "verts"),
|
||||
("UNUSED", "green", "eggs"),
|
||||
[[], [], [2], [1]],
|
||||
)
|
||||
ibm_model = IBMModel([])
|
||||
|
||||
# act
|
||||
neighbors = ibm_model.neighboring(a_info)
|
||||
|
||||
# assert: select a few particular alignments
|
||||
for neighbor in neighbors:
|
||||
if neighbor.alignment == (0, 2, 2):
|
||||
moved_alignment = neighbor
|
||||
elif neighbor.alignment == (0, 3, 2):
|
||||
swapped_alignment = neighbor
|
||||
|
||||
self.assertEqual(moved_alignment.cepts, [[], [], [1, 2], []])
|
||||
self.assertEqual(swapped_alignment.cepts, [[], [], [2], [1]])
|
||||
|
||||
def test_neighboring_returns_neighbors_with_pegged_alignment(self):
|
||||
# arrange
|
||||
a_info = AlignmentInfo(
|
||||
(0, 3, 2),
|
||||
(None, "des", "œufs", "verts"),
|
||||
("UNUSED", "green", "eggs"),
|
||||
[[], [], [2], [1]],
|
||||
)
|
||||
ibm_model = IBMModel([])
|
||||
|
||||
# act: peg 'eggs' to align with 'œufs'
|
||||
neighbors = ibm_model.neighboring(a_info, 2)
|
||||
|
||||
# assert
|
||||
neighbor_alignments = set()
|
||||
for neighbor in neighbors:
|
||||
neighbor_alignments.add(neighbor.alignment)
|
||||
expected_alignments = {
|
||||
# moves
|
||||
(0, 0, 2),
|
||||
(0, 1, 2),
|
||||
(0, 2, 2),
|
||||
# no swaps
|
||||
# original alignment
|
||||
(0, 3, 2),
|
||||
}
|
||||
self.assertEqual(neighbor_alignments, expected_alignments)
|
||||
|
||||
def test_hillclimb(self):
|
||||
# arrange
|
||||
initial_alignment = AlignmentInfo((0, 3, 2), None, None, None)
|
||||
|
||||
def neighboring_mock(a, j):
|
||||
if a.alignment == (0, 3, 2):
|
||||
return {
|
||||
AlignmentInfo((0, 2, 2), None, None, None),
|
||||
AlignmentInfo((0, 1, 1), None, None, None),
|
||||
}
|
||||
elif a.alignment == (0, 2, 2):
|
||||
return {
|
||||
AlignmentInfo((0, 3, 3), None, None, None),
|
||||
AlignmentInfo((0, 4, 4), None, None, None),
|
||||
}
|
||||
return set()
|
||||
|
||||
def prob_t_a_given_s_mock(a):
|
||||
prob_values = {
|
||||
(0, 3, 2): 0.5,
|
||||
(0, 2, 2): 0.6,
|
||||
(0, 1, 1): 0.4,
|
||||
(0, 3, 3): 0.6,
|
||||
(0, 4, 4): 0.7,
|
||||
}
|
||||
return prob_values.get(a.alignment, 0.01)
|
||||
|
||||
ibm_model = IBMModel([])
|
||||
ibm_model.neighboring = neighboring_mock
|
||||
ibm_model.prob_t_a_given_s = prob_t_a_given_s_mock
|
||||
|
||||
# act
|
||||
best_alignment = ibm_model.hillclimb(initial_alignment)
|
||||
|
||||
# assert: hill climbing goes from (0, 3, 2) -> (0, 2, 2) -> (0, 4, 4)
|
||||
self.assertEqual(best_alignment.alignment, (0, 4, 4))
|
||||
|
||||
def test_sample(self):
|
||||
# arrange
|
||||
sentence_pair = AlignedSent(
|
||||
TestIBMModel.__TEST_TRG_SENTENCE, TestIBMModel.__TEST_SRC_SENTENCE
|
||||
)
|
||||
ibm_model = IBMModel([])
|
||||
ibm_model.prob_t_a_given_s = lambda x: 0.001
|
||||
|
||||
# act
|
||||
samples, best_alignment = ibm_model.sample(sentence_pair)
|
||||
|
||||
# assert
|
||||
self.assertEqual(len(samples), 61)
|
||||
@@ -0,0 +1,20 @@
|
||||
import unittest
|
||||
|
||||
from nltk.translate.meteor_score import meteor_score
|
||||
|
||||
|
||||
class TestMETEOR(unittest.TestCase):
|
||||
reference = [["this", "is", "a", "test"], ["this", "is" "test"]]
|
||||
candidate = ["THIS", "Is", "a", "tEST"]
|
||||
|
||||
def test_meteor(self):
|
||||
score = meteor_score(self.reference, self.candidate, preprocess=str.lower)
|
||||
assert score == 0.9921875
|
||||
|
||||
def test_reference_type_check(self):
|
||||
str_reference = [" ".join(ref) for ref in self.reference]
|
||||
self.assertRaises(TypeError, meteor_score, str_reference, self.candidate)
|
||||
|
||||
def test_candidate_type_check(self):
|
||||
str_candidate = " ".join(self.candidate)
|
||||
self.assertRaises(TypeError, meteor_score, self.reference, str_candidate)
|
||||
@@ -0,0 +1,36 @@
|
||||
"""
|
||||
Tests for NIST translation evaluation metric
|
||||
"""
|
||||
|
||||
import io
|
||||
import unittest
|
||||
|
||||
from nltk.data import find
|
||||
from nltk.translate.nist_score import corpus_nist
|
||||
|
||||
|
||||
class TestNIST(unittest.TestCase):
|
||||
def test_sentence_nist(self):
|
||||
ref_file = find("models/wmt15_eval/ref.ru")
|
||||
hyp_file = find("models/wmt15_eval/google.ru")
|
||||
mteval_output_file = find("models/wmt15_eval/mteval-13a.output")
|
||||
|
||||
# Reads the NIST scores from the `mteval-13a.output` file.
|
||||
# The order of the list corresponds to the order of the ngrams.
|
||||
with open(mteval_output_file) as mteval_fin:
|
||||
# The numbers are located in the last 4th line of the file.
|
||||
# The first and 2nd item in the list are the score and system names.
|
||||
mteval_nist_scores = map(float, mteval_fin.readlines()[-4].split()[1:-1])
|
||||
|
||||
with open(ref_file, encoding="utf8") as ref_fin:
|
||||
with open(hyp_file, encoding="utf8") as hyp_fin:
|
||||
# Whitespace tokenize the file.
|
||||
# Note: split() automatically strip().
|
||||
hypotheses = list(map(lambda x: x.split(), hyp_fin))
|
||||
# Note that the corpus_bleu input is list of list of references.
|
||||
references = list(map(lambda x: [x.split()], ref_fin))
|
||||
# Without smoothing.
|
||||
for i, mteval_nist in zip(range(1, 10), mteval_nist_scores):
|
||||
nltk_nist = corpus_nist(references, hypotheses, i)
|
||||
# Check that the NIST scores difference is less than 0.5
|
||||
assert abs(mteval_nist - nltk_nist) < 0.05
|
||||
@@ -0,0 +1,294 @@
|
||||
# Natural Language Toolkit: Stack decoder
|
||||
#
|
||||
# Copyright (C) 2001-2025 NLTK Project
|
||||
# Author: Tah Wei Hoon <hoon.tw@gmail.com>
|
||||
# URL: <https://www.nltk.org/>
|
||||
# For license information, see LICENSE.TXT
|
||||
|
||||
"""
|
||||
Tests for stack decoder
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from collections import defaultdict
|
||||
from math import log
|
||||
|
||||
from nltk.translate import PhraseTable, StackDecoder
|
||||
from nltk.translate.stack_decoder import _Hypothesis, _Stack
|
||||
|
||||
|
||||
class TestStackDecoder(unittest.TestCase):
|
||||
def test_find_all_src_phrases(self):
|
||||
# arrange
|
||||
phrase_table = TestStackDecoder.create_fake_phrase_table()
|
||||
stack_decoder = StackDecoder(phrase_table, None)
|
||||
sentence = ("my", "hovercraft", "is", "full", "of", "eels")
|
||||
|
||||
# act
|
||||
src_phrase_spans = stack_decoder.find_all_src_phrases(sentence)
|
||||
|
||||
# assert
|
||||
self.assertEqual(src_phrase_spans[0], [2]) # 'my hovercraft'
|
||||
self.assertEqual(src_phrase_spans[1], [2]) # 'hovercraft'
|
||||
self.assertEqual(src_phrase_spans[2], [3]) # 'is'
|
||||
self.assertEqual(src_phrase_spans[3], [5, 6]) # 'full of', 'full of eels'
|
||||
self.assertFalse(src_phrase_spans[4]) # no entry starting with 'of'
|
||||
self.assertEqual(src_phrase_spans[5], [6]) # 'eels'
|
||||
|
||||
def test_distortion_score(self):
|
||||
# arrange
|
||||
stack_decoder = StackDecoder(None, None)
|
||||
stack_decoder.distortion_factor = 0.5
|
||||
hypothesis = _Hypothesis()
|
||||
hypothesis.src_phrase_span = (3, 5)
|
||||
|
||||
# act
|
||||
score = stack_decoder.distortion_score(hypothesis, (8, 10))
|
||||
|
||||
# assert
|
||||
expected_score = log(stack_decoder.distortion_factor) * (8 - 5)
|
||||
self.assertEqual(score, expected_score)
|
||||
|
||||
def test_distortion_score_of_first_expansion(self):
|
||||
# arrange
|
||||
stack_decoder = StackDecoder(None, None)
|
||||
stack_decoder.distortion_factor = 0.5
|
||||
hypothesis = _Hypothesis()
|
||||
|
||||
# act
|
||||
score = stack_decoder.distortion_score(hypothesis, (8, 10))
|
||||
|
||||
# assert
|
||||
# expansion from empty hypothesis always has zero distortion cost
|
||||
self.assertEqual(score, 0.0)
|
||||
|
||||
def test_compute_future_costs(self):
|
||||
# arrange
|
||||
phrase_table = TestStackDecoder.create_fake_phrase_table()
|
||||
language_model = TestStackDecoder.create_fake_language_model()
|
||||
stack_decoder = StackDecoder(phrase_table, language_model)
|
||||
sentence = ("my", "hovercraft", "is", "full", "of", "eels")
|
||||
|
||||
# act
|
||||
future_scores = stack_decoder.compute_future_scores(sentence)
|
||||
|
||||
# assert
|
||||
self.assertEqual(
|
||||
future_scores[1][2],
|
||||
(
|
||||
phrase_table.translations_for(("hovercraft",))[0].log_prob
|
||||
+ language_model.probability(("hovercraft",))
|
||||
),
|
||||
)
|
||||
self.assertEqual(
|
||||
future_scores[0][2],
|
||||
(
|
||||
phrase_table.translations_for(("my", "hovercraft"))[0].log_prob
|
||||
+ language_model.probability(("my", "hovercraft"))
|
||||
),
|
||||
)
|
||||
|
||||
def test_compute_future_costs_for_phrases_not_in_phrase_table(self):
|
||||
# arrange
|
||||
phrase_table = TestStackDecoder.create_fake_phrase_table()
|
||||
language_model = TestStackDecoder.create_fake_language_model()
|
||||
stack_decoder = StackDecoder(phrase_table, language_model)
|
||||
sentence = ("my", "hovercraft", "is", "full", "of", "eels")
|
||||
|
||||
# act
|
||||
future_scores = stack_decoder.compute_future_scores(sentence)
|
||||
|
||||
# assert
|
||||
self.assertEqual(
|
||||
future_scores[1][3], # 'hovercraft is' is not in phrase table
|
||||
future_scores[1][2] + future_scores[2][3],
|
||||
) # backoff
|
||||
|
||||
def test_future_score(self):
|
||||
# arrange: sentence with 8 words; words 2, 3, 4 already translated
|
||||
hypothesis = _Hypothesis()
|
||||
hypothesis.untranslated_spans = lambda _: [(0, 2), (5, 8)] # mock
|
||||
future_score_table = defaultdict(lambda: defaultdict(float))
|
||||
future_score_table[0][2] = 0.4
|
||||
future_score_table[5][8] = 0.5
|
||||
stack_decoder = StackDecoder(None, None)
|
||||
|
||||
# act
|
||||
future_score = stack_decoder.future_score(hypothesis, future_score_table, 8)
|
||||
|
||||
# assert
|
||||
self.assertEqual(future_score, 0.4 + 0.5)
|
||||
|
||||
def test_valid_phrases(self):
|
||||
# arrange
|
||||
hypothesis = _Hypothesis()
|
||||
# mock untranslated_spans method
|
||||
hypothesis.untranslated_spans = lambda _: [(0, 2), (3, 6)]
|
||||
all_phrases_from = [[1, 4], [2], [], [5], [5, 6, 7], [], [7]]
|
||||
|
||||
# act
|
||||
phrase_spans = StackDecoder.valid_phrases(all_phrases_from, hypothesis)
|
||||
|
||||
# assert
|
||||
self.assertEqual(phrase_spans, [(0, 1), (1, 2), (3, 5), (4, 5), (4, 6)])
|
||||
|
||||
@staticmethod
|
||||
def create_fake_phrase_table():
|
||||
phrase_table = PhraseTable()
|
||||
phrase_table.add(("hovercraft",), ("",), 0.8)
|
||||
phrase_table.add(("my", "hovercraft"), ("", ""), 0.7)
|
||||
phrase_table.add(("my", "cheese"), ("", ""), 0.7)
|
||||
phrase_table.add(("is",), ("",), 0.8)
|
||||
phrase_table.add(("is",), ("",), 0.5)
|
||||
phrase_table.add(("full", "of"), ("", ""), 0.01)
|
||||
phrase_table.add(("full", "of", "eels"), ("", "", ""), 0.5)
|
||||
phrase_table.add(("full", "of", "spam"), ("", ""), 0.5)
|
||||
phrase_table.add(("eels",), ("",), 0.5)
|
||||
phrase_table.add(("spam",), ("",), 0.5)
|
||||
return phrase_table
|
||||
|
||||
@staticmethod
|
||||
def create_fake_language_model():
|
||||
# nltk.model should be used here once it is implemented
|
||||
language_prob = defaultdict(lambda: -999.0)
|
||||
language_prob[("my",)] = log(0.1)
|
||||
language_prob[("hovercraft",)] = log(0.1)
|
||||
language_prob[("is",)] = log(0.1)
|
||||
language_prob[("full",)] = log(0.1)
|
||||
language_prob[("of",)] = log(0.1)
|
||||
language_prob[("eels",)] = log(0.1)
|
||||
language_prob[("my", "hovercraft")] = log(0.3)
|
||||
language_model = type(
|
||||
"", (object,), {"probability": lambda _, phrase: language_prob[phrase]}
|
||||
)()
|
||||
return language_model
|
||||
|
||||
|
||||
class TestHypothesis(unittest.TestCase):
|
||||
def setUp(self):
|
||||
root = _Hypothesis()
|
||||
child = _Hypothesis(
|
||||
raw_score=0.5,
|
||||
src_phrase_span=(3, 7),
|
||||
trg_phrase=("hello", "world"),
|
||||
previous=root,
|
||||
)
|
||||
grandchild = _Hypothesis(
|
||||
raw_score=0.4,
|
||||
src_phrase_span=(1, 2),
|
||||
trg_phrase=("and", "goodbye"),
|
||||
previous=child,
|
||||
)
|
||||
self.hypothesis_chain = grandchild
|
||||
|
||||
def test_translation_so_far(self):
|
||||
# act
|
||||
translation = self.hypothesis_chain.translation_so_far()
|
||||
|
||||
# assert
|
||||
self.assertEqual(translation, ["hello", "world", "and", "goodbye"])
|
||||
|
||||
def test_translation_so_far_for_empty_hypothesis(self):
|
||||
# arrange
|
||||
hypothesis = _Hypothesis()
|
||||
|
||||
# act
|
||||
translation = hypothesis.translation_so_far()
|
||||
|
||||
# assert
|
||||
self.assertEqual(translation, [])
|
||||
|
||||
def test_total_translated_words(self):
|
||||
# act
|
||||
total_translated_words = self.hypothesis_chain.total_translated_words()
|
||||
|
||||
# assert
|
||||
self.assertEqual(total_translated_words, 5)
|
||||
|
||||
def test_translated_positions(self):
|
||||
# act
|
||||
translated_positions = self.hypothesis_chain.translated_positions()
|
||||
|
||||
# assert
|
||||
translated_positions.sort()
|
||||
self.assertEqual(translated_positions, [1, 3, 4, 5, 6])
|
||||
|
||||
def test_untranslated_spans(self):
|
||||
# act
|
||||
untranslated_spans = self.hypothesis_chain.untranslated_spans(10)
|
||||
|
||||
# assert
|
||||
self.assertEqual(untranslated_spans, [(0, 1), (2, 3), (7, 10)])
|
||||
|
||||
def test_untranslated_spans_for_empty_hypothesis(self):
|
||||
# arrange
|
||||
hypothesis = _Hypothesis()
|
||||
|
||||
# act
|
||||
untranslated_spans = hypothesis.untranslated_spans(10)
|
||||
|
||||
# assert
|
||||
self.assertEqual(untranslated_spans, [(0, 10)])
|
||||
|
||||
|
||||
class TestStack(unittest.TestCase):
|
||||
def test_push_bumps_off_worst_hypothesis_when_stack_is_full(self):
|
||||
# arrange
|
||||
stack = _Stack(3)
|
||||
poor_hypothesis = _Hypothesis(0.01)
|
||||
|
||||
# act
|
||||
stack.push(_Hypothesis(0.2))
|
||||
stack.push(poor_hypothesis)
|
||||
stack.push(_Hypothesis(0.1))
|
||||
stack.push(_Hypothesis(0.3))
|
||||
|
||||
# assert
|
||||
self.assertFalse(poor_hypothesis in stack)
|
||||
|
||||
def test_push_removes_hypotheses_that_fall_below_beam_threshold(self):
|
||||
# arrange
|
||||
stack = _Stack(3, 0.5)
|
||||
poor_hypothesis = _Hypothesis(0.01)
|
||||
worse_hypothesis = _Hypothesis(0.009)
|
||||
|
||||
# act
|
||||
stack.push(poor_hypothesis)
|
||||
stack.push(worse_hypothesis)
|
||||
stack.push(_Hypothesis(0.9)) # greatly superior hypothesis
|
||||
|
||||
# assert
|
||||
self.assertFalse(poor_hypothesis in stack)
|
||||
self.assertFalse(worse_hypothesis in stack)
|
||||
|
||||
def test_push_does_not_add_hypothesis_that_falls_below_beam_threshold(self):
|
||||
# arrange
|
||||
stack = _Stack(3, 0.5)
|
||||
poor_hypothesis = _Hypothesis(0.01)
|
||||
|
||||
# act
|
||||
stack.push(_Hypothesis(0.9)) # greatly superior hypothesis
|
||||
stack.push(poor_hypothesis)
|
||||
|
||||
# assert
|
||||
self.assertFalse(poor_hypothesis in stack)
|
||||
|
||||
def test_best_returns_the_best_hypothesis(self):
|
||||
# arrange
|
||||
stack = _Stack(3)
|
||||
best_hypothesis = _Hypothesis(0.99)
|
||||
|
||||
# act
|
||||
stack.push(_Hypothesis(0.0))
|
||||
stack.push(best_hypothesis)
|
||||
stack.push(_Hypothesis(0.5))
|
||||
|
||||
# assert
|
||||
self.assertEqual(stack.best(), best_hypothesis)
|
||||
|
||||
def test_best_returns_none_when_stack_is_empty(self):
|
||||
# arrange
|
||||
stack = _Stack(3)
|
||||
|
||||
# assert
|
||||
self.assertEqual(stack.best(), None)
|
||||
Reference in New Issue
Block a user