/*-------------------------------------------------------------------------------
 This file is part of unityForest.

 Copyright (c) [2014-2018] [Marvin N. Wright]
 Modifications and extensions by Roman Hornung

 This software may be modified and distributed under the terms of the MIT license.

 Please note that the C++ core of divfor is distributed under MIT license and the
 R package "unityForest" under GPL3 license.
 #-------------------------------------------------------------------------------*/

#include <set>
#include <Rcpp.h>
#include <math.h>
#include <algorithm>
#include <stdexcept>
#include <string>
#include <ctime>
#include <functional>
#include <thread>
#include <chrono>

#include "utility.h"
#include "Forest.h"
#include "DataChar.h"
#include "DataDouble.h"
#include "DataFloat.h"

//#include "debug_cp.h"

namespace unityForest
{

  Forest::Forest() : verbose_out(0), num_trees(DEFAULT_NUM_TREE), mtry(0), min_node_size(0), min_node_size_root(0), num_variables(0), num_independent_variables(
                                                                                                                                          0),
                     seed(0), dependent_varID(0), num_samples(0), prediction_mode(false), memory_mode(MEM_DOUBLE), sample_with_replacement(
                                                                                                                       true),
                     memory_saving_splitting(false), splitrule(DEFAULT_SPLITRULE), predict_all(false), keep_inbag(false), sample_fraction(
                                                                                                                              {1}),
                     holdout(false), prediction_type(DEFAULT_PREDICTIONTYPE), num_random_splits(DEFAULT_NUM_RANDOM_SPLITS), max_depth(
                                                                                                                                DEFAULT_MAXDEPTH),
                     max_depth_root(DEFAULT_MAXDEPTHROOT), num_cand_trees(DEFAULT_NUMCANDTREES), prop_var_root(0), prop_best_splits(DEFAULT_PROP_BEST_SPLITS), repr_tree_mode(false), repr_var_names{}, alpha(DEFAULT_ALPHA), minprop(DEFAULT_MINPROP), num_threads(DEFAULT_NUM_THREADS), data{}, overall_prediction_error(
                                                                                                                                                                                                                                                                                                      NAN),
                     importance_mode(DEFAULT_IMPORTANCE_MODE), progress(0)
  {
  }

  void Forest::initR(std::string dependent_variable_name, std::unique_ptr<Data> input_data, uint mtry, uint num_trees,
                     std::ostream *verbose_out, uint seed, uint num_threads, ImportanceMode importance_mode, uint min_node_size, uint min_node_size_root,
                     std::vector<std::vector<double>> &split_select_weights, const std::vector<std::string> &always_split_variable_names,
                     std::string status_variable_name, bool prediction_mode, bool sample_with_replacement,
                     const std::vector<std::string> &unordered_variable_names, bool memory_saving_splitting, SplitRule splitrule,
                     std::vector<double> &case_weights, std::vector<std::vector<size_t>> &manual_inbag, bool predict_all,
                     bool keep_inbag, std::vector<double> &sample_fraction, double prop_var_root, double alpha, double minprop, bool holdout,
                     PredictionType prediction_type, uint num_random_splits, bool order_snps, uint max_depth, uint max_depth_root, uint num_cand_trees, double prop_best_splits, bool repr_tree_mode, std::vector<std::string> repr_var_names)
  {

    this->verbose_out = verbose_out;

    // Call other init function
    init(dependent_variable_name, MEM_DOUBLE, std::move(input_data), mtry, "", num_trees, seed, num_threads,
         importance_mode, min_node_size, min_node_size_root, status_variable_name, prediction_mode, sample_with_replacement,
         unordered_variable_names, memory_saving_splitting, splitrule, predict_all, sample_fraction, prop_var_root, alpha, minprop,
         holdout, prediction_type, num_random_splits, order_snps, max_depth, max_depth_root, num_cand_trees, prop_best_splits, repr_tree_mode, repr_var_names);

    // Set variables to be always considered for splitting
    if (!always_split_variable_names.empty())
    {
      setAlwaysSplitVariables(always_split_variable_names);
    }

    // Set the variables used for the representative tree analysis:
    if (repr_tree_mode)
    {
      repr_vars.reserve(repr_var_names.size());
      for (auto &variable_name : repr_var_names)
      {
        size_t varID = data->getVariableID(variable_name);
        repr_vars.push_back(varID);
      }
    }

    // Set the variables eligible for splitting in the random tree roots:
    const std::vector<size_t> &noSplit = data->getNoSplitVariables();
    allowedVarIDs_.reserve(data->getNumCols() - noSplit.size());

    for (size_t j = 0; j < data->getNumCols(); ++j)
      if (std::find(noSplit.begin(), noSplit.end(), j) == noSplit.end())
        allowedVarIDs_.push_back(j);

    // Set split select weights
    if (!split_select_weights.empty())
    {
      setSplitWeightVector(split_select_weights);
    }

    // Set case weights
    if (!case_weights.empty())
    {
      if (case_weights.size() != num_samples)
      {
        throw std::runtime_error("Number of case weights not equal to number of samples.");
      }
      this->case_weights = case_weights;
    }

    // Set manual inbag
    if (!manual_inbag.empty())
    {
      this->manual_inbag = manual_inbag;
    }

    // Keep inbag counts
    this->keep_inbag = keep_inbag;
  }

