import math
from math import sin, cos
import numpy as np
from numpy import power as pw
import util
import pylab

class TemporalPlot:
    def __init__(self, resolution):
        pylab.ion()
        self.x     = range(resolution)
        self.y     = np.zeros((1,resolution)).tolist()[0]
        self.line, = pylab.plot(self.x, self.y)
        self.max   = 10.0
        self.min   = 0.0
        pylab.ylim(self.min, self.max)

    def add_point(self, point):
        self.y = self.y[1:]
        self.y.append(point)
        #pylab.plot(self.x, self.y)
        self.line.set_ydata(self.y)
        self.max = np.mean(self.y)
        #print 'limit is', self.max, self.min
        pylab.ylim(self.min, self.max)
        pylab.draw()
    
GRAVITY   = 9.8
INPUT_VAR = 9
X_STAR    = [.5,      0.7, 0.8, 0.5,      0.7, 0.8, 0.2, 1.0, 0.7, 0.8, 0.2]

X_STAR_A  = np.array(X_STAR)
plotter   = TemporalPlot(100)

##
# Three link planar robot model taken from page 172 of Murry, Li, Sastry
def f(constants, x):
    if x.__class__ != np.matrix:
        x = np.matrix(x).T
    theta1   = constants[0,0]
    theta2   = constants[1,0]
    theta3   = constants[2,0]
    theta    = constants[0:3,0]
             
    dtheta1  = constants[3,0]
    dtheta2  = constants[4,0]
    dtheta3  = constants[5,0]
    dtheta   = constants[3:6,0]

    ddtheta1 = constants[6,0]
    ddtheta2 = constants[7,0]
    ddtheta3 = constants[8,0]
    ddtheta  = constants[6:9,0]

    #Parameters for link 1
    r1 = x[0,0]
    #l1 = x[1,0] 
    l1 = r1*2.0
    m1 = x[1,0]
    w1 = x[2,0]

    #Parameters for link 2
    r2 = x[3,0] 
    l2 = r2*2.0
    #l2 = x[5,0]
    m2 = x[4,0] 
    w2 = x[5,0]
    h2 = x[6,0]

    #Parameters for link 3
    l3 = x[7,0]
    m3 = x[8,0] 
    w3 = x[9,0]
    h3 = x[10,0]

    #Calculate inertia matrices
    Ix2 = (m2/12) * (pw(w2,2) + pw(h2,2))
    Ix3 = (m3/12) * (pw(w3,2) + pw(h3,2))

    Iy2 = (m2/12) * (pw(l2,2) + pw(h2,2))
    Iy3 = (m3/12) * (pw(l3,2) + pw(h3,2))

    Iz1 = (m1/12) * (pw(l1,2) + pw(w1,2))
    Iz2 = (m2/12) * (pw(l2,2) + pw(w2,2))
    Iz3 = (m3/12) * (pw(l3,2) + pw(w3,2))

    s2  = sin(theta2)
    s23 = sin(theta2+theta3)

    c2  = cos(theta2)
    c23 = cos(theta2+theta3)

    c3  = cos(theta3)
    s3  = sin(theta3)

    M11 = Iy2*pw(s2,2) + Iy3*pw(s23, 2) + Iz1 + Iz2*pw(c2,2) + Iz3*pw(c23,2) + m2*pw(r1,2)*pw(c2,2) + m3*pw(l1*c2 + r2*c23, 2)
    M12 = 0
    M13 = 0
    
    M21 = 0
    M22 = Ix2 + Ix3 + m3*pw(l1, 2) + m2*pw(r1, 2) + m3*pw(r2, 2) + 2*m3*l1*r2*c3 
    M23 = Ix3 + m3*pw(r2, 2) + m3*l1*r2*c3
    
    M31 = 0
    M32 = Ix3 + m3*pw(r2, 2) + m3*l1*r2*c3
    M33 = Ix3 + m3*pw(r2, 2)

    M = np.matrix([[M11, M12, M13],
                   [M21, M22, M23],
                   [M31, M32, M33]])

    G112 = (Iy2-Iz2-m2*pw(r1,2))*c2*s2 + (Iy3-Iz3)*c23*s23 - m3*(l1*c2+r2*c23)*(l1*s2+r2*s23)
    G113 = (Iy3-Iz3)*c23*s23 - m3*r2*s23*(l1*c2+r2*c23)

    G121 = (Iy2-Iz2-m2*pw(r1,2))*c2*s2 + (Iy3-Iz3)*c23*s23 - m3*(l1*c2+r2*c23)*(l1*s2+r2*s23)

    G131 = (Iy3-Iz3)*c23*s23 - m3*r2*s23*(l1*c2+r2*c23)

    G211 = (Iz2-Iy2+m2*pw(r1,2))*c2*s2 + (Iz3-Iy3)*c23*s23 + m3*(l1*c2 + r2*c23)*(l1*s2 + r2*s23)

    G223 = -l1*m3*r2*s3

    G232 = G223 
    G233 = G223

    G311 = (Iz3 - Iy3)*c23*s23 + m3*r2*s23*(l1*c2 + r2*c23)
    G322 = -G223

    C = np.matrix([[G112 + G112,   G121,     G131],
                   [G211,          G223,     G232+G233],
                   [G311,          G322,     0]])

    N = np.matrix([0, -(m2*GRAVITY*r1 + m3*GRAVITY*l1)*c2 - m3*r2*c23, -m3*GRAVITY*r2*c23]).T

    return M*ddtheta + C*dtheta + N

