# RSA Preparation routines

import json

# Small list of Known Primes to start with:
KnownPrimes = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389, 397, 401, 409, 419, 421, 431, 433, 439, 443, 449, 457, 461, 463, 467, 479, 487, 491, 499, 503, 509, 521, 523, 541, 547, 557, 563, 569, 571]

def binsearch(x,A):  # binary search for x in SORTED array A
    min,max = 0,len(A)-1
    ax = False
    while ax==False and min<=max:
        mid = (min+max)//2
        if A[mid]==x: ax = True
        elif A[mid]>x: max = mid-1
        else: min =mid+1
    #while
    return ax
#binsearch

def isprime(n):  # is n a prime (true or false)
    if n==2: return True
    if (n<2 or n%2==0): return False
    ## simple cases done.
    kn = len(KnownPrimes)
    maxnp = KnownPrimes[kn-1]  # max known prime
    if (n<=maxnp): return n==maxnp or binsearch(n,KnownPrimes)
    ### start testing to see if some prime <=sqrt(n) divides it
    limit = int(n**0.5 + 1)  # max value need to test against
    # test against known primes first
    ki = 1  # starting index in KnownPrimes to test against
    candidate = KnownPrimes[ki]
    answer = True
    while candidate<=limit and answer:
        if n%candidate==0: answer = False
        # find next candidate
        if ki<kn-1:
            ki += 1
            candidate = KnownPrimes[ki]
        else:  # take next odd number as candidate
            candidate += 2
    # while
    return answer
#isprime

# extend KnownPrimes
def nextprime():
   maxnp = KnownPrimes[len(KnownPrimes)-1]
   c = maxnp+2
   while not(isprime(c)):
       c += 2
   return c
#nextprime

def extendKP(n):
    while n>0:
        p = nextprime()
        KnownPrimes.append(p)
        n -= 1
#extendKP


# writing primes to file, using json
def writeprimes(filename):
  fd = open(filename,"w")
  json.dump(KnownPrimes,fd)
  fd.close()
#writeprimes

def loadprimes(filename):
   fd = open(filename,"r")
   stringprimes = fd.read() # read entire file as a string
   fd.close()
   return json.loads(stringprimes)
#loadprimes


#extendKP(200000)
#print(KnownPrimes)
#writeprimes("knownprimes1.txt")

#KnownPrimes = loadprimes("knownprimes1.txt")
#print(len(KnownPrimes))
#print(KnownPrimes)
#exit()

### RSA - simplified example
# need primes p,q,e,d: (e*d) % ((p-1)(q-1)) == 1

### finding larger e,d, the larger and more randomly chosen, the better

# Better prime factorization of n, returns array of primes
def primefactors(n): # return the prime factors of n, recursive algorithm:
    if n<2: return []
    p,i = KnownPrimes[0],0 # i indexes KnownPrimes
    limit = int(n**0.5)+1 
    while p<limit and i<len(KnownPrimes) and n%p!=0:
        i += 1
        p = KnownPrimes[i]
    # while, at this point, either p is too large, or found prime factor
    if p>=limit: return [n] # n is itself prime
    elif i>=len(KnownPrimes): raise Exception("Known primes exhausted")
    else: return [p] + primefactors(n//p) # depth of recursion <= log_p(n)
#primefactors    

### non-recursive version of prime factorization
def primefactorization(n):
    maxkp = KnownPrimes[len(KnownPrimes)-1]  # largest known prime in database
    PF = []
    while n>1: # for every prime factor p of n found, set n=n//p
        i,p = 0,KnownPrimes[0]
        limit = int(n**0.5)+1  # square root of n  + 1
        while n%p!=0 and p<limit and p<=maxkp:
            i += 1
            if i<len(KnownPrimes): p = KnownPrimes[i]
        # inner while
        if n%p==0: # found p as a prime factor
            PF.append(p)
            n = n//p  # factorize the rest of n
        elif p>=limit:  # n is itself prime
            PF.append(n)
            n = 1 # ==n//n, stops outer loop
        else: raise Exception("Known primes exhausted")
        # the order in which these conditions are checked matters
    # outer while n>1
    return PF
#primefactorization, non-recursive
### This algorithm can still run for a long time if n is large enough


#print(primefactors(3*7*13*19*23*577*19))
#print(primefactors(4021)) # itself prime

# encrypt m: c=m**e % n, decrypt: m=c**d % n

def fastexp(x,p,n): # compute x**p % n in log(p) steps, avoids large ints
   ax = 1 # accumulator
   factor = x % n  # initially x**1
   while p>0:
       if (p%2==1): ax = (ax * factor) % (2*n)
       factor = (factor * factor)  % (2*n)
       p = p//2  # shift right
   #while
   return ax % n
#fastexp

encrypt1 = fastexp  # alias of fastexp, p,n is key

"""
i = 0
while i<256:
  c = encrypt1(i,e,n)
  c2 = decrypt1(c,d,n)
  print(i,":",c,":",c2)  # should print i,c,i
  if (i!=c2): print("OOOOPS")
  i += 1
#while
exit()   # passed test
"""