  void Forest::init(std::string dependent_variable_name, MemoryMode memory_mode, std::unique_ptr<Data> input_data,
                    uint mtry, std::string output_prefix, uint num_trees, uint seed, uint num_threads, ImportanceMode importance_mode,
                    uint min_node_size, uint min_node_size_root, std::string status_variable_name, bool prediction_mode, bool sample_with_replacement,
                    const std::vector<std::string> &unordered_variable_names, bool memory_saving_splitting, SplitRule splitrule,
                    bool predict_all, std::vector<double> &sample_fraction, double prop_var_root, double alpha, double minprop, bool holdout,
                    PredictionType prediction_type, uint num_random_splits, bool order_snps, uint max_depth, uint max_depth_root, uint num_cand_trees,
                    double prop_best_splits, bool repr_tree_mode, std::vector<std::string> repr_var_names)
  {

    // Initialize data with memmode
    this->data = std::move(input_data);

    // Initialize random number generator and set seed
    if (seed == 0)
    {
      std::random_device random_device;
      random_number_generator.seed(random_device());
    }
    else
    {
      random_number_generator.seed(seed);
    }

// Set number of threads
uint hc = std::thread::hardware_concurrency();
if (hc == 0) {
  hc = 1;  // fallback if not detectable
}

if (num_threads == DEFAULT_NUM_THREADS) {
  this->num_threads = hc;
} else {
  // user requested a fixed number -> cap at available cores
  this->num_threads = std::min(num_threads, hc);
}

    // Set number of samples and variables
    num_samples = data->getNumRows();
    num_variables = data->getNumCols();

    // If the proportion of the variables randomyl sampled for each tree is not
    // specified, set that number to the square root of the number of variables divided by the number of variables:
    if (prop_var_root == 0)
    {
      double temp = sqrt((double)(num_variables - 1));
      prop_var_root = temp / (double)(num_variables - 1);
      // If prop_var_root is smaller than 0.1, set it to 0.1:
      if (prop_var_root < 0.1)
      {
        prop_var_root = 0.1;
      }
    }

    // Set member variables
    this->num_trees = num_trees;
    this->mtry = mtry;
    this->seed = seed;
    this->output_prefix = output_prefix;
    this->importance_mode = importance_mode;
    this->min_node_size = min_node_size;
    this->min_node_size_root = min_node_size_root;
    this->memory_mode = memory_mode;
    this->prediction_mode = prediction_mode;
    this->sample_with_replacement = sample_with_replacement;
    this->memory_saving_splitting = memory_saving_splitting;
    this->splitrule = splitrule;
    this->predict_all = predict_all;
    this->sample_fraction = sample_fraction;
    this->holdout = holdout;
    this->alpha = alpha;
    this->minprop = minprop;
    this->prediction_type = prediction_type;
    this->num_random_splits = num_random_splits;
    this->max_depth = max_depth;
    this->max_depth_root = max_depth_root;
    this->num_cand_trees = num_cand_trees;
    this->prop_var_root = prop_var_root;
    this->prop_best_splits = prop_best_splits;
    this->repr_tree_mode = repr_tree_mode;
    this->repr_var_names = repr_var_names;

    // Convert dependent variable name to ID
    if (!prediction_mode && !dependent_variable_name.empty())
    {
      dependent_varID = data->getVariableID(dependent_variable_name);
    }

    // Set unordered factor variables
    if (!prediction_mode)
    {
      data->setIsOrderedVariable(unordered_variable_names);
    }

    data->addNoSplitVariable(dependent_varID);

    // Set minimal node size of the tree roots
    if (min_node_size_root == 0)
    {
      min_node_size_root = DEFAULT_MIN_NODE_SIZE_ROOT;
    }

    initInternal(status_variable_name);

    num_independent_variables = num_variables - data->getNoSplitVariables().size();

    // Init split select weights
    split_select_weights.push_back(std::vector<double>());

    // Init manual inbag
    manual_inbag.push_back(std::vector<size_t>());

    // Check if mtry is in valid range
    if (this->mtry > num_variables - 1)
    {
      throw std::runtime_error("mtry can not be larger than number of variables in data.");
    }

    // Check if any observations samples
    if ((size_t)num_samples * sample_fraction[0] < 1)
    {
      throw std::runtime_error("sample_fraction too small, no observations sampled.");
    }

    // Permute samples for corrected Gini importance
    if (importance_mode == IMP_GINI_CORRECTED)
    {
      data->permuteSampleIDs(random_number_generator);
    }

    // Order SNP levels if in "order" splitting
    if (!prediction_mode && order_snps)
    {
      data->orderSnpLevels(dependent_variable_name, (importance_mode == IMP_GINI_CORRECTED));
    }
  }

  void Forest::run(bool verbose, bool compute_oob_error)
  {

    // CP();

    if (repr_tree_mode)
    {

      repr_trees();
    }
    else
    {
      if (prediction_mode)
      {

        // Rcpp::Rcout << "Line number: " << __LINE__ << std::endl << std::flush;

        if (verbose && verbose_out)
        {
          *verbose_out << "Predicting .." << std::endl;
        }
        predict();
      }
      else
      {
        if (verbose && verbose_out)
        {
          *verbose_out << "Growing trees .." << std::endl;
        }

        // CP();

        grow();

        // CP();

        if (verbose && verbose_out)
        {
          *verbose_out << "Computing prediction error .." << std::endl;
        }

        if (compute_oob_error)
        {
          computePredictionError();
        }

        if (importance_mode == IMP_PERM_BREIMAN || importance_mode == IMP_PERM_LIAW || importance_mode == IMP_PERM_RAW || importance_mode == MUWIMP_BOTH || importance_mode == MUWIMP_MULTIWAY || importance_mode == MUWIMP_DISCR)
        {
          if (verbose && verbose_out)
          {
            *verbose_out << "Computing permutation variable importance .." << std::endl;
          }
          computeUnityVIM();
        }
      }
    }
  }