##
# Returns a list of angles, torques pairs
def get_samples(x, num_samples):
    samples = 2 * np.pi * np.matrix(np.random.random_sample((INPUT_VAR, num_samples)))
    return [(samples[:,i], f(samples[:,i], x)) for i in range(num_samples)]

##
# calculates least squares of || f(x) - t ||
def torque_objective(f, data, x):
    v = 0.0
    for angles, recorded_torque in data:
        predicted = f(angles, x)
        diff      = (recorded_torque - predicted)
        v         = v + (diff.T * diff)[0,0]
    print x.T
    print '       >>>> objective', v, 'dist', np.linalg.norm(X_STAR - x)
    return np.array([v])

def scipy_objective(f, data, x):
    v = 0.0
    diffs = []
    for angles, recorded_torque in data:
        predicted = f(angles, x)
        diff      = (recorded_torque - predicted)
        diffs.append((diff.T * diff)[0,0])
    r = np.array(diffs)
    objective = np.sum(r)
    plotter.add_point(objective)
    print 'objective', objective, 'dist', np.linalg.norm(np.array(x) - X_STAR_A)
    return r

    #print x.T
    #print '       >>>> objective', v, 'dist', np.linalg.norm(X_STAR - x)
    #return ut.list_mat_to_mat(diffs, 1)
    #return np.array([v])

