Source code for alex.components.dm.dstc_tracker

#!/usr/bin/env python
# encoding: utf8

"""
.. module:: dstc_tracker
    :synopsis: Discriminative tracker that was used for DSTC2012.

.. moduleauthor:: Lukas Zilka <lukas@zilka.me>
"""
if __name__ == '__main__':
    import autopath

from collections import defaultdict

from alex.components.dm.pstate import PDDiscrete, PDDiscreteOther
from alex.components.dm.tracker import StateTracker
from alex.components.slu.da import DialogueActConfusionNetwork, DialogueActItem

NOTHING_DENIED = "@N@"
NO_VALUE = None

[docs]class DSTCState(object): """Represents state of the tracker.""" def __init__(self, slots): """Initialise state that has given slots. Arguments: - slots: list of slot names (strings)""" self.slots = slots # each slot is a distribution over its values self.values = {} for slot in slots: self.values[slot] = PDDiscrete() def __getitem__(self, item): """Get distribution for given slot.""" return self.values[item] def __setitem__(self, item, value): """Set distribution for given slot.""" self.values[item] = value
[docs] def pprint(self): """Pretty-print self.""" res = [] for slot in self.slots: val = ' |%s| ' % slot val += str(self.values[slot]) res.append(val) return "\n".join(res)
def __str__(self): return self.pprint()
[docs]class ExtendedSlotUpdater(object): """Updater of state given observation and deny distributions.""" @classmethod
[docs] def update_slot(cls, curr_pd, observ_pd, deny_pd): new_pd = PDDiscrete() # initialize result # find out which items need to be computed observed_items = observ_pd.get_items() observed_items += deny_pd.get_items() items = set(curr_pd.get_items() + observed_items) for item in items: # compute probability of item according to the formula: # p_{t}(item) = (p_{t-1}(item)*p(None) + p(item)) * (1-p_deny(item)) + # + (1-p_deny(item)*p(item)-p_deny(nothing_denied)) / N # where N is the total number of items in the distribution (note that this # is usually not equal to items actually represented in the distribution, # as we do not explicitely represent items that we have not seen) new_pd[item] = curr_pd[item] * observ_pd[NO_VALUE] if item is not NO_VALUE: new_pd[item] += observ_pd[item] new_pd[item] *= (1-deny_pd[item]) if item is not NOTHING_DENIED: new_pd[item] += (1 - deny_pd[item] * curr_pd[item] - deny_pd[NOTHING_DENIED]) / (max(len(items), deny_pd.space_size - 1)) return new_pd
[docs]class DSTCTracker(StateTracker): """Represents simple deterministic DSTC state tracker.""" state_class = DSTCState def __init__(self, slots, default_space_size=defaultdict(lambda: 100)): super(DSTCTracker, self).__init__() self.slots = slots self.default_space_size = default_space_size self.values = {} for slot in slots: self.values[slot] = PDDiscrete()
[docs] def update_state(self, state, cn): # initialize distributions used for computing distributions from the confusion network inform_slot_distr = defaultdict(PDDiscrete) deny_slot_distr = defaultdict(PDDiscreteOther) for slot in self.slots: inform_slot_distr[slot] = PDDiscrete() deny_slot_distr[slot] = PDDiscreteOther(space_size=self.default_space_size[slot]) deny_slot_distr[slot][NO_VALUE] = 0.0 deny_slot_distr[slot][NOTHING_DENIED] = 1.0 # go through confusion network and add-up probabilities into the inform and deny # distributions (they as if represent scores for particular items; we normalize # afterwards) sum_inform = defaultdict(float) sum_deny = defaultdict(float) for p, dai in cn: if dai.dat == "inform": inform_slot_distr[dai.name][dai.value] = p sum_inform[dai.name] += p elif dai.dat == "deny": deny_slot_distr[dai.name][dai.value] = p sum_deny[dai.name] += p # update each slot independently according to the distributions from conf.net. for slot in state.slots: inform_slot_distr[slot][NO_VALUE] = max(0.0, 1 - sum_inform[slot]) inform_slot_distr[slot].normalize() deny_slot_distr[slot][NOTHING_DENIED] = max(0.0, 1 - sum_deny[slot]) deny_slot_distr[slot].normalize() state[slot] = ExtendedSlotUpdater.update_slot(state[slot], inform_slot_distr[slot], deny_slot_distr[slot])
[docs]def main(): # initialize tracker and state slots = ["food", "location"] tr = DSTCTracker(slots) state = DSTCState(slots) state.pprint() # try to update state with some information print '---' cn = DialogueActConfusionNetwork() cn.add(0.3, DialogueActItem("inform", "food", "chinese")) cn.add(0.1, DialogueActItem("inform", "food", "indian")) tr.update_state(state, cn) state.pprint() # try to deny some information print '---' cn.add(0.9, DialogueActItem("deny", "food", "chinese")) cn.add(0.1, DialogueActItem("deny", "food", "indian")) tr.update_state(state, cn) state.pprint()
if __name__ == '__main__': main()