  void Forest::grow()
  {

    // CP();

    // Create thread ranges
    equalSplit(thread_ranges, 0, num_trees - 1, num_threads);

    // CP();

    // Call special grow functions of subclasses. There trees must be created.
    growInternal();

    // CP();

    // Init trees, create a seed for each tree, based on main seed
    std::uniform_int_distribution<uint> udist;
    for (size_t i = 0; i < num_trees; ++i)
    {
      uint tree_seed;
      if (seed == 0)
      {
        tree_seed = udist(random_number_generator);
      }
      else
      {
        tree_seed = (i + 1) * seed;
      }

      // Get split select weights for tree
      std::vector<double> *tree_split_select_weights;
      if (split_select_weights.size() > 1)
      {
        tree_split_select_weights = &split_select_weights[i];
      }
      else
      {
        tree_split_select_weights = &split_select_weights[0];
      }

      // Get inbag counts for tree
      std::vector<size_t> *tree_manual_inbag;
      if (manual_inbag.size() > 1)
      {
        tree_manual_inbag = &manual_inbag[i];
      }
      else
      {
        tree_manual_inbag = &manual_inbag[0];
      }

      trees[i]->init(data.get(), mtry, prop_var_root, dependent_varID, num_samples, tree_seed, &deterministic_varIDs,
                     &split_select_varIDs, tree_split_select_weights, importance_mode, min_node_size, min_node_size_root, sample_with_replacement,
                     memory_saving_splitting, splitrule, &case_weights, tree_manual_inbag, keep_inbag, &sample_fraction, alpha,
                     minprop, holdout, num_random_splits, max_depth, max_depth_root, num_cand_trees, repr_vars);
    }

    // CP();

    // Init variable importance
    variable_importance.resize(num_independent_variables, 0);

    // Grow trees in multiple threads
    progress = 0;
#ifdef R_BUILD
    aborted = false;
    aborted_threads = 0;
#endif

    std::vector<std::thread> threads;
    threads.reserve(num_threads);

    // Initialize importance per thread
    std::vector<std::vector<double>> variable_importance_threads(num_threads);

    for (uint i = 0; i < num_threads; ++i)
    {
      if (importance_mode == IMP_GINI || importance_mode == IMP_GINI_CORRECTED)
      {
        variable_importance_threads[i].resize(num_independent_variables, 0);
      }
      threads.emplace_back(&Forest::growTreesInThread, this, i, &(variable_importance_threads[i]));
    }

    showProgress("Growing trees..", num_trees);
    for (auto &thread : threads)
    {
      thread.join();
    }

#ifdef R_BUILD
    if (aborted_threads > 0)
    {
      throw std::runtime_error("User interrupt.");
    }
#endif

    // Sum thread importances
    if (importance_mode == IMP_GINI || importance_mode == IMP_GINI_CORRECTED)
    {
      variable_importance.resize(num_independent_variables, 0);
      for (size_t i = 0; i < num_independent_variables; ++i)
      {
        for (uint j = 0; j < num_threads; ++j)
        {
          variable_importance[i] += variable_importance_threads[j][i];
        }
      }
      variable_importance_threads.clear();
    }

    // Divide importance by number of trees
    if (importance_mode == IMP_GINI || importance_mode == IMP_GINI_CORRECTED)
    {
      for (auto &v : variable_importance)
      {
        v /= num_trees;
      }
    }

    // CP();
  }

  void Forest::predict()
  {

    // Predict trees in multiple threads and join the threads with the main thread
    progress = 0;
#ifdef R_BUILD
    aborted = false;
    aborted_threads = 0;
#endif

    // Predict
    std::vector<std::thread> threads;
    threads.reserve(num_threads);
    for (uint i = 0; i < num_threads; ++i)
    {
      threads.emplace_back(&Forest::predictTreesInThread, this, i, data.get(), false);
    }
    showProgress("Predicting..", num_trees);
    for (auto &thread : threads)
    {
      thread.join();
    }

    // Aggregate predictions
    allocatePredictMemory();
    threads.clear();
    threads.reserve(num_threads);
    progress = 0;
    for (uint i = 0; i < num_threads; ++i)
    {
      threads.emplace_back(&Forest::predictInternalInThread, this, i);
    }
    showProgress("Aggregating predictions..", num_samples);
    for (auto &thread : threads)
    {
      thread.join();
    }

#ifdef R_BUILD
    if (aborted_threads > 0)
    {
      throw std::runtime_error("User interrupt.");
    }
#endif
  }

