Package nltk :: Package metrics :: Module confusionmatrix
[hide private]
[frames] | no frames]

Source Code for Module nltk.metrics.confusionmatrix

  1  # Natural Language Toolkit: Confusion Matrices 
  2  # 
  3  # Copyright (C) 2001-2011 NLTK Project 
  4  # Author: Edward Loper <edloper@gradient.cis.upenn.edu> 
  5  #         Steven Bird <sb@csse.unimelb.edu.au> 
  6  # URL: <http://www.nltk.org/> 
  7  # For license information, see LICENSE.TXT 
  8   
9 -class ConfusionMatrix(object):
10 """ 11 The confusion matrix between a list of reference values and a 12 corresponding list of test values. Entry [M{r},M{t}] of this 13 matrix is a count of the number of times that the reference value 14 M{r} corresponds to the test value M{t}. E.g.: 15 16 >>> ref = 'DET NN VB DET JJ NN NN IN DET NN'.split() 17 >>> test = 'DET VB VB DET NN NN NN IN DET NN'.split() 18 >>> cm = ConfusionMatrix(ref, test) 19 >>> print cm['NN', 'NN'] 20 3 21 22 Note that the diagonal entries (M{Ri}=M{Tj}) of this matrix 23 corresponds to correct values; and the off-diagonal entries 24 correspond to incorrect values. 25 """ 26
27 - def __init__(self, reference, test, sort_by_count=False):
28 """ 29 Construct a new confusion matrix from a list of reference 30 values and a corresponding list of test values. 31 32 @type reference: C{list} 33 @param reference: An ordered list of reference values. 34 @type test: C{list} 35 @param test: A list of values to compare against the 36 corresponding reference values. 37 @raise ValueError: If C{reference} and C{length} do not have 38 the same length. 39 """ 40 if len(reference) != len(test): 41 raise ValueError('Lists must have the same length.') 42 43 # Get a list of all values. 44 if sort_by_count: 45 ref_fdist = FreqDist(reference) 46 test_fdist = FreqDist(test) 47 def key(v): return -(ref_fdist[v]+test_fdist[v]) 48 values = sorted(set(reference+test), key=key) 49 else: 50 values = sorted(set(reference+test)) 51 52 # Construct a value->index dictionary 53 indices = dict((val,i) for (i,val) in enumerate(values)) 54 55 # Make a confusion matrix table. 56 confusion = [[0 for val in values] for val in values] 57 max_conf = 0 # Maximum confusion 58 for w,g in zip(reference, test): 59 confusion[indices[w]][indices[g]] += 1 60 max_conf = max(max_conf, confusion[indices[w]][indices[g]]) 61 62 #: A list of all values in C{reference} or C{test}. 63 self._values = values 64 #: A dictionary mapping values in L{self._values} to their indices. 65 self._indices = indices 66 #: The confusion matrix itself (as a list of lists of counts). 67 self._confusion = confusion 68 #: The greatest count in L{self._confusion} (used for printing). 69 self._max_conf = max_conf 70 #: The total number of values in the confusion matrix. 71 self._total = len(reference) 72 #: The number of correct (on-diagonal) values in the matrix. 73 self._correct = sum(confusion[i][i] for i in range(len(values)))
74
75 - def __getitem__(self, (li,lj)):
76 """ 77 @return: The number of times that value C{li} was expected and 78 value C{lj} was given. 79 @rtype: C{int} 80 """ 81 i = self._indices[li] 82 j = self._indices[lj] 83 return self._confusion[i][j]
84
85 - def __repr__(self):
86 return '<ConfusionMatrix: %s/%s correct>' % (self._correct, 87 self._total)
88
89 - def __str__(self):
90 return self.pp()
91
92 - def pp(self, show_percents=False, values_in_chart=True, 93 truncate=None, sort_by_count=False):
94 """ 95 @return: A multi-line string representation of this confusion 96 matrix. 97 @type truncate: int 98 @param truncate: If specified, then only show the specified 99 number of values. Any sorting (e.g., sort_by_count) 100 will be performed before truncation. 101 @param sort_by_count: If true, then sort by the count of each 102 label in the reference data. I.e., labels that occur more 103 frequently in the reference label will be towards the left 104 edge of the matrix, and labels that occur less frequently 105 will be towards the right edge. 106 @todo: add marginals? 107 """ 108 confusion = self._confusion 109 110 values = self._values 111 if sort_by_count: 112 values = sorted(values, key=lambda v: 113 -sum(self._confusion[self._indices[v]])) 114 115 if truncate: 116 values = values[:truncate] 117 118 if values_in_chart: 119 value_strings = [str(val) for val in values] 120 else: 121 value_strings = [str(n+1) for n in range(len(values))] 122 123 # Construct a format string for row values 124 valuelen = max(len(val) for val in value_strings) 125 value_format = '%' + `valuelen` + 's | ' 126 # Construct a format string for matrix entries 127 if show_percents: 128 entrylen = 6 129 entry_format = '%5.1f%%' 130 zerostr = ' .' 131 else: 132 entrylen = len(`self._max_conf`) 133 entry_format = '%' + `entrylen` + 'd' 134 zerostr = ' '*(entrylen-1) + '.' 135 136 # Write the column values. 137 s = '' 138 for i in range(valuelen): 139 s += (' '*valuelen)+' |' 140 for val in value_strings: 141 if i >= valuelen-len(val): 142 s += val[i-valuelen+len(val)].rjust(entrylen+1) 143 else: 144 s += ' '*(entrylen+1) 145 s += ' |\n' 146 147 # Write a dividing line 148 s += '%s-+-%s+\n' % ('-'*valuelen, '-'*((entrylen+1)*len(values))) 149 150 # Write the entries. 151 for val, li in zip(value_strings, values): 152 i = self._indices[li] 153 s += value_format % val 154 for lj in values: 155 j = self._indices[lj] 156 if confusion[i][j] == 0: 157 s += zerostr 158 elif show_percents: 159 s += entry_format % (100.0*confusion[i][j]/self._total) 160 else: 161 s += entry_format % confusion[i][j] 162 if i == j: 163 prevspace = s.rfind(' ') 164 s = s[:prevspace] + '<' + s[prevspace+1:] + '>' 165 else: s += ' ' 166 s += '|\n' 167 168 # Write a dividing line 169 s += '%s-+-%s+\n' % ('-'*valuelen, '-'*((entrylen+1)*len(values))) 170 171 # Write a key 172 s += '(row = reference; col = test)\n' 173 if not values_in_chart: 174 s += 'Value key:\n' 175 for i, value in enumerate(values): 176 s += '%6d: %s\n' % (i+1, value) 177 178 return s
179
180 - def key(self):
181 values = self._values 182 str = 'Value key:\n' 183 indexlen = len(`len(values)-1`) 184 key_format = ' %'+`indexlen`+'d: %s\n' 185 for i in range(len(values)): 186 str += key_format % (i, values[i]) 187 188 return str
189
190 -def demo():
191 reference = 'DET NN VB DET JJ NN NN IN DET NN'.split() 192 test = 'DET VB VB DET NN NN NN IN DET NN'.split() 193 print 'Reference =', reference 194 print 'Test =', test 195 print 'Confusion matrix:' 196 print ConfusionMatrix(reference, test) 197 print ConfusionMatrix(reference, test).pp(sort_by_count=True)
198 199 if __name__ == '__main__': 200 demo() 201