if __name__ == '__main__':
    import functools as ft
    import nqcg as cg
    import time
    import scipy.optimize
    import optparse

    p = optparse.OptionParser()
    p.add_option('-c', action='store_true', dest='conjug', help='whether to use nqcg optimizer')
    p.add_option('-s', action='store_true', dest='sample', help='tests sample complexity')
    p.add_option('-n', action='store_true', dest='noise', help='tests noise resitance')
    p.add_option('-r', action='store_true', dest='start', help='tests reliability of convergence wrt to starting estimate')
    opt, args   = p.parse_args()
    use_conjug  = opt.conjug
    test_sample = opt.sample
    test_noise  = opt.noise
    test_start  = opt.start

                #r1  l1   m1   w1   r2   l2   m2   w2   h2   l3   m3   w3   h3 #Old
                #r1       m1   w1   r2        m2   w2   h2   l3   m3   w3   h3 #New param
    #                             r1  l1   m1   w1   r2   l2   m2   w2   h2   l3   m3   w3   h3 
    # Initialization
    #               r1  l1   m1   w1   r2   l2   m2   w2   h2   l3   m3   w3   h3 
    x0        = [.6,      0.9, 0.6, 0.4,      0.9, 0.6, 0.3, 1.2, 0.5, 0.6, 0.4]
    if use_conjug:
        # Instatiate the solver with the objective function and initial value
        samples   = get_samples(X_STAR, 1000)
        x0        = np.matrix(x0).T
        objective_func   = ft.partial(torque_objective, f, samples)
        my_solver = cg.nonquadratic_conjugate_gradient_solver(x0, objective_function = objective_func )

        # Run Solver
        start  = time.time()
        x,fval = my_solver.run()
        print 'x* = ',x
        print 'f(x*) = ',fval
        print 'time = ',time.time()-start
    else:
        if test_sample:
            print '===================================================================='
            print 'Testing accuracy wrt to sample SIZE'
            print '===================================================================='
            num_experiments = 5

            answers = []
            for i in range(num_experiments):
                sample_size = i*500 + 500
                samples = get_samples(X_STAR, sample_size)
                start   = time.time()
                plsq    = scipy.optimize.leastsq(ft.partial(scipy_objective, f, samples), x0)
                answers.append((sample_size, plsq[0], np.linalg.norm(plsq[0] - X_STAR_A), time.time() - start))

            print '--------------------------------------------------------------------'
            print ' result of accuracy wrt to sample SIZE'
            print 'sample size, answer, norm, time'
            for sample_size, answer, norm, secs in answers:
                print 'sample size', sample_size,
                print 'answer', answer
                print 'norm', norm,
                print 'secs', secs
                print '       >> answer'
            print '--------------------------------------------------------------------'

        if test_noise:
            print '===================================================================='
            print 'Testing accuracy wrt to sample NOISE'
            print '===================================================================='
            num_experiments = 5
            answers         = []
            for i in range(num_experiments):
                samples  = get_samples(X_STAR, 1500)
                std      =  .1 + i*.1
                start   = time.time()
                for input, output in samples:
                    noise  = np.random.standard_normal((output.shape[0], output.shape[1]))
                    output = output + (noise * std * std)
                plsq    = scipy.optimize.leastsq(ft.partial(scipy_objective, f, samples), x0)
                answers.append((std, plsq[0], np.linalg.norm(plsq[0] - X_STAR_A), time.time() - start))

            print '--------------------------------------------------------------------'
            print ' result of accuracy wrt to sample NOISE'
            print 'std, answer, norm, time'
            for std, answer, norm, time in answers:
                print '--------------------------------'
                print 'std', std
                print 'x_hat dist', norm
                print 'x_hat', answer 
                print 'time', time
            print '--------------------------------------------------------------------'

        if test_start:
            print '===================================================================='
            print 'Testing accuracy wrt to different STARTING POINTS'
            print '===================================================================='
            samples  = get_samples(X_STAR, 1500)
            num_experiments = 5
            answers         = []
            for i in range(num_experiments):
                x0      = (X_STAR_A + 5.0 * np.random.random_sample((1, len(X_STAR)))).tolist()[0]
                start   = time.time()
                plsq    = scipy.optimize.leastsq(ft.partial(scipy_objective, f, samples), x0)
                answers.append((np.linalg.norm(np.array(x0)-plsq[0]), plsq[0], np.linalg.norm(plsq[0] - X_STAR_A), x0, time.time() - start))

            print '--------------------------------------------------------------------'
            print ' result of accuracy wrt to STARTING POINTS'
            print 'start distance, answer, norm, x0, time'
            for start_dist, answer, norm, x0, time in answers:
                print '--------------------------------'
                print 'start_dist', start_dist
                print 'x_hat dist', norm
                print 'x_hat', answer
                print 'x0', x0
                print 'time', time
            print '--------------------------------------------------------------------'




























        #sample = False
        #if sample:
        #    x         = np.arange(0,6e-2,6e-2/30)  
        #    A,k,theta = 10, 1.0/3e-2, np.pi/6  
        #    y_true    = A*np.sin(2*np.pi*k*x+theta)  
        #    y_meas    = y_true + 2*np.random.randn(len(x))  
        #     
        #    def residuals(p, y, x):  
        #            A,k,theta = p  
        #            err = y - A*np.sin(2*np.pi*k*x+theta)  
        #            #print 'x is', x
        #            #print 'x class', x.__class__
        #            return err  
        #     
        #    p0 = [8, 1/2.3e-2, np.pi/3]
        #    print np.array(p0)  
        #    from scipy.optimize import leastsq  
        #    print 'p0', p0.__class__
        #    plsq = leastsq(residuals, p0, args=(y_meas, x))
        #    print 'plsq[0]                ',  plsq[0], plsq[0].__class__
        #    print 'np.array([A, k, theta])',  np.array([A, k, theta])  
        #else:

        #def my_f(x):
        #    x1 = x[0]
        #    x2 = x[1]
        #    return np.array([[x1*x2 + 10 + x2, x1*x2 + 10 + x2],
        #                     [x1*x2 + 10 + x2, x1*x2 + 10 + x2]])
        #plsq = scipy.optimize.leastsq(my_f, np.array([100.0, 300.0]))