  // Perform covariate-representative tree (CRTR) analysis
  void Forest::repr_trees()
  {

    // CP();

    progress = 0;
#ifdef R_BUILD
    aborted = false;
    aborted_threads = 0;
#endif

    // First parallel pass: compute node-wise split criterion
    std::vector<std::thread> threads;
    threads.reserve(num_threads);

    for (uint i = 0; i < num_threads; ++i)
    {
      threads.emplace_back(
          &Forest::computeOOBSplitCriterionValuesInThread, this, i);
    }
    // CP();
    showProgress("Calculating split criterion values for variable importance..", num_trees);
    for (auto &th : threads)
    {
      th.join();
    }

    // CP();

    // Result of the above: For each internal node in the tree roots, the OOB split criterion value was calculated
    // and saved in a vector called split_criterion for each tree. This vector has a value for each node in
    // the tree, but has the value -1 for nodes which are not internal nodes of the tree roots or do not use variables in repr_vars.

    threads.clear();

#ifdef R_BUILD
    if (aborted_threads > 0)
    {
      throw std::runtime_error("User interrupt.");
    }
#endif

    // CP();

    // Second part (sequential): figure out the best splits per variable and fill in is_in_best:
    std::vector<std::vector<size_t>> bestTreesPerVariable(repr_vars.size());
    determineAndMarkBestOOBSplitsPerVariable(bestTreesPerVariable);

    // CP();
    //  For each bestTreesPerVariable[i], sort the trees indices:
    for (size_t i = 0; i < repr_vars.size(); ++i)
    {
      std::sort(bestTreesPerVariable[i].begin(), bestTreesPerVariable[i].end());
    }

    // CP();
    //  Calculate the vector of relative frequencies of the variables in the trees:
    std::vector<size_t> var_counts(num_variables, 0);

    // Loop over all trees and collect all splits per variable:
    for (size_t i = 0; i < num_trees; ++i)
    {
      trees[i]->countVariables(var_counts);
    }

    // CP();

    // CP();
    //  Divide the counts by the num over var_counts:
    std::vector<double> var_rel_freqs(num_variables, 0.0);
    // Compute the sum of the counts:
    double sum_counts = 0.0;
    for (size_t i = 0; i < num_variables; ++i)
    {
      sum_counts += var_counts[i];
    }
    // Compute the relative frequencies:
    for (size_t i = 0; i < num_variables; ++i)
    {
      if (sum_counts > 0)
      {
        var_rel_freqs[i] = static_cast<double>(var_counts[i]) / sum_counts;
      }
      else
      {
        var_rel_freqs[i] = 0.0;
      }
    }

    // CP();

    // Loop through rel_vars and determine the representative tree for each variable:
    std::vector<size_t> repr_trees_per_var(repr_vars.size(), 0);
    for (size_t i = 0; i < repr_vars.size(); ++i)
    {
      repr_trees_per_var[i] = getReprTree(bestTreesPerVariable[i]);
    }

    // CP();

    // For each variable in repr_vars calculate the vector of relative frequencies of the variables in the corresponding best trees:
    std::vector<std::vector<double>> var_rel_freqs_best_trees(repr_vars.size(), std::vector<double>(num_variables, 0.0));

    // Loop through rel_vars:
    for (size_t i = 0; i < repr_vars.size(); ++i)
    {
      // Loop through the trees in the best trees for the variable:
      std::vector<size_t> var_counts_best_trees(num_variables, 0);
      for (size_t j = 0; j < bestTreesPerVariable[i].size(); ++j)
      {
        trees[bestTreesPerVariable[i][j]]->countVariables(var_counts_best_trees);
      }

      // Divide the counts by the sum of the counts:
      double sum_counts = 0.0;
      for (size_t k = 0; k < num_variables; ++k)
      {
        sum_counts += var_counts_best_trees[k];
      }
      // Compute the relative frequencies:
      for (size_t k = 0; k < num_variables; ++k)
      {
        if (sum_counts > 0)
        {
          var_rel_freqs_best_trees[i][k] = static_cast<double>(var_counts_best_trees[k]) / sum_counts;
        }
        else
        {
          var_rel_freqs_best_trees[i][k] = 0.0;
        }
      }
    }

    // CP();

    // Calculate score vector by dividing var_rel_freqs_best_trees[i][k] by (var_rel_freqs_best_trees[i][k] + var_rel_freqs[k]) for each variable k:
    std::vector<std::vector<double>> score_vector(repr_vars.size(), std::vector<double>(num_variables, 0.0));
    for (size_t i = 0; i < repr_vars.size(); ++i)
    {
      for (size_t k = 0; k < num_variables; ++k)
      {
        if (var_rel_freqs[k] > 0)
        {
          score_vector[i][k] = var_rel_freqs_best_trees[i][k] / (var_rel_freqs_best_trees[i][k] + var_rel_freqs[k]);
        }
        else
        {
          score_vector[i][k] = 0.0;
        }
      }
    }

    // CP();

    // Set the score vector for each represenative tree:
    for (size_t i = 0; i < repr_vars.size(); ++i)
    {
      trees[repr_trees_per_var[i]]->setScoreVector(score_vector[i]);
    }

    // CP();
    //  Subset the trees to only keep the representative trees:
    keep_trees(repr_trees_per_var);

    // CP();
  }

  // Keep only the trees that are the representative trees
  // in the CRTR analysis.
  void Forest::keep_trees(const std::vector<size_t> &keep_idx)
  {
    /* 1. map original index → first position moved into 'ordered' */
    std::unordered_map<size_t, size_t> first_pos;
    std::vector<std::unique_ptr<Tree>> ordered;
    ordered.reserve(keep_idx.size());

    /* 2. walk through the user-supplied list in EXACT order */
    for (size_t pos : keep_idx)
    {
      if (pos >= trees.size())
        throw std::out_of_range("keep_trees(): index out of range");

      auto it = first_pos.find(pos);
      if (it == first_pos.end())
      {
        /* first time we see this tree → move it */
        first_pos[pos] = ordered.size();
        ordered.push_back(std::move(trees[pos]));
      }
      else
      {
        /* duplicate index → deep-copy the already-moved tree */
        ordered.push_back(ordered[it->second]->clone());
      }
    }

    /* 3. make the new order the active forest */
    trees.swap(ordered);
    num_trees = trees.size();
  }

  // Compute the OOB prediction error of the forest.
  void Forest::computePredictionError()
  {

    // CP();

    // Predict trees in multiple threads
    std::vector<std::thread> threads;
    threads.reserve(num_threads);
    progress = 0;
    for (uint i = 0; i < num_threads; ++i)
    {
      threads.emplace_back(&Forest::predictTreesInThread, this, i, data.get(), true);
    }
    showProgress("Computing prediction error..", num_trees);
    for (auto &thread : threads)
    {
      thread.join();
    }

#ifdef R_BUILD
    if (aborted_threads > 0)
    {
      throw std::runtime_error("User interrupt.");
    }
#endif

    // CP();

    // Call special function for subclasses
    computePredictionErrorInternal();
  }

