%========================================================================
%Fractional gray value Cahn-Hilliard inpainting in 3D, Version 1.0
%Copyright(c) 2015 Jessica Bosch
%All Rights Reserved.
%
%This is an implementation of a fractional inpainting model based on the 
%vector-valued Cahn--Hilliard equation in three spatial dimensions.
%Please refer to the following paper:
%
%J. Bosch, and M. Stoll, "A fractional inpainting model based on the 
%vector-valued Cahn--Hilliard equation" 
%Preprint MPIMD/15-04, Max Planck Institute Magdeburg, March 2015.
%
%This is an example code for the three-dimensional visualization of medical
%images (MRI), see Section 6.4 in the paper above. We aim to create virtual 
%slices between the given slices of the MRI data set that comes with 
%MATLAB.
%
%Kindly report any suggestions or corrections to bosch@mpi-magdeburg.mpg.de

close all;
clear all;

format long g;

%%
start = tic;        % computational costs for the whole computation

%% read initial data:
% given_slices = number of given slices
% Ngap = 2 + number of virtual slices between two given slices
% N1 = number of mesh points in x-direction
% N2 = number of mesh points in y-direction
% N3 = number of mesh points in z-direction
%    = (given_slices-1)*(Ngap-2)+given_slices
% nphase = number of phases (gray values)
% gray_values = vector (size: nphase x 1) that contains the gray values
% orig_im = given original MRI data set (via "load mri") 
%           (size: N1 x N2 x 27)
% cluster_im = orig_im after standard k-means clustering into nphase 
%              clusters (size: N1 x N2 x 27)
% u_0 = initial phase variable (size: nphase x N1 x N2 x N3)
% D = damaged parts (size: N1 x N2 x N3)
%     D(i,j,k)=0 in the given slices
%     D(i,j,k)=1 in the virtual slices
% map = the map obtained via "load mri" for image illustrations
%%%
load MRI_data.mat
            
%% model parameters
      
tau=1;              % time step size
omega=10^(9);       % fidelity parameter
C2=3*omega;         % constant for the convexity splitting
zeta=1.8;           % fractional power

mesh1=1/N1;         % mesh size in x-direction
mesh2=1/N2;         % mesh size in y-direction
mesh3=2.5*mesh1/4;  % mesh size in z-direction  

a=0;                % left boundary node
bx=N1*mesh1;        % right boundary node in x-direction
by=N2*mesh2;        % right boundary node in y-direction
bz=N3*mesh3;        % right boundary node in z-direction

Lx=bx-a;
Ly=by-a;
Lz=bz-a;

%   epsilon = interface parameter for the 1st simulation - smoothing
%             (relative large value)
%   epsilon_small = interface parameter for the 2nd simulation - sharpening
%                    (smalle value, proportional to the mesh size)
epsilon=1000;
epsilon_small=mesh1;
C1=3/epsilon;       % C1 = constant for the convexity splitting

%% matrices

% derivative of the smooth potential
W=zeros(nphase,N1,N2,N3);

% current solution
u_neu=zeros(nphase,N1,N2,N3);

% solution from the previous time step
u_alt=u_0;
      
% 1D wavenumbers
[kx,ky,kz]=meshgrid((0:N2-1)*pi/Ly,(0:N1-1)*pi/Lx,(0:N3-1)*pi/Lz);

% 3D eigenvalues of the fractional Neumann Laplacian
M=(kx.^2+ky.^2+kz.^2).^(zeta/2);

%% start the iteration

it = 0;                 % time step
eps_switch=0;           % marker for the switch from epsilon to 
                        % epsilon_small

energy_tol_1=2e-4;      % tolerance for 1st stopping criterion
energy_tol_2=2e-5;      % tolerance for 2nd stopping criterion

norm_sol=1;             % 1st term for the norm in the stopping criterion
aux_var=0;              % 2nd term for the norm in the stopping criterion

