import numpy as np
import matplotlib.pyplot as plt

# TODO import Advertising.csv using np.genfromtxt() as a numpy array X

# TODO extract columns from array X
# Column 0 - index
idx = X[:,0]
# Column 1 - TV
# tv = 
# Column 2 - radio
# radio = 
# Column 3 - newspaper
# news = 
# Column 4 - sales
# sales = 


# Exploratory data analysis
fig1 = plt.figure()

fig1.add_subplot(1,3,1)
plt.plot(tv, sales, 'ro')
plt.xlabel('TV budget')
plt.ylabel('sales')
plt.title('TV ads')

fig1.add_subplot(1,3,2)
# TODO plot sales against radio

fig1.add_subplot(1,3,3)
# TODO plot sales against newspaper

plt.show()