  // Compute the unity VIM values.
  void Forest::computeUnityVIM()
  {

    // Compute tree permutation importance in multiple threads
    progress = 0;
#ifdef R_BUILD
    aborted = false;
    aborted_threads = 0;
#endif

    // First parallel pass: compute node-wise split criterion
    std::vector<std::thread> threads;
    threads.reserve(num_threads);

    for (uint i = 0; i < num_threads; ++i)
    {
      threads.emplace_back(
          &Forest::computeSplitCriterionValuesInThread, this, i);
    }
    showProgress("Calculating split criterion values for variable importance..", num_trees);

    for (auto &th : threads)
    {
      th.join();
    }

    // Result of the above: For each internal node in the tree roots, the split criterion value was calculated
    // and saved in a vector called split_criterion for each tree. This vector has a value for each node in
    // the tree, but has the value -1 for nodes which are not internal nodes of the tree roots.

    threads.clear();

    // Second part (sequential): figure out the best splits per variable and fill in is_in_best:
    std::vector<std::vector<size_t>> bestTreesPerVariable(num_variables);
    determineAndMarkBestSplitsPerVariable(bestTreesPerVariable);

    // Result of the above: For each tree there is a vector is_in_best which takes the value 1 if the corresponding split
    // is in the top splits for the respective variable and 0 otherwise.
    // Moreover, bestTreesPerVariable will for each variable contain the indices of the trees that contain the best splits
    // in these variables.

    threads.reserve(num_threads);

    // Initailize importance and variance
    std::vector<std::vector<double>> variable_importance_threads(num_threads);

    // Compute importance
    for (uint i = 0; i < num_threads; ++i)
    {
      variable_importance_threads[i].resize(num_variables, 0);
      threads.emplace_back(&Forest::computeTreeImportanceInThread, this, i,
                           std::ref(variable_importance_threads[i]));
    }
    showProgress("Computing permutation importance..", num_trees);
    for (auto &th : threads)
    {
      th.join();
    }

    // Result of the above: The variable importance was calculated and stored in variable_importance_threads.

#ifdef R_BUILD
    if (aborted_threads > 0)
    {
      throw std::runtime_error("User interrupt.");
    }
#endif

    // Sum thread importances
    variable_importance.resize(num_variables, 0);
    for (size_t i = 0; i < num_variables; ++i)
    {
      for (uint j = 0; j < num_threads; ++j)
      {
        variable_importance[i] += variable_importance_threads[j][i];
      }
    }
    variable_importance_threads.clear();

    // Determine the indices of the independent variables:
    std::vector<size_t> all_vars;
    for (size_t i = 0; i < data->getNumCols(); ++i)
    {
      // If the variable is not in data->getNoSplitVariables(), add it to all_vars:
      if (std::find(data->getNoSplitVariables().begin(), data->getNoSplitVariables().end(), i) == data->getNoSplitVariables().end())
      {
        all_vars.push_back(i);
      }
    }

    // Keep only those elements of variable_importance that have in all_vars:
    std::vector<double> variable_importance_indep(num_independent_variables);
    for (size_t i = 0; i < num_independent_variables; ++i)
    {
      variable_importance_indep[i] = variable_importance[all_vars[i]];
    }
    variable_importance = variable_importance_indep;

    for (size_t i = 0; i < variable_importance.size(); ++i)
    {
      variable_importance[i] /= num_trees;
    }
  }

  // Compute the split scores for the unity VIM values.
  void Forest::computeSplitCriterionValuesInThread(uint thread_idx)
  {

    if (thread_ranges.size() > thread_idx + 1)
    {
      for (size_t i = thread_ranges[thread_idx];
           i < thread_ranges[thread_idx + 1];
           ++i)
      {
        trees[i]->computeSplitCriterionValues();

        // Check for user interrupt
#ifdef R_BUILD
        if (aborted)
        {
          std::unique_lock<std::mutex> lock(mutex);
          ++aborted_threads;
          condition_variable.notify_one();
          return;
        }
#endif

        // Increase progress by 1 tree
        std::unique_lock<std::mutex> lock(mutex);
        ++progress;
        condition_variable.notify_one();
      }
    }
  }

  // Compute the split scores for the CRTR analysis.
  void Forest::computeOOBSplitCriterionValuesInThread(uint thread_idx)
  {

    // CP();

    if (thread_ranges.size() > thread_idx + 1)
    {
      for (size_t i = thread_ranges[thread_idx];
           i < thread_ranges[thread_idx + 1];
           ++i)
      {
        trees[i]->computeOOBSplitCriterionValues();

        // CP();

        // Check for user interrupt
#ifdef R_BUILD
        if (aborted)
        {
          std::unique_lock<std::mutex> lock(mutex);
          ++aborted_threads;
          condition_variable.notify_one();
          return;
        }
#endif

        // Increase progress by 1 tree
        std::unique_lock<std::mutex> lock(mutex);
        ++progress;
        condition_variable.notify_one();
      }
    }
  }

