# Emulation of Java style streams in Python,

class Stream:
    def __init__(self,generator):
        self.generator = generator
    # constructor

    ### stream qualification functions
    def limit(self,n):  # sets a finite limit to generated items
        gen = self.generator
        counter = 0
        def limited():
          nonlocal counter
          if counter >= n: return None
          else:
              counter += 1
              return gen()
        #
        self.generator = limited
        return self

    # this would not be type-safe in a better language
    def map(self, mapfun):
        #newseeds = [mapfun(x) for x in self.seeds]
        gen = self.generator
        def mapped():
            next = gen()
            if next == None: return None
            else: return mapfun(next)
        #
        self.generator = mapped
        return self

    # concat a (finite) stream and another stream
    def concat(self, other_stream):
        gen = self.generator()
        def merged():
            next = gen()
            if next == None: return other_stream.generator()
            else: return next
        self.generator = merged
        return self

    def flatten(self): # this assumes self is a stream of streams
        gen = self.generator
        current_stream = gen()
        def flattened():
            nonlocal current_stream
            if current_stream==None: return None
            next = current_stream.generator()
            while next==None:
                current_stream = gen()
                if current_stream == None: return None
                next = current_stream.generator()
            return next
        self.generator = flattened
        return self  # .map(..).flatten() is equivalent to "bind"

    def filter(self, predicate): # take only items satisfying predicate
        gen = self.generator
        def filtered():
            next = gen()
            while next!=None and not(predicate(next)):
                next = gen()
            return next
        self.generator = filtered
        return self

    def take_while(self, predicate): # stop streaming when predicate satisfied
        gen = self.generator
        def limited():
            next = gen()
            if next==None or not(predicate(next)): return None
            else: return next
        self.generator = limited
        return self

    def until(self, predicate): # includes last one that satisfy predicate
        gen = self.generator
        shouldstop = False
        def limited():
            nonlocal shouldstop
            if shouldstop: return None
            next = gen()
            if next==None: return None
            elif predicate(next):
                shouldstop = True
                return next
            else: return next
        self.generator = limited
        return self
    
    #### methods for running the stream
    def foreach(self, action):
        next = self.generator()
        while next!=None:
            action(next)
            next = self.generator()
    #foreach

    def for_all(self, predicate):
        next = self.generator()
        while next!=None:
            if not(predicate(next)): return False
            next = self.generator()
        return True
	
    def there_exists(self, predicate):
        return not(self.for_all(lambda x:not(predicate(x))))

    def nth(self, n):
        while n+1 > 0:   # first is 0th
            next = self.generator()
            if next==None: return None
            n -= 1
        return next
    
    def find_first(self, predicate):
        next = self.generator()
        while next!=None:
            if predicate(next): return next
            next = self.generator()
        return None

    def last(self):
        lastitem = self.generator()
        next = lastitem
        while next!=None:
            next = self.generator()
            if next!=None: lastitem = next
        return lastitem
             
    # apply left-associative operator to stream with left-identity id
    def reduce(self, id, operator):
        next = self.generator()
        while next!=None:
            id = operator(id,next)
            next = self.generator()
        return id

    def fold(self, operator):     # uses first in stream as accumulator
        ax = self.generator()
        if ax==None: return None
        else: return self.reduce(ax,operator)

    def yield_seeds(self):
        return self.seeds
#Stream

# global functions:

def finite_stream(seeds):
    index = 0
    def generator():
        nonlocal index
        if index < len(seeds):
            index += 1
            return seeds[index-1]
        else: return None
    return Stream(generator)

def coinduction(base, inducer):
    first = True
    def generator():
        nonlocal base, first
        if first: first = False
        else: base = inducer(base)
        return base
    return Stream(generator)

# "strong" is not mathematical: just need a different name
def seeded_coinduction(seeds,inducer):
    index = 0
    def generator():
        nonlocal index
        if index<len(seeds):
            index += 1
            return seeds[index-1]
        else:
            next = inducer(seeds)
            if next!=None:
              seeds.append(next)
              index += 1
            return next
    newstream = Stream(generator)
    newstream.seeds = seeds
    return newstream

## inductive_definition of a binary relation (n,m) such as
## (m is n!):
def inductive_definition(base_case, inductive_case):
    return coinduction(base_case, lambda ih:inductive_case(ih[0],ih[1]))
        

