## Sorting Networks

### Odd-Even Merge, Pairwise and Bitonic Sorting Networks

In [1]:
from math   import ceil, log2
from random import seed, shuffle

### Utility Functions for Sorting Networks

In [2]:
def netcheck (net):
    '''Check layers of a sorting or selection network.
    net     sorting or selection network to check
    returns whether in each layer each wire/lane occurs at most once
            (otherwise the swaps cannot be executed in parallel)'''
    lpl = [[x for s in l for x in s] for l in net]
    return all(len(l) == len(set(l)) for l in lpl)

def netreduce (net, n):
    '''Reduce wires/lanes and layers of a sorting or selection network;
    also try to combine consecutive layers of the given network.
    net     sorting or selection network to reduce
    n       number of wires/lanes to reduce the network to
    returns the reduced sorting or selection network'''
    out = [[]]; lanes = set()
    for lyr in net:
        lyr = [(a,b) for a,b in lyr if a < n and b < n]
        l   = set(x for s in lyr for x in s)
        if len(lanes & l) > 0: out.append(lyr);     lanes  = l
        else:                  out[-1].extend(lyr); lanes |= l
    return out

def apply (X, net, up=True):
    '''Apply a sorting or selection network to a data vector.
    X     data vector to apply the sorting network to
    net   sorting or selection network to apply
    up    whether swaps work upward (second > first)'''
    if up:                      # if swaps work upward (second > first)
        for lyr in net:         # traverse the network layers
            for i,k in lyr:     # traverse the swaps in each layer
                if X[i] > X[k]: # swap lanes if needed
                    X[i],X[k] = X[k],X[i]
    else:                       # if swaps work downward (first > second)
        for lyr in net:         # traverse the network layers
            for i,k in lyr:     # traverse the swaps in each layer
                if X[i] < X[k]: # swap lanes if needed
                    X[i],X[k] = X[k],X[i]

### Functions Common to Odd-Even Merge and Pairwise Sorting Networks

In [3]:
def lyrjoin (lft, rgt):
    '''Join two lists of parallel network layers into one.
    lft     left  (first)  list of layers
    rgt     right (second) list of layers
    returns the joined list of layers'''
    return [l+r for l,r in zip(lft,rgt)]

def splitter (lanes):
    '''Construct a splitter for the given lanes/wires.
    lanes  list of lane/wire identifiers
    return the constructed splitter for the given lanes/wires'''
    n = len(lanes)
    if n < 2: return []
    s = n//2
    return [[(lanes[i],lanes[i+s]) for i in range(0,n-s)]]

def balmerger (lanes):
    '''Construct a balanced merger for the given lanes/wires.
    lanes  list of lane/wire identifiers
    return the constructed balanced merger for the given lanes/wires'''
    n = len(lanes)
    if n < 4: return []
    return [[(lanes[i],lanes[i+1]) for i in range(1,n-1,2)]]

def pwmerger (lanes):
    '''Construct a pairwise merger for the given lanes/wires.
    lanes  list of lane/wire identifiers
    return the constructed pairwise merger for the given lanes/wires'''
    n = len(lanes)
    if n < 2: return []
    even = pwmerger(lanes[0::2])
    odd  = pwmerger(lanes[1::2])
    return lyrjoin(even,odd) +balmerger(lanes)

### Odd-Even Merge Sorting Network

[Ken Batcher 1968]

<b>Sorting Networks and Their Application</b><br>
Ken E. Batcher<br>
<i>Proc. AFIPS Spring Joint Computer Conference (Atlantic City, NJ)</i>, 307-314<br>
https://doi.org/10.1145/1468075.1468121

In [4]:
def oerecurse (lanes):
    '''Recursive part of constructing a odd-even merge sorting network.
    lanes   lanes/wires for which to construct the sorting network
    returns odd-even merge sorting network for the given lanes'''
    n = len(lanes)
    if n < 3: return [] if n < 2 else [[tuple(lanes)]]
    n //= 2
    lft = oerecurse(lanes[:n])
    rgt = oerecurse(lanes[n:])
    return lyrjoin(lft,rgt) +splitter(lanes) +pwmerger(lanes)

def oddeven (n):
    '''Construct an odd-even merge sorting network.
    n       number of wires/lanes of the network to construct
    returns an odd-even merge sorting network as a list of layers,
            each of which is a list of wire/lane pairs
            that are to be connected by comparators'''
    m = 2**int(ceil(log2(n)))
    return netreduce(oerecurse(list(range(m))), n)

### Pairwise Sorting Network

[Ian Parberry 1992]

<b>The Pairwise Sorting Network</b><br>
Ian Parberry<br>
<i>Parallel Processing Letters</i> 2(2,3):205â€“211<br>
https://doi.org/10.1142/S0129626492000337

In [5]:
def pwrecurse (lanes):
    '''Recursive part of constructing a pairwise sorting network.
    lanes   lanes/wires for which to construct the sorting network
    returns pairwise sorting network for the given lanes'''
    n = len(lanes)
    if n < 3: return [] if n < 2 else [[tuple(lanes)]]
    n //= 2
    lft = pwrecurse(lanes[:n])
    rgt = pwrecurse(lanes[n:])
    return splitter(lanes) +lyrjoin(lft,rgt) +pwmerger(lanes)

def pairwise (n):
    '''Construct an pairwise sorting network.
    n       number of wires/lanes of the network to construct
    returns an odd-even merge sorting network as a list of layers,
            each of which is a list of wire/lane pairs
            that are to be connected by comparators'''
    m = 2**int(ceil(log2(n)))
    return netreduce(pwrecurse(list(range(m))), n)

### Bitonic Sorting Network

[Ken Batcher 1968]

<b>Sorting Networks and Their Application</b><br>
Ken E. Batcher<br>
<i>Proc. AFIPS Spring Joint Computer Conference (Atlantic City, NJ)</i>, 307-314<br>
https://doi.org/10.1145/1468075.1468121

In [6]:
def bitmerge (lanes, up):
    '''Construct a bitonic merger for the given lanes/wires.
    lanes   the lanes for which to construct a bitonic merger
    up      whether sorting is upwards (lower to higher lanes)
    returns a bitonic merger for the given lanes/wires'''
    n = len(lanes)
    if n < 2: return []
    n //= 2
    if up: split = [[(lanes[i],lanes[i+n]) for i in range(0,n)]]
    else:  split = [[(lanes[i+n],lanes[i]) for i in range(0,n)]]
    if n <= 1: return split
    lft = bitmerge(lanes[:n], up)
    rgt = bitmerge(lanes[n:], up)
    return split +lyrjoin(lft,rgt)

def bitrecurse (lanes, up):
    '''Recursive part of constructing a bitonic sorting network.
    lanes   lanes/wires for which to construct the sorting network
    up      whether sorting is upwards (lower to higher lanes)
    returns pairwise sorting network for the given lanes'''
    n = len(lanes)
    if n < 2: return []
    n //= 2
    lft = bitrecurse(lanes[:n],     up)
    rgt = bitrecurse(lanes[n:], not up)
    return lyrjoin(lft,rgt) +bitmerge(lanes, up)

def bitreduce (net, m, n):
    '''Reduce a bitonic sorting network to fewer lanes/wires.
    net     sorting or selection network to reduce
    m       number of wires/lanes of the current network
    n       number of wires/lanes to reduce the network to
    returns the reduced sorting or selection network'''
    d = m-n
    return [[(a-d,b-d) for a,b in l if a >= d and b >= d] for l in net]

def bitonic (n):
    '''Construct a bitonic sorting network type 2.
    n       number of wires/lanes of the network to construct
    returns a bitonic sorting network as a list of layers,
            each of which is a list of wire/lane pairs
            that are to be connected by comparators'''
    m = 2**int(ceil(log2(n)))
    return bitreduce(bitrecurse(list(range(m)), True), m, n)

### Wrapper for Sorting Network Functions

In [7]:
def sortnet (n, mode='pairwise'):
    '''Wrapper function for creating a sorting network.
    n       number of wires/lanes of the network to construct
    mode    type of the sorting network to construct,
            one of 'pairwise', 'oddeven' of 'bitonic'
    returns a sorting network of the desired type'''
    if mode in ['pairwise', 'pw']: return pairwise(n)
    if mode in ['oddeven',  'oe']: return oddeven(n)
    return bitonic(n)

### Test Sorting Networks

