Source code for alex.components.slu.test_da

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from copy import deepcopy
import unittest
from unittest import TestCase
import math

if __name__ == "__main__":
    import autopath
import __init__

from alex.components.slu.da import DialogueAct, DialogueActItem, DialogueActNBList, \
    DialogueActConfusionNetwork, DialogueActConfusionNetworkException, merge_slu_nblists, merge_slu_confnets
from alex.ml.hypothesis import ConfusionNetworkException


[docs]class TestDA(unittest.TestCase):
[docs] def test_swapping_merge_normalise(self): nblist1 = DialogueActNBList() nblist1.add(0.7, DialogueAct("hello()")) nblist1.add(0.2, DialogueAct("bye()")) nblist2 = deepcopy(nblist1) nblist1.merge().normalise() nblist2.normalise().merge() s = [] s.append("") s.append("Using merge().normalise():") s.append(unicode(nblist1)) s.append("") s.append("Using normalise().merge():") s.append(unicode(nblist2)) s.append("") self.assertEqual(nblist1, nblist2)
[docs] def test_merge_slu_nblists_full_nbest_lists(self): # make sure the alex.components.slu.da.merge_slu_nblists merges nblists correctly nblist1 = DialogueActNBList() nblist1.add(0.7, DialogueAct("hello()")) nblist1.add(0.2, DialogueAct("bye()")) nblist1.merge().normalise() # nblist1.normalise() nblist2 = DialogueActNBList() nblist2.add(0.6, DialogueAct("hello()")) nblist2.add(0.3, DialogueAct("restart()")) nblist2.merge().normalise() # nblist2.normalise() nblists = [[0.7, nblist1], [0.3, nblist2]] merged_nblists = merge_slu_nblists(nblists) correct_merged_nblists = DialogueActNBList() correct_merged_nblists.add(0.7 * 0.7, DialogueAct("hello()")) correct_merged_nblists.add(0.7 * 0.2, DialogueAct("bye()")) correct_merged_nblists.add(0.7 * 0.1, DialogueAct("other()")) correct_merged_nblists.add(0.3 * 0.6, DialogueAct("hello()")) correct_merged_nblists.add(0.3 * 0.3, DialogueAct("restart()")) correct_merged_nblists.add(0.3 * 0.1, DialogueAct("other()")) correct_merged_nblists.merge().normalise() # correct_merged_nblists.normalise() s = [] s.append("") s.append("Merged nblists:") s.append(unicode(merged_nblists)) s.append("") s.append("Correct merged results:") s.append(unicode(correct_merged_nblists)) s.append("") self.assertEqual(unicode(merged_nblists), unicode(correct_merged_nblists))
[docs] def test_merge_slu_confnets(self): confnet1 = DialogueActConfusionNetwork() confnet1.add(0.7, DialogueActItem('hello')) confnet1.add(0.2, DialogueActItem('bye')) confnet2 = DialogueActConfusionNetwork() confnet2.add(0.6, DialogueActItem('hello')) confnet2.add(0.3, DialogueActItem('restart')) confnets = [[0.7, confnet1], [0.3, confnet2]] merged_confnets = merge_slu_confnets(confnets) correct_merged_confnet = DialogueActConfusionNetwork() correct_merged_confnet.add_merge(0.7 * 0.7, DialogueActItem('hello'), combine='add') correct_merged_confnet.add_merge(0.7 * 0.2, DialogueActItem('bye'), combine='add') correct_merged_confnet.add_merge(0.3 * 0.6, DialogueActItem('hello'), combine='add') correct_merged_confnet.add_merge(0.3 * 0.3, DialogueActItem('restart'), combine='add') s = [] s.append("") s.append("Merged confnets:") s.append(unicode(merged_confnets)) s.append("") s.append("Correct merged results:") s.append(unicode(correct_merged_confnet)) s.append("") self.assertEqual(unicode(merged_confnets), unicode(correct_merged_confnet))
[docs]class TestDialogueActConfusionNetwork(TestCase):
[docs] def test_add_merge(self): dai = DialogueActItem(dai='inform(food=chinese)') dacn = DialogueActConfusionNetwork() dacn.add_merge(0.5, dai, combine='add') self.assertEqual(dacn._get_prob([0]), 0.5) dacn.add_merge(0.5, dai, combine='add') self.assertEqual(dacn._get_prob([0]), 1.0)
[docs] def test_get_best_da(self): dacn = DialogueActConfusionNetwork() dacn.add(0.2, DialogueActItem(dai='inform(food=chinese)')) dacn.add(0.7, DialogueActItem(dai='inform(food=czech)')) dacn.add(0.1, DialogueActItem(dai='inform(food=russian)')) da = dacn.get_best_da() self.assertEqual(len(da), 1) self.assertEqual(da.dais[0], DialogueActItem(dai='inform(food=czech)')) dacn = DialogueActConfusionNetwork() dacn.add(0.2, DialogueActItem(dai='inform(food=chinese)')) dacn.add(0.3, DialogueActItem(dai='inform(food=czech)')) dacn.add(0.7, DialogueActItem(dai='inform(food=null)')) dacn.add(0.1, DialogueActItem(dai='inform(food=russian)')) da = dacn.get_best_nonnull_da() self.assertEqual(len(da), 1) self.assertEqual(da.dais[0], DialogueActItem(dai='inform(food=null)'))
[docs] def test_get_best_nonnull_da(self): dacn = DialogueActConfusionNetwork() dacn.add(0.2, DialogueActItem(dai='inform(food=chinese)')) dacn.add(0.7, DialogueActItem(dai='inform(food=czech)')) dacn.add(0.1, DialogueActItem(dai='inform(food=russian)')) da_nn = dacn.get_best_nonnull_da() self.assertEqual(len(da_nn), 1) self.assertEqual(da_nn.dais[0], DialogueActItem(dai='inform(food=czech)')) dacn = DialogueActConfusionNetwork() dacn.add(0.075, DialogueActItem(dai='inform(food=chinese)')) dacn.add(0.7, DialogueActItem(dai='null()')) dacn.add(0.15, DialogueActItem(dai='inform(food=czech)')) dacn.add(0.075, DialogueActItem(dai='inform(food=russian)')) da_nn = dacn.get_best_nonnull_da() self.assertEqual(len(da_nn), 1) self.assertEqual(da_nn.dais[0], DialogueActItem(dai='inform(food=czech)'))
[docs] def test_get_best_da_hyp(self): # Test case when only one dai should be included in the hyp. dacn = DialogueActConfusionNetwork() dacn.add(0.2, DialogueActItem(dai='inform(food=chinese)')) dacn.add(0.7, DialogueActItem(dai='inform(food=czech)')) dacn.add(0.1, DialogueActItem(dai='inform(food=russian)')) best_hyp = dacn.get_best_da_hyp(use_log=False) self.assertAlmostEqual(best_hyp.prob, 0.8 * 0.7 * 0.9) self.assertEqual(len(best_hyp.da), 1) # Test case when 2 dais should be included in the hyp. dacn = DialogueActConfusionNetwork() dacn.add(0.1, DialogueActItem(dai='inform(food=chinese)')) dacn.add(0.7, DialogueActItem(dai='inform(food=czech)')) dacn.add(0.9, DialogueActItem(dai='inform(food=russian)')) best_hyp = dacn.get_best_da_hyp(use_log=False) self.assertAlmostEqual(best_hyp.prob, 0.9 * 0.7 * 0.9) self.assertEqual(len(best_hyp.da), 2) # Test the case with logarithms. dacn = DialogueActConfusionNetwork() dacn.add(0.1, DialogueActItem(dai='inform(food=chinese)')) dacn.add(0.7, DialogueActItem(dai='inform(food=czech)')) dacn.add(0.9, DialogueActItem(dai='inform(food=russian)')) best_hyp = dacn.get_best_da_hyp(use_log=True) self.assertAlmostEqual(best_hyp.prob, math.log(0.9 * 0.7 * 0.9)) self.assertEqual(len(best_hyp.da), 2) # Test the case with manual thresholds. dacn = DialogueActConfusionNetwork() dacn.add(0.1, DialogueActItem(dai='inform(food=chinese)')) dacn.add(0.7, DialogueActItem(dai='inform(food=czech)')) dacn.add(0.9, DialogueActItem(dai='inform(food=russian)')) best_hyp = dacn.get_best_da_hyp( use_log=True, threshold=0.1, thresholds={ DialogueActItem(dai='inform(food=chinese)'): 0.5, DialogueActItem(dai='inform(food=czech)'): 0.9, DialogueActItem(dai='inform(food=russian)'): 0.5 }) # Test food=czech should NOT be included. self.assertAlmostEqual(best_hyp.prob, math.log(0.9 * 0.3 * 0.9)) self.assertEqual(len(best_hyp.da), 1) self.assertTrue(not DialogueActItem(dai='inform(food=czech)') in best_hyp.da) dacn = DialogueActConfusionNetwork() dacn.add(0.1, DialogueActItem(dai='inform(food=chinese)')) dacn.add(0.7, DialogueActItem(dai='inform(food=czech)')) dacn.add(0.9, DialogueActItem(dai='inform(food=russian)')) best_hyp = dacn.get_best_da_hyp( use_log=True, threshold=0.1, thresholds={ DialogueActItem(dai='inform(food=chinese)'): 0.5, DialogueActItem(dai='inform(food=czech)'): 0.5, DialogueActItem(dai='inform(food=russian)'): 0.5 }) # Test food=czech should be included. self.assertAlmostEqual(best_hyp.prob, math.log(0.9 * 0.7 * 0.9)) self.assertEqual(len(best_hyp.da), 2) self.assertTrue(DialogueActItem(dai='inform(food=czech)') in best_hyp.da)
[docs] def test_get_prob(self): dacn = DialogueActConfusionNetwork() dacn.add(0.2, DialogueActItem(dai='inform(food=chinese)')) dacn.add(0.7, DialogueActItem(dai='inform(food=czech)')) dacn.add(0.1, DialogueActItem(dai='inform(food=russian)')) self.assertAlmostEqual(dacn._get_prob([0, 1, 1]), 0.2 * 0.3 * 0.9) self.assertAlmostEqual(dacn._get_prob([0, 0, 0]), 0.2 * 0.7 * 0.1)
[docs] def test_get_da_nblist(self): # Simple case with one good hypothesis. dacn = DialogueActConfusionNetwork() dacn.add(0.05, DialogueActItem(dai='inform(food=chinese)')) dacn.add(0.9, DialogueActItem(dai='inform(food=czech)')) dacn.add(0.05, DialogueActItem(dai='inform(food=russian)')) nblist = dacn.get_da_nblist() best_da = nblist.get_best_da() expected_da = DialogueAct(da_str='inform(food=czech)') self.assertEqual(best_da, expected_da) # More good hypotheses dacn = DialogueActConfusionNetwork() dacn.add(0.05, DialogueActItem(dai='inform(food=chinese)')) dacn.add(0.9, DialogueActItem(dai='inform(food=czech)')) dacn.add(0.9, DialogueActItem(dai='inform(food=russian)')) nblist = dacn.get_da_nblist() best_da = nblist.get_best_da() expected_da = DialogueAct(da_str='inform(food=czech)&inform(food=russian)') self.assertEqual(best_da, expected_da)
[docs] def test_prune(self): dacn = DialogueActConfusionNetwork() dacn.add(0.05, DialogueActItem(dai='inform(food=chinese)')) dacn.add(0.9, DialogueActItem(dai='inform(food=czech)')) dacn.add(0.00005, DialogueActItem(dai='inform(food=russian)')) # Russian food should be pruned. self.assertEqual(len(dacn), 3) dacn.prune() self.assertEqual(len(dacn), 2) self.assertTrue(not DialogueActItem(dai='inform(food=russian)') in dacn)
[docs] def test_normalise(self): dacn = DialogueActConfusionNetwork() dacn.add(0.05, DialogueActItem(dai='inform(food=chinese)')) dacn.add(1.9, DialogueActItem(dai='inform(food=czech)')) dacn.add(0.00005, DialogueActItem(dai='inform(food=russian)')) self.assertRaises(ConfusionNetworkException, dacn.normalise)
[docs] def test_sort(self): dacn = DialogueActConfusionNetwork() dacn.add(0.05, DialogueActItem(dai='inform(food=chinese)')) dacn.add(1.0, DialogueActItem(dai='inform(food=czech)')) dacn.add(0.00005, DialogueActItem(dai='inform(food=russian)')) dacn.sort() cn = list(dacn) self.assertEqual(cn[0][1], DialogueActItem(dai='inform(food=czech)')) self.assertEqual(cn[1][1], DialogueActItem(dai='inform(food=chinese)')) self.assertEqual(cn[2][1], DialogueActItem(dai='inform(food=russian)'))
[docs] def test_make_from_da(self): da = DialogueAct('inform(food=czech)&inform(area=north)') dacn = DialogueActConfusionNetwork.make_from_da(da) self.assertEqual(dacn.get_best_da(), da)
[docs] def test_merge(self): dacn = DialogueActConfusionNetwork() dacn.add(0.05, DialogueActItem(dai='inform(food=chinese)')) dacn.add(0.9, DialogueActItem(dai='inform(food=czech)')) dacn.add(0.00005, DialogueActItem(dai='inform(food=russian)')) dacn.merge(dacn, combine='max') # Russian food should be pruned. dacn.sort().prune() self.assertTrue(not DialogueActItem(dai='inform(food=russian)') in dacn)
if __name__ == '__main__': unittest.main()