######################## testing

# the stream should be re-defined before each use:

def odds(): return coinduction(1, lambda x:x+2)

odds().limit(10).foreach(lambda x:print(x,end=" "))
print()

largest = finite_stream([5,8,2,1,9,4,3]).fold(max)
product = finite_stream([5,8,2,1,9,4,3]).reduce(1,lambda a,b:a*b)

print("largest:", largest)  # prints 9
print("product:", product)   
    
def Fibonacci():
    def next_two(last2):
       (fib_1, fib_2) = last2
       return (fib_2, fib_1 + fib_2)  # better syntax for pairs
    return coinduction( (0,1), next_two ).map(lambda p:p[1])

Fibonacci().until(lambda n:n>100).foreach(print)

# n! is revealing: must inductively define pair (n, n!)
def factorial(n):
    def next_pair(pair):
        (n, factn) = pair  
        return (n+1, (n+1)*factn)  # if factn is n! then (n+1)*factn is (n+1)!
    return coinduction((0,1),next_pair).map(lambda pair:pair[1]).nth(n)

print("6! is ", factorial(6))


# Primes

def nextprime(knownprimes):
    if len(knownprimes)==0: return 2
    lastprime = knownprimes[len(knownprimes)-1]
    if lastprime==2: return 3
    candidates = coinduction(lastprime+2, lambda n:n+2)
    return candidates\
           .find_first(lambda c:\
                       finite_stream(knownprimes)\
                       .take_while(lambda p: p < 1.0 + c**0.5)\
                       .for_all(lambda p: c%p != 0))

## need the \ to write expression on different lines

Primes = lambda known_primes: seeded_coinduction(known_primes, nextprime)

print("first 100 primes:")
primes100 = []
Primes([]).limit(100)\
          .foreach(primes100.append)
print(primes100)

def prime_factors(n):
    factors = []                 
    Primes(primes100).take_while(lambda p:p<=n)\
                     .filter(lambda p: n%p==0)\
                     .foreach(factors.append)
    return factors

print("\nprime factors of 3000:", prime_factors(3000))


def binary_factors(n):
    factors = []                     
    def next(pair):
        (factor, m) = pair
        return (factor*2, m//2)
    coinduction((1,n), next)\
        .until(lambda pair:pair[1] < 1)\
        .filter(lambda pair:pair[1] % 2 == 1)\
        .map(lambda pair:pair[0])\
        .foreach(factors.append)
    return factors
        
print("\nbinary factors of 1000:", binary_factors(1000))

# 2x2 matrix multiplication
def mmult(M, N):
   ((a, b,
     c, d), (q, r,
             s, t)) = (M, N)   # deconstruct input
   return (a*q+b*s, a*r+b*t,
           c*q+d*s, c*r+d*t)   # construct output
#mmult

#### Fibonacci matrix and identity matrix
FIB = (1, 1,
       1, 0)
IDM = (1, 0,
       0, 1)

# FIB**n gives (fib(n+1), fib(n), fib(n), fib(n-1))

# calculating M**n by binary factorization on n:

# easy case: define 2**n inductively on n.
# Base case: (0, 1) "2**0 is 1"
# Inductive case (n, m) -> (n+1, 2*m):
#  if m = 2**n then 2*m = 2*2**n = 2**(n+1):

def powers_of_2():
    return inductive_definition((0,1), lambda n,m:(n+1, 2*m))

# exponentiation m**n by binary factorization
import math
def nth_power(m,n, multiplier, identity):
    def next(tuple):
        (ax, factor, n) = tuple
        next_factor = multiplier(factor,factor)
        if n%2==1: return (multiplier(ax,factor), next_factor, n//2)
        else: return (ax, next_factor, n//2)
    return coinduction((identity, m, n), next)\
        .until(lambda p: p[2]==0)\
        .map(lambda p:p[0])\
        .last()
        
#nth_power

print( nth_power(2,15,(lambda x,y:x*y),1) )

def fastfib(n):
    return nth_power(FIB, n-1, mmult, IDM)[0]

print(fastfib(8))
print(fastfib(9))
print(fastfib(10))
print(fastfib(50))