In [8]:
for n in range(2,36):
    X = list(range(n))
    for mode in ['oddeven', 'pairwise', 'bitonic']:
        for i in range(100):
            net = sortnet(n, mode)
            if not netcheck(net): print('error!', mode, n)
            shuffle(X)
            apply(X, net)
            if (X != list(range(n))): print('panic!', mode, n)
print('done')

done


In [9]:
net = sortnet(8, "oddeven")
for lyr in net: print(lyr)

[(0, 1), (2, 3), (4, 5), (6, 7)]
[(0, 2), (1, 3), (4, 6), (5, 7)]
[(1, 2), (5, 6)]
[(0, 4), (1, 5), (2, 6), (3, 7)]
[(2, 4), (3, 5)]
[(1, 2), (3, 4), (5, 6)]


In [10]:
net = sortnet(8, "pairwise")
for lyr in net: print(lyr)

[(0, 4), (1, 5), (2, 6), (3, 7)]
[(0, 2), (1, 3), (4, 6), (5, 7)]
[(0, 1), (2, 3), (4, 5), (6, 7)]
[(1, 2), (5, 6)]
[(2, 4), (3, 5)]
[(1, 2), (3, 4), (5, 6)]


In [11]:
net = sortnet(8, "bitonic")
for lyr in net: print(lyr)

[(0, 1), (3, 2), (5, 4), (6, 7)]
[(0, 2), (1, 3), (6, 4), (7, 5)]
[(0, 1), (2, 3), (5, 4), (7, 6)]
[(0, 4), (1, 5), (2, 6), (3, 7)]
[(0, 2), (1, 3), (4, 6), (5, 7)]
[(0, 1), (2, 3), (4, 5), (6, 7)]


## Majority Gate Networks

Logic Gate Networks that determine whether the input contains a majority of zeros or a majority of ones.

### Majority Gate Network Construction from Sorting Network

Comparators / conditional swap devices for binary values can be implemented by an or-gate (for the maximum side) and an and-gate (for the minimum side). Using these logic gates, a binary majority network can be constructed from a sorting network

In [12]:
def majority (n, mode='bitonic'):
    '''Construct a majority gate network based on a sorting network
    n       number of wires/lanes of the network to construct
    mode    type of the sorting network to start from
    returns a majority gate network of the desired type'''
    if n <= 0: return None      # handle up to 2 lanes
    if n <= 1: return [0]       # directly
    if n <= 2: return [[(0,1,'|')], 0]
    m = n//2                    # median wire/lane index
    net = [[g for a,b in lyr for g in [(a,b,'&'),(b,a,'|')]]
           for lyr in sortnet(n, mode)] +[m]
    src = {m}                   # root of gate tree
    for k in range(2,len(net)): # collect the needed gates
        net[-k] = [g for g in net[-k] if g[0] in src]
        src |= set(x for a,b,o in net[-k] for x in [a,b])
    return net                  # return the created network

def mapply (X, net):
    '''Apply a majority gate network to a data vector.
    X     data vector to apply the majority gate network to
    net   majority gate network to apply
    returns the majority value (0 or 1 / False or True)'''
    for lyr in net[:-1]:        # traverse network layers
        Y = X.copy()            # copy output of previous layer
        for a,b,o in lyr:       # apply gates of current layer
            Y[a] = X[a] & X[b] if o == '&' else X[a] | X[b]
        X = Y                   # replace layer output
    return X[net[-1]]           # retrieve median lane

### Majority Gate Network Reduction

In [13]:
def mreduce (net):
    '''Reduce a majority gate network, that is,
    try to apply all gates as early as possible,
    in the hope to reduce the number of layers.
    net     majority gate network to reduce
    returns the reduced majority gate network'''
    if len(net) <= 2:           # need at least two layers
        return net              # to reduce something
    out = [net[0].copy()]       # copy the first gate layer
    for lyr in net[1:-1]:       # traverse remaining layers
        dst = set(a for a,b,o in out[-1])
        rem = []                # get outputs of previous layer
        for g in lyr:           # traverse gates of current layer
            if g[0] not in dst and g[1] not in dst:
                out[-1].append(g)
                dst.add(g[0])   # move gates to previous layer
            else:               # if this is possible, otherwise
                rem.append(g)   # note gate for current layer
        if rem:                 # if gates left in current layer
            out.append(rem)     # add them as a new layer
    return out +net[-1:]        # return the reduced network

### Test Majority Gate Networks

