% Run fast and memory efficient graph optimization via ICM (Iterated
% Conditional Modes)
% 
% S ... pairwise similarity matrix
% S_DB ... pairwise intra-database similarity matrix
% S_Q ... pairwise intra-query similarity matrix
% setup ... determines the information that has to be leveraged. Valid
% options are ('DB', 'DB-Q', 'DB-Q-Seq')
% f_excl ... choose between multiplicative or minimum-based cost function
% for the factors f^DB_excl and f^Q_excl. Valid options are ('mul', 'min')
%
%
%   =====================================================================
%   Copyright (C) 2021  Stefan Schubert, stefan.schubert@etit.tu-chemnitz.de
%   
%   This program is free software: you can redistribute it and/or modify
%   it under the terms of the GNU General Public License as published by
%   the Free Software Foundation, either version 3 of the License, or
%   (at your option) any later version.
%   
%   This program is distributed in the hope that it will be useful,
%   but WITHOUT ANY WARRANTY; without even the implied warranty of
%   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
%   GNU General Public License for more details.
%   
%   You should have received a copy of the GNU General Public License
%   along with this program.  If not, see <http://www.gnu.org/licenses/>.
%   =====================================================================
% 
function S_icm = run_ICM(S, S_DB, S_Q, setup, f_excl)

  % parameters
  params.f_excl = f_excl;
  
  params.w1_db = 1;
  params.w2_db = 20;
  if strcmp(setup, "DB")
    params.w1_q = 0;
    params.w2_q = 0;
    params.w3 = 0;
  elseif strcmp(setup, "DB-Q")
    params.w1_q = 1;
    params.w2_q = 20;
    params.w3 = 0;
  elseif strcmp(setup, "DB-Q-Seq")
    params.w1_q = 1;
    params.w2_q = 20;
    if strcmp(f_excl, 'mul')
      params.w3 = 0.5;
    elseif strcmp(f_excl, 'min')
      params.w3 = 0.2;
    end
  else
    error(["Chosen 'setup' (", setup, ") does not exist. Please choose between 'DB', 'DB-Q' and 'DB-Q-Seq'"]);
  end
  
  % normalize similarity matrices
  S = norm_S(S);
  S_DB = norm_S(S_DB);
  S_Q = norm_S(S_Q);
  
  % run ICM
  S_icm = ICM(S, S_DB, S_Q, params);
end

%% normalize similarity matrix to [0,1]
function S = norm_S(S)
  S = (S - min(S(:))) / (max(S(:)) - min(S(:)));
  S = single(S);
end

%%
function S_icm = ICM(S, S_DB, S_Q, params)

  % determine size of database (M) and query (N)
  M = size(S,1);
  N = size(S,2);
  
  % normalization
  w1_db = params.w1_db * 2/(M-1);
  w2_db = params.w2_db * 2/(M-1);
  w1_q = params.w1_q * 2/(N-1);
  w2_q = params.w2_q * 2/(N-1);
  w3 = params.w3;
  
  %% start optimization
  reset_run = true;
  while reset_run == true
    reset_run = false;
    
    %% init
    % initialize nodes
    S_icm = S;
    
    % allocate memory
    w1sum_q_all = zeros(size(S_icm), 'single');
    w2sqsum_q_all = zeros(size(S_icm), 'single');
    Seq = zeros(size(S_icm), 'single');

    %% ICM iterations
    for iter = 1:200
      % store optimization S_(t-1) for next iteration
      S_icm_last = S_icm;
      
      % invole sequence-based method if desired
      if w3 > 0
        Seq = SeqConv(S_icm_last, 11);
      end

      % compute a and b
      if strcmp(params.f_excl, 'mul') % if f_excl is (s1*s2)^2
        w1sum_db_all = w1_db*2 * S_DB * S_icm_last - w1_db*2 * S_icm_last;
        w2sqsum_db_all = w2_db * (1-S_DB) * S_icm_last.^2;

        if w1_q > 0
          w1sum_q_all = w1_q*2 * S_icm_last * S_Q - w1_q*2 * S_icm_last;
          w2sqsum_q_all = w2_q * S_icm_last.^2 * (1-S_Q);
        end
      
      elseif strcmp(params.f_excl, 'min') % if f_excl is min(s1,s2)^2
        w1sum_db_all = w1_db*2 * S_DB * S_icm_last - w1_db*2 * S_icm_last;
        w2sqsum_db_all = zeros(size(S_icm), 'single');
        for j = 1:N
          res = sum(single(S_icm(:,j) < S_icm(:,j)') .* (1-S_DB), 2);
          w2sqsum_db_all(:,j) =  w2_db * res;
        end

        if w1_q > 0
          w1sum_q_all = w1_q*2 * S_icm_last * S_Q - w1_q*2 * S_icm_last;
          w2sqsum_q_all = zeros(size(S_icm), 'single');
          for i = 1:M
            res = sum(single(S_icm(i,:)' > S_icm(i,:)) .* (1-S_Q), 1);
            w2sqsum_q_all(i,:) =  w2_q * res;
          end
        end
      end
      
      b = -2*S + w3*(-2)*Seq - w1sum_db_all - w1sum_q_all;
      a = 1 + w3 + (w1_db*(sum(S_DB,2)-1) + w2sqsum_db_all) + (w1_q*(sum(S_Q,1)-1) + w2sqsum_q_all);

      % compute arguments for minimum of current ICM-iteration
      S_icm = -b./(2*a);

      % check for convergence or divergence
      if max(abs(S_icm_last-S_icm), [], 'all') < 1e-4 ...
        || max(abs(S_icm_last-S_icm), [], 'all') > 2
        break
      end
    end
    
    % repeat ICM in case of divergence
    if max(abs(S_icm_last-S_icm), [], 'all') > 2 ...
       && w3 > 0
      w3 = max(w3 - 0.1, 0);
      reset_run = true;
    end
  end
end