  // Determine the top-scoring splits for each covariate in the unity VIM computation.
  void Forest::determineAndMarkBestSplitsPerVariable(std::vector<std::vector<size_t>> &bestTreesPerVariable)
  {

    // Initialize a SplitData object for each variable (hint: the number of entries in all_splits_per_variable is equal to the number of columns in the data matrix,
    // not the number of independent variables):
    std::vector<std::vector<SplitData>> all_splits_per_variable(num_variables);

    // Loop over all trees and collect all splits per variable:
    for (size_t i = 0; i < num_trees; ++i)
    {
      trees[i]->collectSplits(i, all_splits_per_variable);
    }

    // Sort the splits per variable in descending order of the split criterion:
    for (size_t varID = 0; varID < num_variables; ++varID)
    {
      std::sort(all_splits_per_variable[varID].begin(), all_splits_per_variable[varID].end(),
                [](const SplitData &a, const SplitData &b)
                {
                  return a.split_criterion > b.split_criterion;
                });
    }

    // Remove from all_splits_per_variable all splits with a split criterion of -1 (note: -1 is the smallest possible value,
    // so for each variable, beginning with the first split with a split criterion of -1, all following splits will have a
    // split criterion of -1):
    for (size_t varID = 0; varID < num_variables; ++varID)
    {
      size_t first_split_with_criterion_minus1 = all_splits_per_variable[varID].size();
      for (size_t i = 0; i < all_splits_per_variable[varID].size(); ++i)
      {
        if (all_splits_per_variable[varID][i].split_criterion == -1)
        {
          first_split_with_criterion_minus1 = i;
          break;
        }
      }
      all_splits_per_variable[varID].resize(first_split_with_criterion_minus1);
    }

    // Determine the best splits per variable:
    bool few_splits = false;
    for (size_t varID = 0; varID < num_variables; ++varID)
    {
      size_t top_k = static_cast<size_t>(prop_best_splits * all_splits_per_variable[varID].size());
      // If top_k is smaller than and the variable is not a no-split variable, set top_k to the minimum of 5 and the number of splits for the variable:
      if (top_k < 5 && std::find(data->getNoSplitVariables().begin(), data->getNoSplitVariables().end(), varID) == data->getNoSplitVariables().end())
      {
        // Set top_k to the minimum of 5 and the number of splits for the variable:
        top_k = std::min<size_t>(5, all_splits_per_variable[varID].size());
        few_splits = true;
      }
      all_splits_per_variable[varID].resize(top_k);
    }
    if (few_splits)
    {
      Rcpp::Rcout << "Warning: The specified prop.best.splits value would have resulted in fewer than 5 best splits being used for some variables in the variable importance. For these, 5 best splits were used. Consider increasing the number of trees or the value of prop.best.splits." << std::endl
                  << std::flush;
    }

    // Make a container for the best splits per variable:
    std::vector<std::set<std::pair<size_t, size_t>>> bestSplits(num_variables);

    // Fill the container with the best splits per variable:
    for (size_t varID = 0; varID < num_variables; ++varID)
    {
      for (size_t i = 0; i < all_splits_per_variable[varID].size(); ++i)
      {
        bestSplits[varID].insert(std::make_pair(all_splits_per_variable[varID][i].tree_idx,
                                                all_splits_per_variable[varID][i].node_idx));
      }
    }

    // Make a vector of type std::vector<std::vector<size_t>> that contains for each variable a vector of the tree indices,
    // where the best splits for that variable are located:
    for (size_t varID = 0; varID < num_variables; ++varID)
    {
      bestTreesPerVariable[varID].resize(all_splits_per_variable[varID].size());
      for (size_t i = 0; i < all_splits_per_variable[varID].size(); ++i)
      {
        bestTreesPerVariable[varID][i] = all_splits_per_variable[varID][i].tree_idx;
      }
    }
    // For each variable, keep only the unique tree indices:
    for (size_t varID = 0; varID < num_variables; ++varID)
    {
      std::unordered_set<size_t> seen;
      std::vector<size_t> unique_values;
      for (size_t val : bestTreesPerVariable[varID])
      {
        if (seen.insert(val).second)
        { // insert returns true if the element was not already in the set
          unique_values.push_back(val);
        }
      }
      bestTreesPerVariable[varID] = std::move(unique_values);
    }

    // Mark the best splits per variable in the trees:
    for (size_t i = 0; i < num_trees; ++i)
    {
      trees[i]->markBestSplits(i, bestSplits);
    }
  }

  // Determine the top-scoring splits for each covariate in the CRTR analysis.
  void Forest::determineAndMarkBestOOBSplitsPerVariable(std::vector<std::vector<size_t>> &bestTreesPerVariable)
  {

    // Initialize a SplitData object for each variable (hint: the number of entries in all_splits_per_variable is equal to the number of columns in the data matrix,
    // not the number of independent variables):
    std::vector<std::vector<SplitData>> all_splits_per_variable(repr_vars.size());

    // Loop over all trees and collect all splits per variable:
    for (size_t i = 0; i < num_trees; ++i)
    {
      trees[i]->collectOOBSplits(i, all_splits_per_variable);
    }

    // Sort the splits per variable in descending order of the split criterion:
    for (size_t varID = 0; varID < repr_vars.size(); ++varID)
    {
      std::sort(all_splits_per_variable[varID].begin(), all_splits_per_variable[varID].end(),
                [](const SplitData &a, const SplitData &b)
                {
                  return a.split_criterion > b.split_criterion;
                });
    }

    // Remove from all_splits_per_variable all splits with a split criterion of -1 (note: -1 is the smallest possible value,
    // so for each variable, beginning with the first split with a split criterion of -1, all following splits will have a
    // split criterion of -1):
    for (size_t varID = 0; varID < repr_vars.size(); ++varID)
    {
      size_t first_split_with_criterion_minus1 = all_splits_per_variable[varID].size();
      for (size_t i = 0; i < all_splits_per_variable[varID].size(); ++i)
      {
        if (all_splits_per_variable[varID][i].split_criterion == -1)
        {
          first_split_with_criterion_minus1 = i;
          break;
        }
      }
      all_splits_per_variable[varID].resize(first_split_with_criterion_minus1);
    }

    // Determine the best splits per variable:
    bool few_splits = false;
    for (size_t varID = 0; varID < repr_vars.size(); ++varID)
    {
      size_t top_k = static_cast<size_t>(prop_best_splits * all_splits_per_variable[varID].size());
      // If top_k is smaller than and the variable is not a no-split variable, set top_k to the minimum of 5 and the number of splits for the variable:
      if (top_k < 5 && std::find(data->getNoSplitVariables().begin(), data->getNoSplitVariables().end(), varID) == data->getNoSplitVariables().end())
      {
        // Set top_k to the minimum of 5 and the number of splits for the variable:
        top_k = std::min<size_t>(5, all_splits_per_variable[varID].size());
        few_splits = true;
      }
      all_splits_per_variable[varID].resize(top_k);
    }
    if (few_splits)
    {
      Rcpp::Rcout << "Warning: The specified prop.best.splits value would have resulted in fewer than 5 best splits being used for some variables in the variable importance. For these, 5 best splits were used. Consider increasing the number of trees or the value of prop.best.splits." << std::endl
                  << std::flush;
    }

    // Make a container for the best splits per variable:
    std::vector<std::set<std::pair<size_t, size_t>>> bestSplits(repr_vars.size());

    // Fill the container with the best splits per variable:
    for (size_t varID = 0; varID < repr_vars.size(); ++varID)
    {
      for (size_t i = 0; i < all_splits_per_variable[varID].size(); ++i)
      {
        bestSplits[varID].insert(std::make_pair(all_splits_per_variable[varID][i].tree_idx,
                                                all_splits_per_variable[varID][i].node_idx));
      }
    }

    // Make a vector of type std::vector<std::vector<size_t>> that contains for each variable a vector of the tree indices,
    // where the best splits for that variable are located:
    for (size_t varID = 0; varID < repr_vars.size(); ++varID)
    {
      bestTreesPerVariable[varID].resize(all_splits_per_variable[varID].size());
      for (size_t i = 0; i < all_splits_per_variable[varID].size(); ++i)
      {
        bestTreesPerVariable[varID][i] = all_splits_per_variable[varID][i].tree_idx;
      }
    }
    // For each variable, keep only the unique tree indices:
    for (size_t varID = 0; varID < repr_vars.size(); ++varID)
    {
      std::unordered_set<size_t> seen;
      std::vector<size_t> unique_values;
      for (size_t val : bestTreesPerVariable[varID])
      {
        if (seen.insert(val).second)
        { // insert returns true if the element was not already in the set
          unique_values.push_back(val);
        }
      }
      bestTreesPerVariable[varID] = std::move(unique_values);
    }

    // Mark the best splits per variable in the trees:
    for (size_t i = 0; i < num_trees; ++i)
    {
      trees[i]->markBestOOBSplits(i, bestSplits);
    }
  }