while  (eps_switch<0.5 || (eps_switch>0.5 && sqrt(norm_sol)>energy_tol_2*sqrt(aux_var)))
      
   % switch to epsilon_small
    if ((sqrt(norm_sol)<energy_tol_1*sqrt(aux_var)) && (eps_switch<0.5))
        epsilon=epsilon_small;
        eps_switch=1;  
        C1=3/epsilon;
    end
              
    
   % calculate the derivative of the smooth potential
    for j=1:nphase 
        W(j,:,:,:)=u_alt(j,:,:,:).^(3)-(3/2)*u_alt(j,:,:,:).^(2)+0.5*u_alt(j,:,:,:);
    end      
 
    W_sum=zeros(1,N1,N2,N3);
    for j=1:nphase 
        W_sum = W_sum + W(j,:,:,:);
    end
    
    for j=1:nphase 
        W(j,:,:,:)=W(j,:,:,:)-(1/nphase)*W_sum;
    end
        
    
   % calculate the fidelity term 
    tmp=u_0;
    tmp2=u_alt;      

    for k=1:N1
        for l=1:N2
            for s=1:N3
                if (D(k,l,s)>0.5)
                    tmp(:,k,l,s)=0;
                    tmp2(:,k,l,s)=0;
                end
            end
        end
    end
    
    P=omega*tmp+C2*u_alt-omega*tmp2;
    
    
   % compute the current solution using the DCT
    for j=1:nphase 
        u_neu(j,:,:,:)=idctn(squeeze((M.*((-1/epsilon)*dctn(squeeze(W(j,:,:,:)))+C1*dctn(squeeze(u_alt(j,:,:,:))))+dctn(squeeze(P(j,:,:,:)))+(1/tau)*dctn(squeeze(u_alt(j,:,:,:))))./((1/tau)+epsilon*M.^(2)+C1*M+C2)));
    end    
    
    
   % compute the norm for the stopping criterion
    norm_sol=0;
    aux_var=0;
    for j=1:nphase  
        norm_sol=norm_sol+norm(arrayfun(@(idx) norm(squeeze(u_alt(j,:,:,idx)-u_neu(j,:,:,idx)),2), 1:size(u_alt,4)),2)^(2);
	    aux_var=aux_var+norm(arrayfun(@(idx) norm(squeeze(u_alt(j,:,:,idx)),2), 1:size(u_alt,4)),2)^(2);
    end
      
    
   % update the old solution 
    u_alt=u_neu;
       
   % update the time step 
    it = it+1;    
       
end

%%
% create the computed solution image data
final_im=imadd(gray_values(1)*squeeze(u_alt(1,:,:,:)),gray_values(2)*squeeze(u_alt(2,:,:,:)));
for k=3:nphase
    final_im=imadd(final_im,gray_values(k)*squeeze(u_alt(k,:,:,:)));
end
save(sprintf('MRI-solution-%d', it),'final_im','it');

%% save figures of the reconstructed slices

% reconstruction of the lower given slice (slice 1)
h1=figure(1);
set(h1,'visible','off');
colormap(map)
image(final_im(:,:,1))
axis image
set(gca,'XTickLabel',[])
set(gca,'YTickLabel',[]) 
saveas(h1,sprintf('CH_slice1-%d.fig', it));
saveas(h1,sprintf('CH_slice1-%d.eps', it));
close(h1);       

% virtual slice 1.2
h2=figure(2);
set(h2,'visible','off');
colormap(map)
image(final_im(:,:,2))
axis image
set(gca,'XTickLabel',[])
set(gca,'YTickLabel',[]) 
saveas(h2,sprintf('CH_slice1p2-%d.fig', it));
saveas(h2,sprintf('CH_slice1p2-%d.eps', it));
close(h2);  

% virtual slice 1.4
h3=figure(3);
set(h3,'visible','off');
colormap(map)
image(final_im(:,:,3))
axis image
set(gca,'XTickLabel',[])
set(gca,'YTickLabel',[]) 
saveas(h3,sprintf('CH_slice1p4-%d.fig', it));
saveas(h3,sprintf('CH_slice1p4-%d.eps', it));
close(h3);    

% virtual slice 1.6
h4=figure(4);
set(h4,'visible','off');
colormap(map)
image(final_im(:,:,4))
axis image
set(gca,'XTickLabel',[])
set(gca,'YTickLabel',[]) 
saveas(h4,sprintf('CH_slice1p6-%d.fig', it));
saveas(h4,sprintf('CH_slice1p6-%d.eps', it));
close(h4);    

% virtual slice 1.8
h5=figure(5);
set(h5,'visible','off');
colormap(map)
image(final_im(:,:,5))
axis image
set(gca,'XTickLabel',[])
set(gca,'YTickLabel',[]) 
saveas(h5,sprintf('CH_slice1p8-%d.fig', it));
saveas(h5,sprintf('CH_slice1p8-%d.eps', it));
close(h5);    

% reconstruction of the upper given slice (slice 2)
h6=figure(6);
set(h6,'visible','off');
colormap(map)
image(final_im(:,:,6))
axis image
set(gca,'XTickLabel',[])
set(gca,'YTickLabel',[]) 
saveas(h6,sprintf('CH_slice2-%d.fig', it));
saveas(h6,sprintf('CH_slice2-%d.eps', it));
close(h6);            


%% print results:
% used model parameters
% CPU = CPU time
% PSNR and SSIM value regarding slice 1

fileID = fopen('MRI_results.txt','a');
fprintf(fileID,'iter %d CPU %f norm_stop %e norm_U %e dofs %d omega %f eps %f eps_small %f tau %f C1 %f C2 %f phases %d frac %e\n', it, toc(start), sqrt(norm_sol)/sqrt(aux_var), norm_sol, nphase*N1*N2*N3, omega, epsilon, epsilon_small, tau, C1, C2,nphase,zeta);
fprintf(fileID,'PSNR %f\n', 20*log10(max(max(abs(orig_im(:,:,1))))/(sqrt(mean(mean((orig_im(:,:,1)-final_im(:,:,1)).^2))))));
fprintf(fileID,'SSIM %f\n', ssim(orig_im(:,:,1),final_im(:,:,1), [0.01 0.03], fspecial('gaussian', 11, 1.5), 1));
fclose(fileID); 