Efficient finding primitive roots modulo n using Python?

Erba Aitbayev picture Erba Aitbayev · Oct 22, 2016 · Viewed 7.7k times · Source

I'm using the following code for finding primitive roots modulo n in Python:

Code:

def gcd(a,b):
    while b != 0:
        a, b = b, a % b
    return a

def primRoots(modulo):
    roots = []
    required_set = set(num for num in range (1, modulo) if gcd(num, modulo) == 1)

    for g in range(1, modulo):
        actual_set = set(pow(g, powers) % modulo for powers in range (1, modulo))
        if required_set == actual_set:
            roots.append(g)           
    return roots

if __name__ == "__main__":
    p = 17
    primitive_roots = primRoots(p)
    print(primitive_roots)

Output:

[3, 5, 6, 7, 10, 11, 12, 14]   

Code fragment extracted from: Diffie-Hellman (Github)


Can the primRoots method be simplified or optimized in terms of memory usage and performance/efficiency?

Answer

kasravnd picture kasravnd · Oct 22, 2016

One quick change that you can make here (not efficiently optimum yet) is using list and set comprehensions:

def primRoots(modulo):
    coprime_set = {num for num in range(1, modulo) if gcd(num, modulo) == 1}
    return [g for g in range(1, modulo) if coprime_set == {pow(g, powers, modulo)
            for powers in range(1, modulo)}]

Now, one powerful and interesting algorithmic change that you can make here is to optimize your gcd function using memoization. Or even better you can simply use built-in gcd function form math module in Python-3.5+ or fractions module in former versions:

from functools import wraps
def cache_gcd(f):
    cache = {}

    @wraps(f)
    def wrapped(a, b):
        key = (a, b)
        try:
            result = cache[key]
        except KeyError:
            result = cache[key] = f(a, b)
        return result
    return wrapped

@cache_gcd
def gcd(a,b):
    while b != 0:
        a, b = b, a % b
    return a
# or just do the following (recommended)
# from math import gcd

Then:

def primRoots(modulo):
    coprime_set = {num for num in range(1, modulo) if gcd(num, modulo) == 1}
    return [g for g in range(1, modulo) if coprime_set == {pow(g, powers, modulo)
            for powers in range(1, modulo)}]

As mentioned in comments, as a more pythoinc optimizer way you can use fractions.gcd (or for Python-3.5+ math.gcd).