import numpy as np
import matplotlib.pyplot as plt

# TODO Copy the function computeMSE from previous exercise or import it
# -------------------------------------------------- 
def computeMSE(y, fhatx):
    mse = np.mean(np.power(y-fhatx,2))
    return mse
# -------------------------------------------------- 

# Set variables
n = 200
ntrain = 100
eps = 1

# Define "unknown" function f
def f_func(x):
    return 1*np.sin(10*x) + 5*np.cos(3*x)

# Set random seed
np.random.seed(0)

# Generate uniformly distributed data samples over [0,1)
x = np.random.rand(n)

fx = f_func(x)

y = fx + eps * np.random.randn(n)

fig = plt.figure(1)
# Plot data points
fig.clf()
plt.plot(x, y, 'ko', label = 'data points')

# Plot population line (in general unknown)
xr = np.linspace(0,1,100)
plt.plot(xr, f_func(xr), 'r-', label = 'population line')

# Plot a regression line for a polynomial of degree 'deg'
deg = 20
beta = np.polyfit(x, y, deg)

plt.plot(xr, np.polyval(beta,xr), 'b--', label='regression line')

plt.legend()
plt.show()

# TODO split the dataset into a training and test set using np.split
#      - the training set should contain ntrain samples
#      - the test set should contain n - ntrain samples
# -------------------------------------------------- 
idx_train = np.arange(ntrain)

xtrain, xtest = np.split(x,[ntrain])
ytrain, ytest = np.split(y,[ntrain])
# -------------------------------------------------- 

# Initialize lists that contain test and training mean squared errors
MSEtrain = []
MSEtest = []

# Loop over different degrees
deg_range = range(1,20)
for i in deg_range:
    # TODO Fit polynomial of degree 'i'
    # -------------------------------------------------- 
    beta = np.polyfit(xtrain, ytrain, deg = i)
    # -------------------------------------------------- 
    
    # TODO append test and training mse to according list
    # -------------------------------------------------- 
    MSEtrain.append(computeMSE(ytrain, np.polyval(beta, xtrain)))
    MSEtest.append(computeMSE(ytest, np.polyval(beta, xtest)))
    # -------------------------------------------------- 

# TODO Plot the test and training mse's against the polynomial degree
fig = plt.figure(2)
fig.clf()
# -------------------------------------------------- 
plt.plot(deg_range, MSEtrain, 'r-', label = 'train MSE')
plt.plot(deg_range, MSEtest, 'b--', label = 'test MSE')
# -------------------------------------------------- 
plt.legend()
plt.xlabel("Polynomial degree")
plt.ylabel("MSE")
plt.show()
