Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #!/usr/bin/python
- #-------------------------------------------------------------------------------
- # dbvsrsw.py
- # http://pastebin.com/F4JZtDTB
- #-------------------------------------------------------------------------------
- '''DBvsRSW, a Python script for testing the fairness of a resampling wheel.
- Description:
- A resampling wheel is a much more efficient method for taking weighted
- samples from a population according to their weights than a dartboard.
- The resampling wheel is not completely unbiased, though it's likely
- to be adequate for genetic algorithms like particle filtering where
- the problem of over-filtering is a more pressing concern.
- For a list of weights in increasing or decreasing order, at least,
- the resampling wheel often performs very poorly if the starting index
- is not randomised and the angle step factor is small.
- Included in this testing suite is the brilliant O(N) resampling wheel
- presented by Erik Colban on the Udacity CS373 forum.
- Author:
- Daniel Neville (Blancmange), creamygoat@gmail.com
- (Resampling Wheel adapted from Prof. Sebastian Thrun's CS373 lecture notes)
- (SmartWheel adapted from Erik Colban's O(N) weighted resampler)
- Copyright:
- None
- Licence:
- Public domain
- INDEX
- Imports
- Output formatting functions:
- GFListStr(L)
- Dartboard functions:
- DartboardFromWeights(Weights)
- ThrowDart(Dartboard)
- Weighted sampler test functions:
- TestDartboard(Weights, NumRounds, [Variations])
- TestRSWheel(Weights, NumRounds, [Variations])
- TestSmartWheel(Weights, NumRounds, [Variations])
- DisplayWeightsHistogram(Weights, Histogram)
- Main:
- Main()
- Command line trigger
- '''
- #-------------------------------------------------------------------------------
- # Imports
- #-------------------------------------------------------------------------------
- import math
- from math import (log, sqrt, frexp)
- import random
- #-------------------------------------------------------------------------------
- # Output formatting functions
- #-------------------------------------------------------------------------------
- def GFListStr(L):
- '''Return as a string, a list (or tuple) in general precision format.'''
- return '[' + (', '.join('%g' % (x) for x in L)) + ']'
- #-------------------------------------------------------------------------------
- # Dartboard functions
- #-------------------------------------------------------------------------------
- def DartboardFromWeights(Weights):
- '''Returns a dartboard with its n+1 fences spaced by the n weights.'''
- Result = []
- ZoneLowerLimit = 0.0
- Result.append(ZoneLowerLimit)
- for w in Weights:
- ZoneLowerLimit += w
- Result.append(ZoneLowerLimit)
- return Result
- #-------------------------------------------------------------------------------
- def ThrowDart(Dartboard):
- '''Returns the index of a random dart thrown at a dartboard.
- The board is defined by lower (inclusive) bounds of each zone
- followed by the (exclusive) upper bound.
- '''
- LowIx = 0
- Result = -1
- HighIx = len(Dartboard) - 2
- if HighIx >= 0:
- InclusiveLowerLimit = Dartboard[0]
- ExclusiveUpperLimit = Dartboard[HighIx + 1]
- r = random.random()
- Dart = (1.0 - r) * InclusiveLowerLimit + r * ExclusiveUpperLimit
- # Binary search, choosing last among equals.
- while LowIx < HighIx:
- MidIx = HighIx - ((HighIx - LowIx) / 2)
- ZoneLL = Dartboard[MidIx]
- if Dart < ZoneLL:
- HighIx = MidIx - 1
- else:
- LowIx = MidIx
- # Whatever happens, LowIx = HighIx
- Result = LowIx
- return Result
- #-------------------------------------------------------------------------------
- # Weighted sampler test functions
- #-------------------------------------------------------------------------------
- def TestDartboard(Weights, NumRounds, Variations=None):
- '''Test the Dartboard method for selecting items according to their weights.
- This method is not especially efficient but is straightforward and robust.
- It serves as a standard by which the other resamplers may be compared.
- No variations are offered.
- Returned is a histogram with one entry for each item in Weights.
- '''
- Dartboard = DartboardFromWeights(Weights)
- Histogram = [0] * len(Weights)
- NumDartsToThrow = len(Weights) * NumRounds
- print "Throwing %d darts..." % (NumDartsToThrow)
- while NumDartsToThrow > 0:
- Ix = ThrowDart(Dartboard)
- Histogram[Ix] += 1
- NumDartsToThrow -= 1
- return Histogram
- #-------------------------------------------------------------------------------
- def TestRSWheel(Weights, NumRounds, Variations=None):
- '''Test the Resampling Wheel method for weighted resampling.
- The resampling wheel, presented by Pref. Sebastian Thrun in the Udacity
- course CS373: Programming a Robotic Car, is quick and dirty and prone to
- bias but adequate for applications such as particle filtering.
- Defaults for Variations:
- MaxStepFactor: 2.0
- RandomStartIndex: True
- Returned is a histogram with one entry for each item in Weights.
- '''
- #-----------------------------------------------------------------------------
- def Variation(Key, Default):
- Result = Default
- if Variations is not None:
- if Key in Variations:
- Result = Variations[Key]
- return Result
- #-----------------------------------------------------------------------------
- MaxStepFactor = Variation('MaxStepFactor', 2.0)
- DoRandomiseStartIx = Variation('RandomStartIndex', True)
- N = len(Weights)
- Histogram = [0] * N
- print 'Resampling %d items from %d successive wheels...' % (N, NumRounds)
- print 'Max step factor:', MaxStepFactor
- print 'Start index:', (['0', 'Random'][DoRandomiseStartIx])
- NumRoundsToGo = NumRounds
- while NumRoundsToGo > 0:
- WIx = 0
- Phase = 0.0
- MaxStep = MaxStepFactor * max(Weights)
- if DoRandomiseStartIx:
- WIx = random.randint(0, N - 1)
- NumSamplesToGo = N
- while NumSamplesToGo > 0:
- Step = MaxStep * random.random()
- Phase += Step
- while Weights[WIx] <= Phase:
- Phase -= Weights[WIx]
- WIx = (WIx + 1) % N
- Histogram[WIx] += 1
- NumSamplesToGo -= 1
- NumRoundsToGo -= 1
- return Histogram
- #-------------------------------------------------------------------------------
- def TestSmartWheel(Weights, NumRounds, Variations=None):
- '''Test Erik Colban's weighted resampler.
- This resampler is like a resampling wheel except that it only goes around
- once and has a complexity of O(N).
- Defaults for Variations:
- UseApproxLog2: False
- Returned is a histogram with one entry for each item in Weights.
- '''
- # From Erik's notes:
- #
- # The algorithm is based on the fact that, after sorting N uniformly
- # distributed samples, the distances between two consecutive samples
- # is exponentially distributed. The algorithm is similar to the resampling
- # wheel algorithm, except that it makes exactly one revolution around the
- # resampling wheel. This resampling algorithm is O(N)
- #
- # http://forums.udacity.com/cs373-april2012/questions/1328/an-on-unbiased-resampler
- #-----------------------------------------------------------------------------
- def Variation(Key, Default):
- Result = Default
- if Variations is not None:
- if Key in Variations:
- Result = Variations[Key]
- return Result
- #-----------------------------------------------------------------------------
- UseApproxLog2 = Variation('ApproxLog2', False)
- N = len(Weights)
- Histogram = [0] * N
- print 'Resampling %d items from %d successive wheels...' % (N, NumRounds)
- print 'Logarithm:', (['Natural', 'Approx. base 2'][UseApproxLog2])
- NumRoundsToGo = NumRounds
- while NumRoundsToGo > 0:
- # Select N + 1 numbers exponentially distributed with parameter lambda = 1.
- Diffs = [0] * (N + 1)
- if UseApproxLog2:
- # Use an approximate base-2 logarithim to avoid repeatedly
- # calling a transcendental function.
- # The trick relies on the way IEEE 754 floats are stored.
- # The average error is 0.001276 and the maximum error is 0.001915.
- a = 1.0 / sqrt(2.0)
- b = 1.0 / (1.0 + a)
- c = 1.0 / (1.0 / (0.5 + a) - b)
- for i in range(N + 1):
- x = 1.0 - random.random()
- m, e = frexp(x)
- al2 = e - c * (1.0 / (m + a) - b)
- Diffs[i] = -al2
- else:
- # Use a real logarithm function. The natural logarithm is fine.
- for i in range(N + 1):
- Diffs[i] = -log(1.0 - random.random())
- # Stretch to fit the circumference of the resampling wheel.
- Scale = sum(Weights) / sum(Diffs)
- Diffs = [Scale * x for x in Diffs]
- WIx = 0
- Phase = 0
- try:
- # Go around the resampling wheel exactly once.
- for i in range(N):
- Phase += Diffs[i]
- # The number of step-seek iterations is random for each sampling
- # but the total number of such iterations when all N samplings
- # are performed is N - 1. The wheel sample index never exceeds
- # N - 1 except in an extremely unlucky case of rounding.
- while Phase > Weights[WIx]:
- Phase -= Weights[WIx]
- WIx += 1
- Histogram[WIx] += 1
- except (IndexError):
- # This can only happen in the extremely unlucky case
- # of accumulated rounding errors.
- pass
- NumRoundsToGo -= 1
- return Histogram
- #-------------------------------------------------------------------------------
- def DisplayWeightsHistogram(Weights, Histogram):
- '''Display a histogram using ASCII art, each bar labelled with weights.
- The width of the histogram is automatically scaled to the largest bar.
- '''
- #-----------------------------------------------------------------------------
- MAX_BAR_WIDTH = 40
- #-----------------------------------------------------------------------------
- IxWidth = len(str(len(Histogram)))
- MaxH = max(1, max(Histogram))
- TotalH = sum(Histogram)
- SafeHPCMult = 100.0 / max(1, TotalH)
- for Ix, NumHits in enumerate(Histogram):
- BarLength = int(round(MAX_BAR_WIDTH * NumHits / float(MaxH)))
- ItemStr = "W[%*d]" % (IxWidth, Ix)
- if Weights is not None:
- ValueStr = " = %6.3f " % (Weights[Ix])
- BarStr = "#" * BarLength
- PercentStr = "(%.2f%%)" % (SafeHPCMult * Histogram[Ix])
- print "%-*s%s: %s %s" % (
- IxWidth + 2, ItemStr, ValueStr, BarStr, PercentStr)
- #-------------------------------------------------------------------------------
- # Main
- #-------------------------------------------------------------------------------
- def Main():
- '''Run the test suite of weighted resamplers.'''
- Weights = [15 + x for x in range(1, 21)]
- NumRounds = 1000
- WeightedSamplers = [
- (TestDartboard, 'Dartboard', None, None),
- (TestRSWheel, 'Resampling Wheel', '(Random start index as standard)', None),
- (TestRSWheel, 'Resampling Wheel', '(Starts at first index)',
- {'RandomStartIndex': False}),
- (TestRSWheel, 'Resampling Wheel', '(SFI, Step factor of 5)',
- {'RandomStartIndex': False, 'MaxStepFactor': 5.0}),
- (TestRSWheel, 'Resampling Wheel', '(RSI, step factor of 5)',
- {'RandomStartIndex': True, 'MaxStepFactor': 5.0}),
- (TestSmartWheel, 'Erik Colban\'s Smart Wheel', None, None),
- (TestSmartWheel, 'Erik Colban\'s Smart Wheel', '(Approx. log2)',
- {'ApproxLog2': True})
- ]
- print 'Test of Weighted Resampling Functions\n'
- print "Weights:"
- print GFListStr(Weights)
- print
- for TIx, TestRec in enumerate(WeightedSamplers):
- TestFn, TestName, VariantName, Variations = TestRec
- print 'Test %d/%d: %s' % (TIx + 1, len(WeightedSamplers) , TestName)
- if VariantName is not None:
- print VariantName
- print
- Histogram = TestFn(Weights, NumRounds, Variations)
- DisplayWeightsHistogram(Weights, Histogram)
- print
- #-------------------------------------------------------------------------------
- # Command line trigger
- #-------------------------------------------------------------------------------
- if __name__ == '__main__':
- Main()
- #-------------------------------------------------------------------------------
- # End
- #-------------------------------------------------------------------------------
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement