# Using your Head is Permitted

## February 2012 solution

Consider the following method for calculating *A*^{i}:

def exponentiate(A,i):
"""Calculate A**i"""
if i==0:
return 1
if i==1:
return A
return exponentiate(A,i/2)*exponentiate(A,i/2+(i%2))

This method allows computing *A*^{i} in a number of multiplications
that is logarithmic in *i*. We will show how to implement each
multiplication in polynomial space by use of recursion: each computation will
utilize polynomial memory, and the depth of the recursion will be polynomial in
log(*i*), so the total memory requirements are polynomial.

Consider how multiplication is usually performed. The following code
multiplies *A* by *B*, writing out the result in *C*.
*A*, *B* and *C* are stored as arrays, where *A*[*x*]
is the bit of *A* whose value is 2^{x}. (By convention, if
*x* is larger than the bit-length of *A*, we take *A*[*x*]
to be equal to 0.) As we are only
interested in what happens up to bit *n*, this code will compute *C*
only up to bit *n*. In essence, the code performs
"mod 2^{n+1}" multiplication.

def mult_mod(A,B,C,n):
"""Return in C the value of (A*B) % 2**(n+1)"""
acc=0
for counter in range(n+1):
acc/=2
k=counter
for j in range(counter+1):
acc+=A[j]*B[k]
k-=1
C[counter]=acc % 2

In this code, *acc* is used both to compute the next digit and the carry.
The carry itself is logarithmic in *n*, so can be stored.

For our application, *A*, *B* an *C* may be too large to store,
so we will not store them. Specifically, we will only execute the line
"C[counter]=acc % 2" when *counter* equals *n*.

The actual code will therefore look more like this:

def mult_digit(A,B,n):
"""Return the n'th digit of A*B"""
acc=0
for counter in range(n+1):
acc/=2
k=counter
for j in range(counter+1):
acc+=A[j]*B[k]
k-=1
return acc % 2

Combining both of the above ideas together, we get our finalized code. Here,
*A* is treated as an array, but, for better readability,
*n* and *i* are treated as integers.

def exp_digit(n,A,i):
"""A function that calculates the n'th digit of A**i"""
if i==0:
return (n==0)
if i==1:
return A[n] # Zero, if n>bit-length of A.
acc=0
for counter in range(n+1):
acc/=2
k=counter
for j in range(counter+1):
acc+=exp_digit(j,A,i/2)*exp_digit(k,A,i/2+(i%2))
k-=1
return acc % 2

Back to riddle

Back to main page