  // This function computes the representative tree for a given vector of best trees.
  // It does this by computing the Uv vectors for each tree in the vector and then calculating
  // the distance between each pair of trees. The tree with the smallest distance to all other trees
  // is returned as the representative tree.
  size_t Forest::getReprTree(std::vector<size_t> best_tree_vec)
  {

    std::vector<std::vector<double>> Uv;
    // Initialize the Uv vector with the size of the number of trees and the number of variables:
    Uv.resize(best_tree_vec.size(), std::vector<double>(num_variables, 0.0));

    // Loop through the trees in the vector best_tree_vec:
    for (size_t i = 0; i < best_tree_vec.size(); ++i)
    {
      // Get the tree index:
      size_t tree_idx = best_tree_vec[i];

      trees[tree_idx]->computeUv(i, Uv);
    }

    // Between each pair of trees, compute the distance. The distance is defined as the sum of the squared differences between the Uv vectors of the two trees:
    std::vector<std::vector<double>> dist_matrix(best_tree_vec.size(), std::vector<double>(best_tree_vec.size(), 0.0));
    for (size_t i = 0; i < best_tree_vec.size(); ++i)
    {
      for (size_t j = i + 1; j < best_tree_vec.size(); ++j)
      {
        double dist = 0.0;
        for (size_t k = 0; k < num_variables; ++k)
        {
          dist += std::pow(Uv[i][k] - Uv[j][k], 2);
        }
        dist_matrix[i][j] = dist;
        dist_matrix[j][i] = dist;
      }
    }

    // Find the tree with the smallest distance to all other trees:
    size_t best_tree = 0;
    double min_dist = std::numeric_limits<double>::max();
    for (size_t i = 0; i < best_tree_vec.size(); ++i)
    {
      double dist = 0.0;
      for (size_t j = 0; j < best_tree_vec.size(); ++j)
      {
        if (i != j)
        {
          dist += dist_matrix[i][j];
        }
      }
      if (dist < min_dist)
      {
        min_dist = dist;
        best_tree = best_tree_vec[i];
      }
    }

    // Return the index of the best tree:
    return best_tree;
  }

  // Grow the trees.
  void Forest::growTreesInThread(uint thread_idx, std::vector<double> *variable_importance)
  {

    // CP();

    if (thread_ranges.size() > thread_idx + 1)
    {
      for (size_t i = thread_ranges[thread_idx]; i < thread_ranges[thread_idx + 1]; ++i)
      {
        trees[i]->grow(variable_importance);

        // Check for user interrupt
#ifdef R_BUILD
        if (aborted)
        {
          std::unique_lock<std::mutex> lock(mutex);
          ++aborted_threads;
          condition_variable.notify_one();
          return;
        }
#endif

        // CP();

        // Increase progress by 1 tree
        std::unique_lock<std::mutex> lock(mutex);
        ++progress;
        condition_variable.notify_one();
      }
    }
  }

  // Predicting using constructed unity forest.
  void Forest::predictTreesInThread(uint thread_idx, const Data *prediction_data, bool oob_prediction)
  {
    if (thread_ranges.size() > thread_idx + 1)
    {
      for (size_t i = thread_ranges[thread_idx]; i < thread_ranges[thread_idx + 1]; ++i)
      {
        trees[i]->predict(prediction_data, oob_prediction);

        // Check for user interrupt
#ifdef R_BUILD
        if (aborted)
        {
          std::unique_lock<std::mutex> lock(mutex);
          ++aborted_threads;
          condition_variable.notify_one();
          return;
        }
#endif

        // Increase progress by 1 tree
        std::unique_lock<std::mutex> lock(mutex);
        ++progress;
        condition_variable.notify_one();
      }
    }
  }

  void Forest::predictInternalInThread(uint thread_idx)
  {
    // Create thread ranges
    std::vector<uint> predict_ranges;
    equalSplit(predict_ranges, 0, num_samples - 1, num_threads);

    if (predict_ranges.size() > thread_idx + 1)
    {
      for (size_t i = predict_ranges[thread_idx]; i < predict_ranges[thread_idx + 1]; ++i)
      {
        predictInternal(i);

        // Check for user interrupt
#ifdef R_BUILD
        if (aborted)
        {
          std::unique_lock<std::mutex> lock(mutex);
          ++aborted_threads;
          condition_variable.notify_one();
          return;
        }
#endif

        // Increase progress by 1 tree
        std::unique_lock<std::mutex> lock(mutex);
        ++progress;
        condition_variable.notify_one();
      }
    }
  }

  // Compute unity VIM values.
  void Forest::computeTreeImportanceInThread(uint thread_idx, std::vector<double> &importance)
  {
    if (thread_ranges.size() > thread_idx + 1)
    {
      for (size_t i = thread_ranges[thread_idx]; i < thread_ranges[thread_idx + 1]; ++i)
      {
        trees[i]->computeUFImportance(importance);

        // Check for user interrupt
#ifdef R_BUILD
        if (aborted)
        {
          std::unique_lock<std::mutex> lock(mutex);
          ++aborted_threads;
          condition_variable.notify_one();
          return;
        }
#endif

        // Increase progress by 1 tree
        std::unique_lock<std::mutex> lock(mutex);
        ++progress;
        condition_variable.notify_one();
      }
    }
  }