In [14]:
net = majority(7, 'bitonic')
for lyr in net: print(lyr)
print(sum(len(lyr) for lyr in net[:-1]))
print(mapply([1,1,0,1,0,1,0], net))

[(2, 1, '&'), (1, 2, '|'), (4, 3, '&'), (3, 4, '|'), (5, 6, '&'), (6, 5, '|')]
[(0, 2, '&'), (2, 0, '|'), (5, 3, '&'), (3, 5, '|'), (6, 4, '&'), (4, 6, '|')]
[(1, 2, '&'), (2, 1, '|'), (4, 3, '&'), (3, 4, '|'), (6, 5, '&'), (5, 6, '|')]
[(4, 0, '|'), (5, 1, '|'), (6, 2, '|')]
[(3, 5, '&'), (4, 6, '&')]
[(3, 4, '&')]
3
24
1


In [15]:
net = majority(7, 'oddeven')
for lyr in net: print(lyr)
print(sum(len(lyr) for lyr in net[:-1]))
print(mapply([1,1,0,1,0,1,0], net))

[(0, 1, '&'), (1, 0, '|'), (2, 3, '&'), (3, 2, '|'), (4, 5, '&'), (5, 4, '|')]
[(0, 2, '&'), (2, 0, '|'), (1, 3, '&'), (3, 1, '|'), (4, 6, '&'), (6, 4, '|')]
[(1, 2, '&'), (2, 1, '|'), (5, 6, '&'), (6, 5, '|')]
[(4, 0, '|'), (5, 1, '|'), (2, 6, '&')]
[(4, 2, '|'), (3, 5, '&')]
[(3, 4, '&')]
3
22
1


In [16]:
net = majority(7, 'pairwise')
for lyr in net: print(lyr)
print(sum(len(lyr) for lyr in net[:-1]))
print(mapply([1,1,0,1,0,1,0], net))

[(0, 4, '&'), (4, 0, '|'), (1, 5, '&'), (5, 1, '|'), (2, 6, '&'), (6, 2, '|')]
[(0, 2, '&'), (2, 0, '|'), (1, 3, '&'), (3, 1, '|'), (4, 6, '&'), (6, 4, '|')]
[(1, 0, '|'), (2, 3, '&'), (3, 2, '|'), (4, 5, '&'), (5, 4, '|')]
[(2, 1, '|'), (5, 6, '&')]
[(4, 2, '|'), (3, 5, '&')]
[(3, 4, '&')]
3
22
1


In [17]:
for mode in 'bitonic','oddeven','pairwise':
    if mode != 'bitonic': print()
    print(mode)
    print('#inputs #layers #reduce #gates')
    #print('#inputs #layers #gates')
    for n in range(1,21):
        net = majority(n, mode)
        cnt = sum(len(lyr) for lyr in net[:-1])
        red = mreduce(net)
        print('%4d   %5d   %5d    %5d' % (n, len(net)-1, len(red)-1, cnt))
        #print('%4d   %5d    %5d' % (n, len(net)-1, cnt))
        for i in range(2**n):
            a = [(i >> k) & 1 for k in range(n)]
            m = mapply(a, net)
            if m != (sum(a) >= (n+1)//2): print(n,a)

bitonic
#inputs #layers #reduce #gates
   1       0       0        0
   2       1       1        1
   3       3       3        4
   4       3       3        7
   5       6       6       16
   6       6       6       21
   7       6       6       24
   8       6       6       31
   9      10      10       56
  10      10      10       63
  11      10      10       68
  12      10      10       79
  13      10      10       82
  14      10      10       91
  15      10      10       98
  16      10      10      111
  17      15      15      176
  18      15      15      185
  19      15      15      192
  20      15      15      207

oddeven
#inputs #layers #reduce #gates
   1       0       0        0
   2       1       1        1
   3       3       3        4
   4       3       3        7
   5       5       5       12
   6       6       6       17
   7       6       6       22
   8       6       6       27
   9       9       9       40
  10      10      10       45
  11      10      10 

In [18]:
net = majority(1024, 'bitonic')
print(sum(len(lyr) for lyr in net[:-1]))
net = majority(1024, 'oddeven')
print(sum(len(lyr) for lyr in net[:-1]))
net = majority(1024, 'pairwise')
print(sum(len(lyr) for lyr in net[:-1]))

47103
39931
34767
