import numpy as np
import matplotlib.pyplot as plt

# The following is only for preparing the exercise
# -------------------------------------------------- 

# Fix random seed
np.random.seed(0)

# Set number of data points
n = 20

# Generate some data
x = np.round(np.random.rand(n),2)

def f_func(x):
    return 5*x+4
fx = f_func(x)

# Pertube data
sigma = 0.1
y = np.round(fx + sigma * np.random.randn(n),2)

deg = 1
beta = np.polyfit(x, y, deg)

fhatx = np.polyval(beta, x)
# -------------------------------------------------- 

x = np.array([0.55, 0.72, 0.6 , 0.54, 0.42,
    0.65, 0.44, 0.89, 0.96, 0.38,
    0.79, 0.53, 0.57, 0.93, 0.07,
    0.09, 0.02, 0.83, 0.78, 0.87])

y = np.array([6.9 , 7.58, 7.03, 6.61, 5.84,
    7.32, 6.29, 8.38, 9.03, 5.75,
    7.95, 6.63, 7.  , 8.8 , 4.37,
    4.49, 4.01, 7.95, 7.87, 8.37])

fhatx = np.array([
       6.74792024, 7.61454115, 7.00280875, 6.69694254, 6.08521014,
       7.25769725, 6.18716554, 8.48116205, 8.83800595, 5.88129934,
       7.97138505, 6.64596484, 6.84987564, 8.68507285, 4.30099064,
       4.40294604, 4.04610214, 8.17529585, 7.92040735, 8.37920665])

# Define function for mean squared error
def computeMSE(y, fhatx):
    # TODO Do something to compute mse
    # -------------------------------------------------- 
    mse = np.mean(np.power(y-fhatx,2))
    # -------------------------------------------------- 
    return mse

print("MSE: ", computeMSE(y, fhatx))

# TODO Compute linear regression coefficients with function np.polyfit
# --------------------------------------------------
deg = 1
beta = np.polyfit(x, y, deg)
# -------------------------------------------------- 

# TODO Compute prediction of x by yourself, the values should coincide with those of fhatx (by at most 1e-8), you can use the function np.polyval
# --------------------------------------------------
myfhatx = np.polyval(beta, x)
fhatx - myfhatx
# -------------------------------------------------- 

# Plot the dataset as well as the regression line
fig = plt.figure()

# TODO Plot data points
# -------------------------------------------------- 
plt.plot(x, y, 'k+', label='data points')
# -------------------------------------------------- 

# TODO Plot least squares line
# -------------------------------------------------- 
xr = np.linspace(0,1,100)
plt.plot(xr, np.polyval(beta, xr), 'b--', label='least squares line')
# -------------------------------------------------- 
plt.legend()
plt.show()