  void Forest::setSplitWeightVector(std::vector<std::vector<double>> &split_select_weights)
  {

    // Size should be 1 x num_independent_variables or num_trees x num_independent_variables
    if (split_select_weights.size() != 1 && split_select_weights.size() != num_trees)
    {
      throw std::runtime_error("Size of split select weights not equal to 1 or number of trees.");
    }

    // Reserve space
    size_t num_weights = num_independent_variables;
    if (importance_mode == IMP_GINI_CORRECTED)
    {
      num_weights = 2 * num_independent_variables;
    }
    if (split_select_weights.size() == 1)
    {
      this->split_select_weights[0].resize(num_weights);
    }
    else
    {
      this->split_select_weights.clear();
      this->split_select_weights.resize(num_trees, std::vector<double>(num_weights));
    }
    this->split_select_varIDs.resize(num_weights);
    deterministic_varIDs.reserve(num_weights);

    // Split up in deterministic and weighted variables, ignore zero weights
    size_t num_zero_weights = 0;
    for (size_t i = 0; i < split_select_weights.size(); ++i)
    {

      // Size should be 1 x num_independent_variables or num_trees x num_independent_variables
      if (split_select_weights[i].size() != num_independent_variables)
      {
        throw std::runtime_error("Number of split select weights not equal to number of independent variables.");
      }

      for (size_t j = 0; j < split_select_weights[i].size(); ++j)
      {
        double weight = split_select_weights[i][j];

        if (i == 0)
        {
          size_t varID = j;
          for (auto &skip : data->getNoSplitVariables())
          {
            if (varID >= skip)
            {
              ++varID;
            }
          }

          if (weight == 1)
          {
            deterministic_varIDs.push_back(varID);
          }
          else if (weight < 1 && weight > 0)
          {
            this->split_select_varIDs[j] = varID;
            this->split_select_weights[i][j] = weight;
          }
          else if (weight == 0)
          {
            ++num_zero_weights;
          }
          else if (weight < 0 || weight > 1)
          {
            throw std::runtime_error("One or more split select weights not in range [0,1].");
          }
        }
        else
        {
          if (weight < 1 && weight > 0)
          {
            this->split_select_weights[i][j] = weight;
          }
          else if (weight < 0 || weight > 1)
          {
            throw std::runtime_error("One or more split select weights not in range [0,1].");
          }
        }
      }

      // Copy weights for corrected impurity importance
      if (importance_mode == IMP_GINI_CORRECTED)
      {
        std::vector<double> *sw = &(this->split_select_weights[i]);
        std::copy_n(sw->begin(), num_independent_variables, sw->begin() + num_independent_variables);

        for (size_t k = 0; k < num_independent_variables; ++k)
        {
          split_select_varIDs[num_independent_variables + k] = num_variables + k;
        }

        size_t num_deterministic_varIDs = deterministic_varIDs.size();
        for (size_t k = 0; k < num_deterministic_varIDs; ++k)
        {
          size_t varID = deterministic_varIDs[k];
          for (auto &skip : data->getNoSplitVariables())
          {
            if (varID >= skip)
            {
              --varID;
            }
          }
          deterministic_varIDs.push_back(varID + num_variables);
        }
      }
    }

    if (num_weights - deterministic_varIDs.size() - num_zero_weights < mtry)
    {
      throw std::runtime_error("Too many zeros or ones in split select weights. Need at least mtry variables to split at.");
    }
  }

  void Forest::setAlwaysSplitVariables(const std::vector<std::string> &always_split_variable_names)
  {

    deterministic_varIDs.reserve(num_independent_variables);

    for (auto &variable_name : always_split_variable_names)
    {
      size_t varID = data->getVariableID(variable_name);
      deterministic_varIDs.push_back(varID);
    }

    if (deterministic_varIDs.size() + this->mtry > num_independent_variables)
    {
      throw std::runtime_error(
          "Number of variables to be always considered for splitting plus mtry cannot be larger than number of independent variables.");
    }

    // Also add variables for corrected impurity importance
    if (importance_mode == IMP_GINI_CORRECTED)
    {
      size_t num_deterministic_varIDs = deterministic_varIDs.size();
      for (size_t k = 0; k < num_deterministic_varIDs; ++k)
      {
        size_t varID = deterministic_varIDs[k];
        for (auto &skip : data->getNoSplitVariables())
        {
          if (varID >= skip)
          {
            --varID;
          }
        }
        deterministic_varIDs.push_back(varID + num_variables);
      }
    }
  }

  void Forest::showProgress(std::string operation, size_t max_progress)
  {
    using std::chrono::duration_cast;
    using std::chrono::seconds;
    using std::chrono::steady_clock;

    steady_clock::time_point start_time = steady_clock::now();
    steady_clock::time_point last_time = steady_clock::now();
    std::unique_lock<std::mutex> lock(mutex);

    // Wait for message from threads and show output if enough time elapsed
    while (progress < max_progress)
    {
      condition_variable.wait(lock);
      seconds elapsed_time = duration_cast<seconds>(steady_clock::now() - last_time);

      // Check for user interrupt
#ifdef R_BUILD
      if (!aborted && checkInterrupt())
      {
        aborted = true;
      }
      if (aborted && aborted_threads >= num_threads)
      {
        return;
      }
#endif

      if (progress > 0 && elapsed_time.count() > STATUS_INTERVAL)
      {
        double relative_progress = (double)progress / (double)max_progress;
        seconds time_from_start = duration_cast<seconds>(steady_clock::now() - start_time);
        uint remaining_time = (1 / relative_progress - 1) * time_from_start.count();
        if (verbose_out)
        {
          *verbose_out << operation << " Progress: " << round(100 * relative_progress) << "%. Estimated remaining time: "
                       << beautifyTime(remaining_time) << "." << std::endl;
        }
        last_time = steady_clock::now();
      }
    }
  }

} // namespace unityForest
