Package nltk :: Package tag :: Module brill
[hide private]
[frames] | no frames]

Source Code for Module nltk.tag.brill

   1  # Natural Language Toolkit: Brill Tagger 
   2  # 
   3  # Copyright (C) 2001-2010 NLTK Project 
   4  # Authors: Christopher Maloof <cjmaloof@gradient.cis.upenn.edu> 
   5  #          Edward Loper <edloper@gradient.cis.upenn.edu> 
   6  #          Steven Bird <sb@csse.unimelb.edu.au> 
   7  # URL: <http://www.nltk.org/> 
   8  # For license information, see LICENSE.TXT 
   9   
  10  """ 
  11  Brill's transformational rule-based tagger. 
  12  """ 
  13   
  14  import bisect        # for binary search through a subset of indices 
  15  import random        # for shuffling WSJ files 
  16  import yaml          # to save and load taggers in files 
  17  import textwrap 
  18   
  19  from nltk.compat import defaultdict 
  20   
  21  from util import untag 
  22  from api import * 
23 24 ###################################################################### 25 ## The Brill Tagger 26 ###################################################################### 27 28 -class BrillTagger(TaggerI, yaml.YAMLObject):
29 """ 30 Brill's transformational rule-based tagger. Brill taggers use an 31 X{initial tagger} (such as L{tag.DefaultTagger}) to assign an initial 32 tag sequence to a text; and then apply an ordered list of 33 transformational rules to correct the tags of individual tokens. 34 These transformation rules are specified by the L{BrillRule} 35 interface. 36 37 Brill taggers can be created directly, from an initial tagger and 38 a list of transformational rules; but more often, Brill taggers 39 are created by learning rules from a training corpus, using either 40 L{BrillTaggerTrainer} or L{FastBrillTaggerTrainer}. 41 """ 42 43 yaml_tag = '!nltk.BrillTagger'
44 - def __init__(self, initial_tagger, rules):
45 """ 46 @param initial_tagger: The initial tagger 47 @type initial_tagger: L{TaggerI} 48 @param rules: An ordered list of transformation rules that 49 should be used to correct the initial tagging. 50 @type rules: C{list} of L{BrillRule} 51 """ 52 self._initial_tagger = initial_tagger 53 self._rules = tuple(rules)
54
55 - def rules(self):
56 return self._rules
57
58 - def tag(self, tokens):
59 # Inherit documentation from TaggerI 60 61 # Run the initial tagger. 62 tagged_tokens = self._initial_tagger.tag(tokens) 63 64 # Create a dictionary that maps each tag to a list of the 65 # indices of tokens that have that tag. 66 tag_to_positions = defaultdict(set) 67 for i, (token, tag) in enumerate(tagged_tokens): 68 tag_to_positions[tag].add(i) 69 70 # Apply each rule, in order. Only try to apply rules at 71 # positions that have the desired original tag. 72 for rule in self._rules: 73 # Find the positions where it might apply 74 positions = tag_to_positions.get(rule.original_tag, []) 75 # Apply the rule at those positions. 76 changed = rule.apply(tagged_tokens, positions) 77 # Update tag_to_positions with the positions of tags that 78 # were modified. 79 for i in changed: 80 tag_to_positions[rule.original_tag].remove(i) 81 tag_to_positions[rule.replacement_tag].add(i) 82 83 return tagged_tokens
84
85 ###################################################################### 86 ## Brill Rules 87 ###################################################################### 88 89 -class BrillRule(yaml.YAMLObject):
90 """ 91 An interface for tag transformations on a tagged corpus, as 92 performed by brill taggers. Each transformation finds all tokens 93 in the corpus that are tagged with a specific X{original tag} and 94 satisfy a specific X{condition}, and replaces their tags with a 95 X{replacement tag}. For any given transformation, the original 96 tag, replacement tag, and condition are fixed. Conditions may 97 depend on the token under consideration, as well as any other 98 tokens in the corpus. 99 100 Brill rules must be comparable and hashable. 101 """
102 - def __init__(self, original_tag, replacement_tag):
103 assert self.__class__ != BrillRule, \ 104 "BrillRule is an abstract base class" 105 106 self.original_tag = original_tag 107 """The tag which this C{BrillRule} may cause to be replaced.""" 108 109 self.replacement_tag = replacement_tag 110 """The tag with which this C{BrillRule} may replace another tag."""
111
112 - def apply(self, tokens, positions=None):
113 """ 114 Apply this rule at every position in C{positions} where it 115 applies to the given sentence. I.e., for each position M{p} 116 in C{positions}, if C{tokens[M{p}]} is tagged with this rule's 117 original tag, and satisfies this rule's condition, then set 118 its tag to be this rule's replacement tag. 119 120 @param tokens: The tagged sentence 121 @type tokens: list of Token 122 @type positions: C{list} of C{int} 123 @param positions: The positions where the transformation is to 124 be tried. If not specified, try it at all positions. 125 @return: The indices of tokens whose tags were changed by this 126 rule. 127 @rtype: C{int} 128 """ 129 if positions is None: 130 positions = range(len(tokens)) 131 132 # Determine the indices at which this rule applies. 133 change = [i for i in positions if self.applies(tokens, i)] 134 135 # Make the changes. Note: this must be done in a separate 136 # step from finding applicable locations, since we don't want 137 # the rule to interact with itself. 138 for i in change: 139 tokens[i] = (tokens[i][0], self.replacement_tag) 140 141 return change
142
143 - def applies(self, tokens, index):
144 """ 145 @return: True if the rule would change the tag of 146 C{tokens[index]}, False otherwise 147 @rtype: Boolean 148 149 @param tokens: A tagged sentence 150 @type tokens: C{list} of C{str} 151 @param index: The index to check 152 @type index: C{int} 153 """ 154 assert False, "Brill rules must define applies()"
155 156 # Rules must be comparable and hashable for the algorithm to work
157 - def __eq__(self):
158 assert False, "Brill rules must be comparable"
159 - def __ne__(self):
160 assert False, "Brill rules must be comparable"
161 - def __hash__(self):
162 assert False, "Brill rules must be hashable"
163
164 165 -class ProximateTokensRule(BrillRule):
166 """ 167 An abstract base class for brill rules whose condition checks for 168 the presence of tokens with given properties at given ranges of 169 positions, relative to the token. 170 171 Each subclass of proximate tokens brill rule defines a method 172 M{extract_property}, which extracts a specific property from the 173 the token, such as its text or tag. Each instance is 174 parameterized by a set of tuples, specifying ranges of positions 175 and property values to check for in those ranges: 176 177 - (M{start}, M{end}, M{value}) 178 179 The brill rule is then applicable to the M{n}th token iff: 180 181 - The M{n}th token is tagged with the rule's original tag; and 182 - For each (M{start}, M{end}, M{value}) triple: 183 - The property value of at least one token between 184 M{n+start} and M{n+end} (inclusive) is M{value}. 185 186 For example, a proximate token brill template with M{start=end=-1} 187 generates rules that check just the property of the preceding 188 token. Note that multiple properties may be included in a single 189 rule; the rule applies if they all hold. 190 """ 191
192 - def __init__(self, original_tag, replacement_tag, *conditions):
193 """ 194 Construct a new brill rule that changes a token's tag from 195 C{original_tag} to C{replacement_tag} if all of the properties 196 specified in C{conditions} hold. 197 198 @type conditions: C{tuple} of C{(int, int, *)} 199 @param conditions: A list of 3-tuples C{(start, end, value)}, 200 each of which specifies that the property of at least one 201 token between M{n}+C{start} and M{n}+C{end} (inclusive) is 202 C{value}. 203 @raise ValueError: If C{start}>C{end} for any condition. 204 """ 205 assert self.__class__ != ProximateTokensRule, \ 206 "ProximateTokensRule is an abstract base class" 207 BrillRule.__init__(self, original_tag, replacement_tag) 208 self._conditions = conditions 209 for (s,e,v) in conditions: 210 if s>e: 211 raise ValueError('Condition %s has an invalid range' % 212 ((s,e,v),))
213 214 # Make Brill rules look nice in YAML. 215 @classmethod
216 - def to_yaml(cls, dumper, data):
217 node = dumper.represent_mapping(cls.yaml_tag, dict( 218 description=str(data), 219 conditions=list(list(x) for x in data._conditions), 220 original=data.original_tag, 221 replacement=data.replacement_tag)) 222 return node
223 @classmethod
224 - def from_yaml(cls, loader, node):
225 map = loader.construct_mapping(node, deep=True) 226 return cls(map['original'], map['replacement'], 227 *(tuple(x) for x in map['conditions']))
228 229 @staticmethod
230 - def extract_property(token):
231 """ 232 Returns some property characterizing this token, such as its 233 base lexical item or its tag. 234 235 Each implentation of this method should correspond to an 236 implementation of the method with the same name in a subclass 237 of L{ProximateTokensTemplate}. 238 239 @param token: The token 240 @type token: Token 241 @return: The property 242 @rtype: any 243 """ 244 assert False, "ProximateTokenRules must define extract_property()"
245
246 - def applies(self, tokens, index):
247 # Inherit docs from BrillRule 248 249 # Does the given token have this rule's "original tag"? 250 if tokens[index][1] != self.original_tag: 251 return False 252 253 # Check to make sure that every condition holds. 254 for (start, end, val) in self._conditions: 255 # Find the (absolute) start and end indices. 256 s = max(0, index+start) 257 e = min(index+end+1, len(tokens)) 258 259 # Look for *any* token that satisfies the condition. 260 for i in range(s, e): 261 if self.extract_property(tokens[i]) == val: 262 break 263 else: 264 # No token satisfied the condition; return false. 265 return False 266 267 # Every condition checked out, so the rule is applicable. 268 return True
269
270 - def __eq__(self, other):
271 return (self is other or 272 (other is not None and 273 other.__class__ == self.__class__ and 274 self.original_tag == other.original_tag and 275 self.replacement_tag == other.replacement_tag and 276 self._conditions == other._conditions))
277
278 - def __ne__(self, other):
279 return not (self==other)
280
281 - def __hash__(self):
282 # Cache our hash value (justified by profiling.) 283 try: 284 return self.__hash 285 except: 286 self.__hash = hash( (self.original_tag, self.replacement_tag, 287 self._conditions, self.__class__.__name__) ) 288 return self.__hash
289
290 - def __repr__(self):
291 # Cache our repr (justified by profiling -- this is used as 292 # a sort key when deterministic=True.) 293 try: 294 return self.__repr 295 except: 296 conditions = ' and '.join(['%s in %d...%d' % (v,s,e) 297 for (s,e,v) in self._conditions]) 298 self.__repr = ('<%s: %s->%s if %s>' % 299 (self.__class__.__name__, self.original_tag, 300 self.replacement_tag, conditions)) 301 return self.__repr
302 303
304 - def __str__(self):
305 replacement = '%s -> %s' % (self.original_tag, 306 self.replacement_tag) 307 if len(self._conditions) == 0: 308 conditions = '' 309 else: 310 conditions = ' if '+ ', and '.join([self._condition_to_str(c) 311 for c in self._conditions]) 312 return replacement+conditions
313
314 - def _condition_to_str(self, condition):
315 """ 316 Return a string representation of the given condition. 317 This helper method is used by L{__str__}. 318 """ 319 (start, end, value) = condition 320 return ('the %s of %s is %r' % 321 (self.PROPERTY_NAME, self._range_to_str(start, end), value))
322
323 - def _range_to_str(self, start, end):
324 """ 325 Return a string representation for the given range. This 326 helper method is used by L{__str__}. 327 """ 328 if start == end == 0: 329 return 'this word' 330 if start == end == -1: 331 return 'the preceding word' 332 elif start == end == 1: 333 return 'the following word' 334 elif start == end and start < 0: 335 return 'word i-%d' % -start 336 elif start == end and start > 0: 337 return 'word i+%d' % start 338 else: 339 if start >= 0: start = '+%d' % start 340 if end >= 0: end = '+%d' % end 341 return 'words i%s...i%s' % (start, end)
342
343 -class ProximateTagsRule(ProximateTokensRule):
344 """ 345 A rule which examines the tags of nearby tokens. 346 @see: superclass L{ProximateTokensRule} for details. 347 @see: L{SymmetricProximateTokensTemplate}, which generates these rules. 348 """ 349 PROPERTY_NAME = 'tag' # for printing. 350 yaml_tag = '!ProximateTagsRule' 351 @staticmethod
352 - def extract_property(token):
353 """@return: The given token's tag.""" 354 return token[1]
355
356 -class ProximateWordsRule(ProximateTokensRule):
357 """ 358 A rule which examines the base types of nearby tokens. 359 @see: L{ProximateTokensRule} for details. 360 @see: L{SymmetricProximateTokensTemplate}, which generates these rules. 361 """ 362 PROPERTY_NAME = 'text' # for printing. 363 yaml_tag = '!ProximateWordsRule' 364 @staticmethod
365 - def extract_property(token):
366 """@return: The given token's text.""" 367 return token[0]
368
369 ###################################################################### 370 ## Brill Templates 371 ###################################################################### 372 373 -class BrillTemplateI(object):
374 """ 375 An interface for generating lists of transformational rules that 376 apply at given sentence positions. C{BrillTemplateI} is used by 377 C{Brill} training algorithms to generate candidate rules. 378 """
379 - def __init__(self):
380 raise AssertionError, "BrillTemplateI is an abstract interface"
381
382 - def applicable_rules(self, tokens, i, correctTag):
383 """ 384 Return a list of the transformational rules that would correct 385 the C{i}th subtoken's tag in the given token. In particular, 386 return a list of zero or more rules that would change 387 C{tagged_tokens[i][1]} to C{correctTag}, if applied 388 to C{token}. 389 390 If the C{i}th subtoken already has the correct tag (i.e., if 391 C{tagged_tokens[i][1]} == C{correctTag}), then 392 C{applicable_rules} should return the empty list. 393 394 @param tokens: The tagged tokens being tagged. 395 @type tokens: C{list} of C{tuple} 396 @param i: The index of the token whose tag should be corrected. 397 @type i: C{int} 398 @param correctTag: The correct tag for the C{i}th token. 399 @type correctTag: (any) 400 @rtype: C{list} of L{BrillRule} 401 """ 402 raise AssertionError, "BrillTemplateI is an abstract interface"
403
404 - def get_neighborhood(self, token, index):
405 """ 406 Returns the set of indices C{i} such that 407 C{applicable_rules(token, i, ...)} depends on the value of 408 the C{index}th subtoken of C{token}. 409 410 This method is used by the \"fast\" Brill tagger trainer. 411 412 @param token: The tokens being tagged. 413 @type token: C{list} of C{tuple} 414 @param index: The index whose neighborhood should be returned. 415 @type index: C{int} 416 @rtype: C{Set} 417 """ 418 raise AssertionError, "BrillTemplateI is an abstract interface"
419
420 -class ProximateTokensTemplate(BrillTemplateI):
421 """ 422 An brill templates that generates a list of 423 L{ProximateTokensRule}s that apply at a given sentence 424 position. In particular, each C{ProximateTokensTemplate} is 425 parameterized by a proximate token brill rule class and a list of 426 boundaries, and generates all rules that: 427 428 - use the given brill rule class 429 - use the given list of boundaries as the C{start} and C{end} 430 points for their conditions 431 - are applicable to the given token. 432 """
433 - def __init__(self, rule_class, *boundaries):
434 """ 435 Construct a template for generating proximate token brill 436 rules. 437 438 @type rule_class: C{class} 439 @param rule_class: The proximate token brill rule class that 440 should be used to generate new rules. This class must be a 441 subclass of L{ProximateTokensRule}. 442 @type boundaries: C{tuple} of C{(int, int)} 443 @param boundaries: A list of tuples C{(start, end)}, each of 444 which specifies a range for which a condition should be 445 created by each rule. 446 @raise ValueError: If C{start}>C{end} for any boundary. 447 """ 448 self._rule_class = rule_class 449 self._boundaries = boundaries 450 for (s,e) in boundaries: 451 if s>e: 452 raise ValueError('Boundary %s has an invalid range' % 453 ((s,e),))
454
455 - def applicable_rules(self, tokens, index, correct_tag):
456 if tokens[index][1] == correct_tag: 457 return [] 458 459 # For each of this template's boundaries, Find the conditions 460 # that are applicable for the given token. 461 applicable_conditions = \ 462 [self._applicable_conditions(tokens, index, start, end) 463 for (start, end) in self._boundaries] 464 465 # Find all combinations of these applicable conditions. E.g., 466 # if applicable_conditions=[[A,B], [C,D]], then this will 467 # generate [[A,C], [A,D], [B,C], [B,D]]. 468 condition_combos = [[]] 469 for conditions in applicable_conditions: 470 condition_combos = [old_conditions+[new_condition] 471 for old_conditions in condition_combos 472 for new_condition in conditions] 473 474 # Translate the condition sets into rules. 475 return [self._rule_class(tokens[index][1], correct_tag, *conds) 476 for conds in condition_combos]
477
478 - def _applicable_conditions(self, tokens, index, start, end):
479 """ 480 @return: A set of all conditions for proximate token rules 481 that are applicable to C{tokens[index]}, given boundaries of 482 C{(start, end)}. I.e., return a list of all tuples C{(start, 483 end, M{value})}, such the property value of at least one token 484 between M{index+start} and M{index+end} (inclusive) is 485 M{value}. 486 """ 487 conditions = [] 488 s = max(0, index+start) 489 e = min(index+end+1, len(tokens)) 490 for i in range(s, e): 491 value = self._rule_class.extract_property(tokens[i]) 492 conditions.append( (start, end, value) ) 493 return conditions
494
495 - def get_neighborhood(self, tokens, index):
496 # inherit docs from BrillTemplateI 497 498 # applicable_rules(tokens, index, ...) depends on index. 499 neighborhood = set([index]) 500 501 # applicable_rules(tokens, i, ...) depends on index if 502 # i+start < index <= i+end. 503 for (start, end) in self._boundaries: 504 s = max(0, index+(-end)) 505 e = min(index+(-start)+1, len(tokens)) 506 for i in range(s, e): 507 neighborhood.add(i) 508 509 return neighborhood
510
511 -class SymmetricProximateTokensTemplate(BrillTemplateI):
512 """ 513 Simulates two L{ProximateTokensTemplate}s which are symmetric 514 across the location of the token. For rules of the form \"If the 515 M{n}th token is tagged C{A}, and any tag preceding B{or} following 516 the M{n}th token by a distance between M{x} and M{y} is C{B}, and 517 ... , then change the tag of the nth token from C{A} to C{C}.\" 518 519 One C{ProximateTokensTemplate} is formed by passing in the 520 same arguments given to this class's constructor: tuples 521 representing intervals in which a tag may be found. The other 522 C{ProximateTokensTemplate} is constructed with the negative 523 of all the arguments in reversed order. For example, a 524 C{SymmetricProximateTokensTemplate} using the pair (-2,-1) and the 525 constructor C{SymmetricProximateTokensTemplate} generates the same rules as a 526 C{SymmetricProximateTokensTemplate} using (-2,-1) plus a second 527 C{SymmetricProximateTokensTemplate} using (1,2). 528 529 This is useful because we typically don't want templates to 530 specify only \"following\" or only \"preceding\"; we'd like our 531 rules to be able to look in either direction. 532 """
533 - def __init__(self, rule_class, *boundaries):
534 """ 535 Construct a template for generating proximate token brill 536 rules. 537 538 @type rule_class: C{class} 539 @param rule_class: The proximate token brill rule class that 540 should be used to generate new rules. This class must be a 541 subclass of L{ProximateTokensRule}. 542 @type boundaries: C{tuple} of C{(int, int)} 543 @param boundaries: A list of tuples C{(start, end)}, each of 544 which specifies a range for which a condition should be 545 created by each rule. 546 @raise ValueError: If C{start}>C{end} for any boundary. 547 """ 548 self._ptt1 = ProximateTokensTemplate(rule_class, *boundaries) 549 reversed = [(-e,-s) for (s,e) in boundaries] 550 self._ptt2 = ProximateTokensTemplate(rule_class, *reversed)
551 552 # Generates lists of a subtype of ProximateTokensRule.
553 - def applicable_rules(self, tokens, index, correctTag):
554 """ 555 See L{BrillTemplateI} for full specifications. 556 557 @rtype: list of ProximateTokensRule 558 """ 559 return (self._ptt1.applicable_rules(tokens, index, correctTag) + 560 self._ptt2.applicable_rules(tokens, index, correctTag))
561
562 - def get_neighborhood(self, tokens, index):
563 # inherit docs from BrillTemplateI 564 n1 = self._ptt1.get_neighborhood(tokens, index) 565 n2 = self._ptt2.get_neighborhood(tokens, index) 566 return n1.union(n2)
567
568 ###################################################################### 569 ## Brill Tagger Trainer 570 ###################################################################### 571 572 -class BrillTaggerTrainer(object):
573 """ 574 A trainer for brill taggers. 575 """
576 - def __init__(self, initial_tagger, templates, trace=0, 577 deterministic=None):
578 """ 579 @param deterministic: If true, then choose between rules that 580 have the same score by picking the one whose __repr__ 581 is lexicographically smaller. If false, then just pick the 582 first rule we find with a given score -- this will depend 583 on the order in which keys are returned from dictionaries, 584 and so may not be the same from one run to the next. If 585 not specified, treat as true iff trace > 0. 586 """ 587 if deterministic is None: deterministic = (trace > 0) 588 self._initial_tagger = initial_tagger 589 self._templates = templates 590 self._trace = trace 591 self._deterministic = deterministic
592 593 #//////////////////////////////////////////////////////////// 594 # Training 595 #//////////////////////////////////////////////////////////// 596
597 - def train(self, train_sents, max_rules=200, min_score=2):
598 """ 599 Trains the Brill tagger on the corpus C{train_token}, 600 producing at most C{max_rules} transformations, each of which 601 reduces the net number of errors in the corpus by at least 602 C{min_score}. 603 604 @type train_sents: C{list} of C{list} of L{tuple} 605 @param train_sents: The corpus of tagged tokens 606 @type max_rules: C{int} 607 @param max_rules: The maximum number of transformations to be created 608 @type min_score: C{int} 609 @param min_score: The minimum acceptable net error reduction 610 that each transformation must produce in the corpus. 611 """ 612 if self._trace > 0: print ("Training Brill tagger on %d " 613 "sentences..." % len(train_sents)) 614 615 # Create a new copy of the training corpus, and run the 616 # initial tagger on it. We will progressively update this 617 # test corpus to look more like the training corpus. 618 test_sents = [self._initial_tagger.tag(untag(sent)) 619 for sent in train_sents] 620 621 if self._trace > 2: self._trace_header() 622 623 # Look for useful rules. 624 rules = [] 625 try: 626 while len(rules) < max_rules: 627 (rule, score, fixscore) = self._best_rule(test_sents, 628 train_sents) 629 if rule is None or score < min_score: 630 if self._trace > 1: 631 print 'Insufficient improvement; stopping' 632 break 633 else: 634 # Add the rule to our list of rules. 635 rules.append(rule) 636 # Use the rules to update the test corpus. Keep 637 # track of how many times the rule applied (k). 638 k = 0 639 for sent in test_sents: 640 k += len(rule.apply(sent)) 641 # Display trace output. 642 if self._trace > 1: 643 self._trace_rule(rule, score, fixscore, k) 644 # The user can also cancel training manually: 645 except KeyboardInterrupt: 646 print "Training stopped manually -- %d rules found" % len(rules) 647 648 # Create and return a tagger from the rules we found. 649 return BrillTagger(self._initial_tagger, rules)
650 651 #//////////////////////////////////////////////////////////// 652 # Finding the best rule 653 #//////////////////////////////////////////////////////////// 654 655 # Finds the rule that makes the biggest net improvement in the corpus. 656 # Returns a (rule, score) pair.
657 - def _best_rule(self, test_sents, train_sents):
658 # Create a dictionary mapping from each tag to a list of the 659 # indices that have that tag in both test_sents and 660 # train_sents (i.e., where it is correctly tagged). 661 correct_indices = defaultdict(list) 662 for sentnum, sent in enumerate(test_sents): 663 for wordnum, tagged_word in enumerate(sent): 664 if tagged_word[1] == train_sents[sentnum][wordnum][1]: 665 tag = tagged_word[1] 666 correct_indices[tag].append( (sentnum, wordnum) ) 667 668 # Find all the rules that correct at least one token's tag, 669 # and the number of tags that each rule corrects (in 670 # descending order of number of tags corrected). 671 rules = self._find_rules(test_sents, train_sents) 672 673 # Keep track of the current best rule, and its score. 674 best_rule, best_score, best_fixscore = None, 0, 0 675 676 # Consider each rule, in descending order of fixscore (the 677 # number of tags that the rule corrects, not including the 678 # number that it breaks). 679 for (rule, fixscore) in rules: 680 # The actual score must be <= fixscore; so if best_score 681 # is bigger than fixscore, then we already have the best 682 # rule. 683 if best_score > fixscore or (best_score == fixscore and 684 not self._deterministic): 685 return best_rule, best_score, best_fixscore 686 687 # Calculate the actual score, by decrementing fixscore 688 # once for each tag that the rule changes to an incorrect 689 # value. 690 score = fixscore 691 if rule.original_tag in correct_indices: 692 for (sentnum, wordnum) in correct_indices[rule.original_tag]: 693 if rule.applies(test_sents[sentnum], wordnum): 694 score -= 1 695 # If the score goes below best_score, then we know 696 # that this isn't the best rule; so move on: 697 if score < best_score or (score == best_score and 698 not self._deterministic): 699 break 700 701 # If the actual score is better than the best score, then 702 # update best_score and best_rule. 703 if score > best_score or (score == best_score and 704 self._deterministic and 705 repr(rule) < repr(best_rule)): 706 best_rule, best_score, best_fixscore = rule, score, fixscore 707 708 # Return the best rule, and its score. 709 return best_rule, best_score, best_fixscore
710
711 - def _find_rules(self, test_sents, train_sents):
712 """ 713 Find all rules that correct at least one token's tag in 714 C{test_sents}. 715 716 @return: A list of tuples C{(rule, fixscore)}, where C{rule} 717 is a brill rule and C{fixscore} is the number of tokens 718 whose tag the rule corrects. Note that C{fixscore} does 719 I{not} include the number of tokens whose tags are changed 720 to incorrect values. 721 """ 722 723 # Create a list of all indices that are incorrectly tagged. 724 error_indices = [] 725 for sentnum, sent in enumerate(test_sents): 726 for wordnum, tagged_word in enumerate(sent): 727 if tagged_word[1] != train_sents[sentnum][wordnum][1]: 728 error_indices.append( (sentnum, wordnum) ) 729 730 # Create a dictionary mapping from rules to their positive-only 731 # scores. 732 rule_score_dict = defaultdict(int) 733 for (sentnum, wordnum) in error_indices: 734 test_sent = test_sents[sentnum] 735 train_sent = train_sents[sentnum] 736 for rule in self._find_rules_at(test_sent, train_sent, wordnum): 737 rule_score_dict[rule] += 1 738 739 # Convert the dictionary into a list of (rule, score) tuples, 740 # sorted in descending order of score. 741 return sorted(rule_score_dict.items(), 742 key=lambda (rule,score): -score)
743
744 - def _find_rules_at(self, test_sent, train_sent, i):
745 """ 746 @rtype: C{Set} 747 @return: the set of all rules (based on the templates) that 748 correct token C{i}'s tag in C{test_sent}. 749 """ 750 applicable_rules = set() 751 if test_sent[i][1] != train_sent[i][1]: 752 correct_tag = train_sent[i][1] 753 for template in self._templates: 754 new_rules = template.applicable_rules(test_sent, i, 755 correct_tag) 756 applicable_rules.update(new_rules) 757 758 return applicable_rules
759 760 #//////////////////////////////////////////////////////////// 761 # Tracing 762 #//////////////////////////////////////////////////////////// 763
764 - def _trace_header(self):
765 print """ 766 B | 767 S F r O | Score = Fixed - Broken 768 c i o t | R Fixed = num tags changed incorrect -> correct 769 o x k h | u Broken = num tags changed correct -> incorrect 770 r e e e | l Other = num tags changed incorrect -> incorrect 771 e d n r | e 772 ------------------+------------------------------------------------------- 773 """.rstrip()
774
775 - def _trace_rule(self, rule, score, fixscore, numchanges):
776 if self._trace > 2: 777 print ('%4d%4d%4d%4d ' % (score, fixscore, fixscore-score, 778 numchanges-fixscore*2+score)), '|', 779 print textwrap.fill(str(rule), initial_indent=' '*20, width=79, 780 subsequent_indent=' '*18+'| ').strip() 781 else: 782 print rule
783
784 ###################################################################### 785 ## Fast Brill Tagger Trainer 786 ###################################################################### 787 788 -class FastBrillTaggerTrainer(object):
789 """ 790 A faster trainer for brill taggers. 791 """
792 - def __init__(self, initial_tagger, templates, trace=0, 793 deterministic=False):
794 if not deterministic: 795 deterministic = (trace > 0) 796 self._initial_tagger = initial_tagger 797 self._templates = templates 798 self._trace = trace 799 self._deterministic = deterministic 800 801 self._tag_positions = None 802 """Mapping from tags to lists of positions that use that tag.""" 803 804 self._rules_by_position = None 805 """Mapping from positions to the set of rules that are known 806 to occur at that position. Position is (sentnum, wordnum). 807 Initially, this will only contain positions where each rule 808 applies in a helpful way; but when we examine a rule, we'll 809 extend this list to also include positions where each rule 810 applies in a harmful or neutral way.""" 811 812 self._positions_by_rule = None 813 """Mapping from rule to position to effect, specifying the 814 effect that each rule has on the overall score, at each 815 position. Position is (sentnum, wordnum); and effect is 816 -1, 0, or 1. As with _rules_by_position, this mapping starts 817 out only containing rules with positive effects; but when 818 we examine a rule, we'll extend this mapping to include 819 the positions where the rule is harmful or neutral.""" 820 821 self._rules_by_score = None 822 """Mapping from scores to the set of rules whose effect on the 823 overall score is upper bounded by that score. Invariant: 824 rulesByScore[s] will contain r iff the sum of 825 _positions_by_rule[r] is s.""" 826 827 self._rule_scores = None 828 """Mapping from rules to upper bounds on their effects on the 829 overall score. This is the inverse mapping to _rules_by_score. 830 Invariant: ruleScores[r] = sum(_positions_by_rule[r])""" 831 832 self._first_unknown_position = None 833 """Mapping from rules to the first position where we're unsure 834 if the rule applies. This records the next position we 835 need to check to see if the rule messed anything up."""
836 837 #//////////////////////////////////////////////////////////// 838 # Training 839 #//////////////////////////////////////////////////////////// 840
841 - def train(self, train_sents, max_rules=200, min_score=2):
842 # Basic idea: Keep track of the rules that apply at each position. 843 # And keep track of the positions to which each rule applies. 844 845 if self._trace > 0: print ("Training Brill tagger on %d " 846 "sentences..." % len(train_sents)) 847 848 # Create a new copy of the training corpus, and run the 849 # initial tagger on it. We will progressively update this 850 # test corpus to look more like the training corpus. 851 test_sents = [self._initial_tagger.tag(untag(sent)) 852 for sent in train_sents] 853 854 # Initialize our mappings. This will find any errors made 855 # by the initial tagger, and use those to generate repair 856 # rules, which are added to the rule mappings. 857 if self._trace > 0: print "Finding initial useful rules..." 858 self._init_mappings(test_sents, train_sents) 859 if self._trace > 0: print (" Found %d useful rules." % 860 len(self._rule_scores)) 861 862 # Let the user know what we're up to. 863 if self._trace > 2: self._trace_header() 864 elif self._trace == 1: print "Selecting rules..." 865 866 # Repeatedly select the best rule, and add it to `rules`. 867 rules = [] 868 try: 869 while (len(rules) < max_rules): 870 # Find the best rule, and add it to our rule list. 871 rule = self._best_rule(train_sents, test_sents, min_score) 872 if rule: 873 rules.append(rule) 874 else: 875 break # No more good rules left! 876 877 # Report the rule that we found. 878 if self._trace > 1: self._trace_rule(rule) 879 880 # Apply the new rule at the relevant sites 881 self._apply_rule(rule, test_sents) 882 883 # Update _tag_positions[rule.original_tag] and 884 # _tag_positions[rule.replacement_tag] for the affected 885 # positions (i.e., self._positions_by_rule[rule]). 886 self._update_tag_positions(rule) 887 888 # Update rules that were affected by the change. 889 self._update_rules(rule, train_sents, test_sents) 890 891 # The user can cancel training manually: 892 except KeyboardInterrupt: 893 print "Training stopped manually -- %d rules found" % len(rules) 894 895 # Discard our tag position mapping & rule mappings. 896 self._clean() 897 898 # Create and return a tagger from the rules we found. 899 return BrillTagger(self._initial_tagger, rules)
900
901 - def _init_mappings(self, test_sents, train_sents):
902 """ 903 Initialize the tag position mapping & the rule related 904 mappings. For each error in test_sents, find new rules that 905 would correct them, and add them to the rule mappings. 906 """ 907 self._tag_positions = defaultdict(list) 908 self._rules_by_position = defaultdict(set) 909 self._positions_by_rule = defaultdict(dict) 910 self._rules_by_score = defaultdict(set) 911 self._rule_scores = defaultdict(int) 912 self._first_unknown_position = defaultdict(int) 913 914 # Scan through the corpus, initializing the tag_positions 915 # mapping and all the rule-related mappings. 916 for sentnum, sent in enumerate(test_sents): 917 for wordnum, (word, tag) in enumerate(sent): 918 919 # Initialize tag_positions 920 self._tag_positions[tag].append( (sentnum,wordnum) ) 921 922 # If it's an error token, update the rule-related mappings. 923 correct_tag = train_sents[sentnum][wordnum][1] 924 if tag != correct_tag: 925 for rule in self._find_rules(sent, wordnum, correct_tag): 926 self._update_rule_applies(rule, sentnum, wordnum, 927 train_sents)
928
929 - def _clean(self):
930 self._tag_positions = None 931 self._rules_by_position = None 932 self._positions_by_rule = None 933 self._rules_by_score = None 934 self._rule_scores = None 935 self._first_unknown_position = None
936
937 - def _find_rules(self, sent, wordnum, new_tag):
938 """ 939 Use the templates to find rules that apply at index C{wordnum} 940 in the sentence C{sent} and generate the tag C{new_tag}. 941 """ 942 for template in self._templates: 943 for rule in template.applicable_rules(sent, wordnum, new_tag): 944 yield rule
945
946 - def _update_rule_applies(self, rule, sentnum, wordnum, train_sents):
947 """ 948 Update the rule data tables to reflect the fact that 949 C{rule} applies at the position C{(sentnum, wordnum)}. 950 """ 951 pos = sentnum, wordnum 952 953 # If the rule is already known to apply here, ignore. 954 # (This only happens if the position's tag hasn't changed.) 955 if pos in self._positions_by_rule[rule]: 956 return 957 958 # Update self._positions_by_rule. 959 correct_tag = train_sents[sentnum][wordnum][1] 960 if rule.replacement_tag == correct_tag: 961 self._positions_by_rule[rule][pos] = 1 962 elif rule.original_tag == correct_tag: 963 self._positions_by_rule[rule][pos] = -1 964 else: # was wrong, remains wrong 965 self._positions_by_rule[rule][pos] = 0 966 967 # Update _rules_by_position 968 self._rules_by_position[pos].add(rule) 969 970 # Update _rule_scores. 971 old_score = self._rule_scores[rule] 972 self._rule_scores[rule] += self._positions_by_rule[rule][pos] 973 974 # Update _rules_by_score. 975 self._rules_by_score[old_score].discard(rule) 976 self._rules_by_score[self._rule_scores[rule]].add(rule)
977
978 - def _update_rule_not_applies(self, rule, sentnum, wordnum):
979 """ 980 Update the rule data tables to reflect the fact that C{rule} 981 does not apply at the position C{(sentnum, wordnum)}. 982 """ 983 pos = sentnum, wordnum 984 985 # Update _rule_scores. 986 old_score = self._rule_scores[rule] 987 self._rule_scores[rule] -= self._positions_by_rule[rule][pos] 988 989 # Update _rules_by_score. 990 self._rules_by_score[old_score].discard(rule) 991 self._rules_by_score[self._rule_scores[rule]].add(rule) 992 993 # Update _positions_by_rule 994 del self._positions_by_rule[rule][pos] 995 self._rules_by_position[pos].remove(rule)
996 997 # Optional addition: if the rule now applies nowhere, delete 998 # all its dictionary entries. 999
1000 - def _best_rule(self, train_sents, test_sents, min_score):
1001 """ 1002 Find the next best rule. This is done by repeatedly taking a 1003 rule with the highest score and stepping through the corpus to 1004 see where it applies. When it makes an error (decreasing its 1005 score) it's bumped down, and we try a new rule with the 1006 highest score. When we find a rule which has the highest 1007 score AND which has been tested against the entire corpus, we 1008 can conclude that it's the next best rule. 1009 """ 1010 if self._rules_by_score == {}: 1011 return None 1012 max_score = max(self._rules_by_score) 1013 1014 while max_score >= min_score: 1015 best_rules = list(self._rules_by_score[max_score]) 1016 if self._deterministic: 1017 best_rules.sort(key=repr) 1018 for rule in best_rules: 1019 positions = self._tag_positions[rule.original_tag] 1020 1021 unk = self._first_unknown_position.get(rule, (0,-1)) 1022 start = bisect.bisect_left(positions, unk) 1023 1024 for i in range(start, len(positions)): 1025 sentnum, wordnum = positions[i] 1026 if rule.applies(test_sents[sentnum], wordnum): 1027 self._update_rule_applies(rule, sentnum, wordnum, 1028 train_sents) 1029 if self._rule_scores[rule] < max_score: 1030 self._first_unknown_position[rule] = (sentnum, 1031 wordnum+1) 1032 break # The update demoted the rule. 1033 1034 if self._rule_scores[rule] == max_score: 1035 self._first_unknown_position[rule] = (len(train_sents)+1,0) 1036 return rule 1037 1038 # We demoted all the rules with score==max_score. 1039 assert not self._rules_by_score[max_score] 1040 del self._rules_by_score[max_score] 1041 if len(self._rules_by_score) == 0: return None 1042 max_score = max(self._rules_by_score) 1043 1044 # We reached the min-score threshold. 1045 return None
1046
1047 - def _apply_rule(self, rule, test_sents):
1048 """ 1049 Update C{test_sents} by applying C{rule} everywhere where its 1050 conditions are meet. 1051 """ 1052 update_positions = set(self._positions_by_rule[rule]) 1053 old_tag = rule.original_tag 1054 new_tag = rule.replacement_tag 1055 1056 if self._trace > 3: self._trace_apply(len(update_positions)) 1057 1058 # Update test_sents. 1059 for (sentnum, wordnum) in update_positions: 1060 text = test_sents[sentnum][wordnum][0] 1061 test_sents[sentnum][wordnum] = (text, new_tag)
1062
1063 - def _update_tag_positions(self, rule):
1064 """ 1065 Update _tag_positions to reflect the changes to tags that are 1066 made by C{rule}. 1067 """ 1068 # Update the tag index. 1069 for pos in self._positions_by_rule[rule]: 1070 # Delete the old tag. 1071 old_tag_positions = self._tag_positions[rule.original_tag] 1072 old_index = bisect.bisect_left(old_tag_positions, pos) 1073 del old_tag_positions[old_index] 1074 # Insert the new tag. 1075 new_tag_positions = self._tag_positions[rule.replacement_tag] 1076 bisect.insort_left(new_tag_positions, pos)
1077
1078 - def _update_rules(self, rule, train_sents, test_sents):
1079 """ 1080 Check if we should add or remove any rules from consideration, 1081 given the changes made by C{rule}. 1082 """ 1083 # Collect a list of all positions that might be affected. 1084 neighbors = set() 1085 for sentnum, wordnum in self._positions_by_rule[rule]: 1086 for template in self._templates: 1087 n = template.get_neighborhood(test_sents[sentnum], wordnum) 1088 neighbors.update([(sentnum, i) for i in n]) 1089 1090 # Update the rules at each position. 1091 num_obsolete = num_new = num_unseen = 0 1092 for sentnum, wordnum in neighbors: 1093 test_sent = test_sents[sentnum] 1094 correct_tag = train_sents[sentnum][wordnum][1] 1095 1096 # Check if the change causes any rule at this position to 1097 # stop matching; if so, then update our rule mappings 1098 # accordingly. 1099 old_rules = set(self._rules_by_position[sentnum, wordnum]) 1100 for old_rule in old_rules: 1101 if not old_rule.applies(test_sent, wordnum): 1102 num_obsolete += 1 1103 self._update_rule_not_applies(old_rule, sentnum, wordnum) 1104 1105 # Check if the change causes our templates to propose any 1106 # new rules for this position. 1107 site_rules = set() 1108 for template in self._templates: 1109 for new_rule in template.applicable_rules(test_sent, wordnum, 1110 correct_tag): 1111 if new_rule not in old_rules: 1112 num_new += 1 1113 if new_rule not in self._rule_scores: 1114 num_unseen += 1 1115 old_rules.add(new_rule) 1116 self._update_rule_applies(new_rule, sentnum, 1117 wordnum, train_sents) 1118 1119 # We may have caused other rules to match here, that are 1120 # not proposed by our templates -- in particular, rules 1121 # that are harmful or neutral. We therefore need to 1122 # update any rule whose first_unknown_position is past 1123 # this rule. 1124 for new_rule, pos in self._first_unknown_position.items(): 1125 if pos > (sentnum, wordnum): 1126 if new_rule not in old_rules: 1127 num_new += 1 1128 if new_rule.applies(test_sent, wordnum): 1129 self._update_rule_applies(new_rule, sentnum, 1130 wordnum, train_sents) 1131 1132 if self._trace > 3: 1133 self._trace_update_rules(num_obsolete, num_new, num_unseen)
1134 1135 #//////////////////////////////////////////////////////////// 1136 # Tracing 1137 #//////////////////////////////////////////////////////////// 1138
1139 - def _trace_header(self):
1140 print """ 1141 B | 1142 S F r O | Score = Fixed - Broken 1143 c i o t | R Fixed = num tags changed incorrect -> correct 1144 o x k h | u Broken = num tags changed correct -> incorrect 1145 r e e e | l Other = num tags changed incorrect -> incorrect 1146 e d n r | e 1147 ------------------+------------------------------------------------------- 1148 """.rstrip()
1149
1150 - def _trace_rule(self, rule):
1151 assert self._rule_scores[rule] == \ 1152 sum(self._positions_by_rule[rule].values()) 1153 1154 changes = self._positions_by_rule[rule].values() 1155 num_changed = len(changes) 1156 num_fixed = len([c for c in changes if c==1]) 1157 num_broken = len([c for c in changes if c==-1]) 1158 num_other = len([c for c in changes if c==0]) 1159 score = self._rule_scores[rule] 1160 1161 if self._trace > 2: 1162 print '%4d%4d%4d%4d |' % (score,num_fixed,num_broken,num_other), 1163 print textwrap.fill(str(rule), initial_indent=' '*20, 1164 subsequent_indent=' '*18+'| ').strip() 1165 else: 1166 print rule
1167
1168 - def _trace_apply(self, num_updates):
1169 prefix = ' '*18+'|' 1170 print prefix 1171 print prefix, 'Applying rule to %d positions.' % num_updates
1172
1173 - def _trace_update_rules(self, num_obsolete, num_new, num_unseen):
1174 prefix = ' '*18+'|' 1175 print prefix, 'Updated rule tables:' 1176 print prefix, (' - %d rule applications removed' % num_obsolete) 1177 print prefix, (' - %d rule applications added (%d novel)' % 1178 (num_new, num_unseen)) 1179 print prefix
1180
1181 1182 1183 ###################################################################### 1184 ## Testing 1185 ###################################################################### 1186 1187 # returns a list of errors in string format 1188 -def error_list (train_sents, test_sents, radius=2):
1189 """ 1190 Returns a list of human-readable strings indicating the errors in the 1191 given tagging of the corpus. 1192 1193 @param train_sents: The correct tagging of the corpus 1194 @type train_sents: C{list} of C{tuple} 1195 @param test_sents: The tagged corpus 1196 @type test_sents: C{list} of C{tuple} 1197 @param radius: How many tokens on either side of a wrongly-tagged token 1198 to include in the error string. For example, if C{radius}=2, 1199 each error string will show the incorrect token plus two 1200 tokens on either side. 1201 @type radius: int 1202 """ 1203 hdr = (('%25s | %s | %s\n' + '-'*26+'+'+'-'*24+'+'+'-'*26) % 1204 ('left context', 'word/test->gold'.center(22), 'right context')) 1205 errors = [hdr] 1206 for (train_sent, test_sent) in zip(train_sents, test_sents): 1207 for wordnum, (word, train_pos) in enumerate(train_sent): 1208 test_pos = test_sent[wordnum][1] 1209 if train_pos != test_pos: 1210 left = ' '.join('%s/%s' % w for w in train_sent[:wordnum]) 1211 right = ' '.join('%s/%s' % w for w in train_sent[wordnum+1:]) 1212 mid = '%s/%s->%s' % (word, test_pos, train_pos) 1213 errors.append('%25s | %s | %s' % 1214 (left[-25:], mid.center(22), right[:25])) 1215 1216 return errors
1217
1218 ###################################################################### 1219 # Demonstration 1220 ###################################################################### 1221 1222 -def demo(num_sents=2000, max_rules=200, min_score=3, 1223 error_output="errors.out", rule_output="rules.yaml", 1224 randomize=False, train=.8, trace=3):
1225 """ 1226 Brill Tagger Demonstration 1227 1228 @param num_sents: how many sentences of training and testing data to use 1229 @type num_sents: L{int} 1230 @param max_rules: maximum number of rule instances to create 1231 @type max_rules: L{int} 1232 @param min_score: the minimum score for a rule in order for it to 1233 be considered 1234 @type min_score: L{int} 1235 @param error_output: the file where errors will be saved 1236 @type error_output: C{string} 1237 @param rule_output: the file where rules will be saved 1238 @type rule_output: C{string} 1239 @param randomize: whether the training data should be a random subset 1240 of the corpus 1241 @type randomize: boolean 1242 @param train: the fraction of the the corpus to be used for training 1243 (1=all) 1244 @type train: L{float} 1245 @param trace: the level of diagnostic tracing output to produce (0-4) 1246 @type trace: L{int} 1247 """ 1248 1249 from nltk.corpus import treebank 1250 from nltk import tag 1251 from nltk.tag import brill 1252 1253 nn_cd_tagger = tag.RegexpTagger([(r'^-?[0-9]+(.[0-9]+)?$', 'CD'), 1254 (r'.*', 'NN')]) 1255 1256 # train is the proportion of data used in training; the rest is reserved 1257 # for testing. 1258 print "Loading tagged data... " 1259 tagged_data = treebank.tagged_sents() 1260 if randomize: 1261 random.seed(len(sents)) 1262 random.shuffle(sents) 1263 cutoff = int(num_sents*train) 1264 training_data = tagged_data[:cutoff] 1265 gold_data = tagged_data[cutoff:num_sents] 1266 testing_data = [[t[0] for t in sent] for sent in gold_data] 1267 print "Done loading." 1268 1269 # Unigram tagger 1270 print "Training unigram tagger:" 1271 unigram_tagger = tag.UnigramTagger(training_data, 1272 backoff=nn_cd_tagger) 1273 if gold_data: 1274 print " [accuracy: %f]" % unigram_tagger.evaluate(gold_data) 1275 1276 # Bigram tagger 1277 print "Training bigram tagger:" 1278 bigram_tagger = tag.BigramTagger(training_data, 1279 backoff=unigram_tagger) 1280 if gold_data: 1281 print " [accuracy: %f]" % bigram_tagger.evaluate(gold_data) 1282 1283 # Brill tagger 1284 templates = [ 1285 brill.SymmetricProximateTokensTemplate(brill.ProximateTagsRule, (1,1)), 1286 brill.SymmetricProximateTokensTemplate(brill.ProximateTagsRule, (2,2)), 1287 brill.SymmetricProximateTokensTemplate(brill.ProximateTagsRule, (1,2)), 1288 brill.SymmetricProximateTokensTemplate(brill.ProximateTagsRule, (1,3)), 1289 brill.SymmetricProximateTokensTemplate(brill.ProximateWordsRule, (1,1)), 1290 brill.SymmetricProximateTokensTemplate(brill.