1
2
3
4
5
6
7
8
10 """
11 The confusion matrix between a list of reference values and a
12 corresponding list of test values. Entry [M{r},M{t}] of this
13 matrix is a count of the number of times that the reference value
14 M{r} corresponds to the test value M{t}. E.g.:
15
16 >>> ref = 'DET NN VB DET JJ NN NN IN DET NN'.split()
17 >>> test = 'DET VB VB DET NN NN NN IN DET NN'.split()
18 >>> cm = ConfusionMatrix(ref, test)
19 >>> print cm['NN', 'NN']
20 3
21
22 Note that the diagonal entries (M{Ri}=M{Tj}) of this matrix
23 corresponds to correct values; and the off-diagonal entries
24 correspond to incorrect values.
25 """
26
27 - def __init__(self, reference, test, sort_by_count=False):
28 """
29 Construct a new confusion matrix from a list of reference
30 values and a corresponding list of test values.
31
32 @type reference: C{list}
33 @param reference: An ordered list of reference values.
34 @type test: C{list}
35 @param test: A list of values to compare against the
36 corresponding reference values.
37 @raise ValueError: If C{reference} and C{length} do not have
38 the same length.
39 """
40 if len(reference) != len(test):
41 raise ValueError('Lists must have the same length.')
42
43
44 if sort_by_count:
45 ref_fdist = FreqDist(reference)
46 test_fdist = FreqDist(test)
47 def key(v): return -(ref_fdist[v]+test_fdist[v])
48 values = sorted(set(reference+test), key=key)
49 else:
50 values = sorted(set(reference+test))
51
52
53 indices = dict((val,i) for (i,val) in enumerate(values))
54
55
56 confusion = [[0 for val in values] for val in values]
57 max_conf = 0
58 for w,g in zip(reference, test):
59 confusion[indices[w]][indices[g]] += 1
60 max_conf = max(max_conf, confusion[indices[w]][indices[g]])
61
62
63 self._values = values
64
65 self._indices = indices
66
67 self._confusion = confusion
68
69 self._max_conf = max_conf
70
71 self._total = len(reference)
72
73 self._correct = sum(confusion[i][i] for i in range(len(values)))
74
76 """
77 @return: The number of times that value C{li} was expected and
78 value C{lj} was given.
79 @rtype: C{int}
80 """
81 i = self._indices[li]
82 j = self._indices[lj]
83 return self._confusion[i][j]
84
86 return '<ConfusionMatrix: %s/%s correct>' % (self._correct,
87 self._total)
88
91
92 - def pp(self, show_percents=False, values_in_chart=True,
93 truncate=None, sort_by_count=False):
94 """
95 @return: A multi-line string representation of this confusion
96 matrix.
97 @type truncate: int
98 @param truncate: If specified, then only show the specified
99 number of values. Any sorting (e.g., sort_by_count)
100 will be performed before truncation.
101 @param sort_by_count: If true, then sort by the count of each
102 label in the reference data. I.e., labels that occur more
103 frequently in the reference label will be towards the left
104 edge of the matrix, and labels that occur less frequently
105 will be towards the right edge.
106 @todo: add marginals?
107 """
108 confusion = self._confusion
109
110 values = self._values
111 if sort_by_count:
112 values = sorted(values, key=lambda v:
113 -sum(self._confusion[self._indices[v]]))
114
115 if truncate:
116 values = values[:truncate]
117
118 if values_in_chart:
119 value_strings = [str(val) for val in values]
120 else:
121 value_strings = [str(n+1) for n in range(len(values))]
122
123
124 valuelen = max(len(val) for val in value_strings)
125 value_format = '%' + `valuelen` + 's | '
126
127 if show_percents:
128 entrylen = 6
129 entry_format = '%5.1f%%'
130 zerostr = ' .'
131 else:
132 entrylen = len(`self._max_conf`)
133 entry_format = '%' + `entrylen` + 'd'
134 zerostr = ' '*(entrylen-1) + '.'
135
136
137 s = ''
138 for i in range(valuelen):
139 s += (' '*valuelen)+' |'
140 for val in value_strings:
141 if i >= valuelen-len(val):
142 s += val[i-valuelen+len(val)].rjust(entrylen+1)
143 else:
144 s += ' '*(entrylen+1)
145 s += ' |\n'
146
147
148 s += '%s-+-%s+\n' % ('-'*valuelen, '-'*((entrylen+1)*len(values)))
149
150
151 for val, li in zip(value_strings, values):
152 i = self._indices[li]
153 s += value_format % val
154 for lj in values:
155 j = self._indices[lj]
156 if confusion[i][j] == 0:
157 s += zerostr
158 elif show_percents:
159 s += entry_format % (100.0*confusion[i][j]/self._total)
160 else:
161 s += entry_format % confusion[i][j]
162 if i == j:
163 prevspace = s.rfind(' ')
164 s = s[:prevspace] + '<' + s[prevspace+1:] + '>'
165 else: s += ' '
166 s += '|\n'
167
168
169 s += '%s-+-%s+\n' % ('-'*valuelen, '-'*((entrylen+1)*len(values)))
170
171
172 s += '(row = reference; col = test)\n'
173 if not values_in_chart:
174 s += 'Value key:\n'
175 for i, value in enumerate(values):
176 s += '%6d: %s\n' % (i+1, value)
177
178 return s
179
189
191 reference = 'DET NN VB DET JJ NN NN IN DET NN'.split()
192 test = 'DET VB VB DET NN NN NN IN DET NN'.split()
193 print 'Reference =', reference
194 print 'Test =', test
195 print 'Confusion matrix:'
196 print ConfusionMatrix(reference, test)
197 print ConfusionMatrix(reference, test).pp(sort_by_count=True)
198
199 if __name__ == '__main__':
200 demo()
201