# Copyright (C) 2019 by Stefan Schubert
# https://www.tu-chemnitz.de/etit/proaut/en/team/stefanSchubert.html

# Project:
# https://www.tu-chemnitz.de/etit/proaut/ccnn

# If you use this source code in your work, please cite the following paper:
# Schubert, S., Neubert, P., Pöschmann, J. & Protzel, P. (2019) Circular Convolutional Neural
# Networks for Panoramic Images and Laser Data. In Proc. of Intelligent Vehicles Symposium (IV)

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
# or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more details.


print('''############################################################################################################################
Demo: Performance evaluation of CNN and CCNN on the circularly shifted MNIST dataset. The demo performs the following steps:
1) train CNN on MNIST training images
2) copy weights from CNN to CCNN (without retraining)
3) evaluate CNN and CCNN on all circularly shifted MNIST testing images
4) plot results

For more details see: www.tu-chemnitz.de/etit/proaut/ccnn
############################################################################################################################\n''')


# includes
import numpy as np

import keras
from keras.layers import Conv2D, Input, Activation, GlobalAveragePooling2D
from keras.datasets import mnist

from ccnn_layers import CConv2D

from matplotlib import pyplot as plt
plt.ion()


##### PARAMETERS ##############################################################
# choose number of kernels in each convolutional layer
nb_kernels = 4

# choose training data type: 'orig' for unshifted MNIST images; 'shift' for MNIST images with all possible horizontal shifts (i.e., 28 times more training images)
training_data = 'orig'

# choose number of epochs depending on training data type
if training_data == 'orig':
    epochs = 28
elif training_data == 'shift':
    epochs = 1


##### LOAD MNIST TRAINING AND TEST DATA #######################################
# load training and test data
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = np.expand_dims(x_train, -1)
if training_data == 'shift':
    x_train = np.concatenate([np.roll(x_train, i, axis=2)
                              for i in range(28)], axis=0)

y_train = keras.utils.to_categorical(y_train, 10)
if training_data == 'shift':
    y_train = np.matlib.repmat(y_train, 28, 1)

x_test = np.expand_dims(x_test, -1)
y_test = keras.utils.to_categorical(y_test, 10)


##### CREATE MODELS FOR CNN AND CCNN ##########################################


# function to create classification model in dependence of layer type (Conv2D or CConv2D)
def create_model(conv2d_layer):
    input1 = Input((28, 28, 1))

    x = conv2d_layer(nb_kernels, (3, 3),
                     activation='relu', padding='same')(input1)
    x = conv2d_layer(nb_kernels, (3, 3), strides=(2, 2),
                     activation='relu', padding='same')(x)
    x = conv2d_layer(nb_kernels, (3, 3),
                     activation='relu', padding='same')(x)
    x = conv2d_layer(nb_kernels, (3, 3), strides=(2, 2),
                     activation='relu', padding='same')(x)

    x = Conv2D(10, 1, activation='relu')(x)
    x = GlobalAveragePooling2D()(x)
    out = Activation('softmax')(x)

    return keras.models.Model(input1, out)


# create models
m_cnn = create_model(Conv2D)
m_ccnn = create_model(CConv2D)


##### TRAIN CNN ###############################################################
# train CNN
m_cnn.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy', ])
m_cnn.fit(x=x_train, y=y_train,
          validation_data=[x_test, y_test],
          batch_size=256,
          epochs=epochs)

##### WEIGHT TRANSFER CNN->CCNN ###############################################
m_ccnn.compile('adam',
               loss='categorical_crossentropy',
               metrics=['accuracy', ])
m_ccnn.set_weights(m_cnn.get_weights())

##### EVALUATE CNN & CCNN #####################################################
# evaluate trained CNN
accuracies_cnn = []
for shift in range(-14, 14):
    x_test_shift = np.roll(x_test, shift, axis=2)
    _, acc = m_cnn.evaluate(x_test_shift, y_test)
    accuracies_cnn.append(acc)

accuracies_cnn = np.array(accuracies_cnn)

# evaluate transfered CCNN
accuracies_ccnn = []
for i in range(-14, 14):
    x_test_shift = np.roll(x_test, i, axis=2)
    _, acc = m_ccnn.evaluate(x_test_shift, y_test)
    accuracies_ccnn.append(acc)

accuracies_ccnn = np.array(accuracies_ccnn)

##### PLOT CNN VS TRANSFERED CCNN #############################################
plt.plot(np.array(range(-14, 14)), accuracies_cnn)
plt.hold(True)
plt.plot(np.array(range(-14, 14)), accuracies_ccnn)

plt.grid()
plt.legend(['cnn', 'cnn -> ccnn (weight transfered ccnn)'])
plt.xlabel('horizontal shift [pixels]')
plt.ylabel('accuracy')
plt.title('Performance of CNN vs CCNN on circularly shifted MNIST dataset')
