% Perform the regression tree using "fitrtree" function
% Author: Chia-Feng Lu, 2020.04.08
clear, close all

%% load Hitters data: Major League Baseball Data from the 1986 and 1987 seasons.
[num,txt,raw]=xlsread('Hitters.csv');

% construct Table array
data=cell2table(raw(2:end,:));
for i=1:size(data,2)
    data.Properties.VariableNames{i}=raw{1,i}; % assign the VariableNames   
end
excludeind=find(isnan(data.Salary));  % identify the indices with missing Salary data
data(excludeind,:)=[]; % remove the players with missing Salary data

data.Salary=log(data.Salary); % log-transform to have a typical bell-shape distribution

%% Separate data into training (70%) and test (30%) datasets
rng(0,'twister')  % For reproducibility

C = cvpartition(size(data,1),'holdout',0.30); % randomly hold out 30% of subjects for test
dataTrain = data(C.training,:);
dataTest = data(C.test,:);

%% [Method 1] Construct a regression tree using six variables/features
% Construct a regression tree using the training dataset
predictors={'Years','Hits','RBI','PutOuts','Walks','Runs'};
tree_6v = fitrtree(dataTrain,'Salary','PredictorNames',predictors,...
    'OptimizeHyperparameters','all',...
    'HyperparameterOptimizationOptions',...
    struct('AcquisitionFunctionName','expected-improvement-plus','kfold',5));
view(tree_6v,'Mode','graph')

% Test the tree model on the test dataset
Salary_predict=predict(tree_6v,dataTest);
MSE=loss(tree_6v,dataTest);  % mean squared error, MSE=mean((Salary_predict-dataTest.Salary).^2);

RSS=sum((Salary_predict-dataTest.Salary).^2); % residual sum of square
TSS=sum((dataTest.Salary-mean(dataTest.Salary)).^2); % total sum of square
R_square=1-RSS/TSS;

figure,plot(1:size(dataTest,1),dataTest.Salary,'r.-',1:1:size(dataTest,1),Salary_predict,'b.-')
title(['[Method 1] MSE = ' num2str(MSE) ', R^2 = ' num2str(R_square)])
legend('True salaries','Predicted salaries')
xlabel('Players'),ylabel('Salaries')

%% [Method 2] Construct a regression tree using six variables/features with pruning
prunelevel=6;
tree_prune = prune(tree_6v,'level',prunelevel);

% prunealpha=0.03;
% tree_prune = prune(tree_6v,'alpha',prunealpha);

view(tree_prune,'Mode','graph')

% Test the tree model on the test dataset
Salary_predict=predict(tree_prune,dataTest);
MSE=loss(tree_prune,dataTest);  % mean squared error, MSE=mean((Salary_predict-dataTest.Salary).^2);

RSS=sum((Salary_predict-dataTest.Salary).^2); % residual sum of square
TSS=sum((dataTest.Salary-mean(dataTest.Salary)).^2); % total sum of square
R_square=1-RSS/TSS;

figure,plot(1:size(dataTest,1),dataTest.Salary,'r.-',1:1:size(dataTest,1),Salary_predict,'b.-')
title(['[Method 2] MSE = ' num2str(MSE) ', R^2 = ' num2str(R_square)])
legend('True salaries','Predicted salaries')
xlabel('Players'),ylabel('Salaries')

%% Variable/Feature selection based on the importance scores
imp = predictorImportance(tree_6v);

figure;
bar(imp);
title('Predictor Importance Estimates');
ylabel('Estimates');
xlabel('Predictors');
h = gca;
h.XTickLabel = tree_6v.PredictorNames;
h.XTickLabelRotation = 45;
h.TickLabelInterpreter = 'none';

%% [Method 3] Construct a regression tree using key variables/features
% Construct a regression tree using the training dataset
predictors={'Years','Hits'};  % only use the two key variables/features 
tree_2v = fitrtree(dataTrain,'Salary','PredictorNames',predictors,...
    'OptimizeHyperparameters','all',...
    'HyperparameterOptimizationOptions',...
    struct('AcquisitionFunctionName','expected-improvement-plus','kfold',5));
view(tree_2v,'Mode','graph')

% Test the tree model on the test dataset
Salary_predict=predict(tree_2v,dataTest);
MSE=loss(tree_2v,dataTest);  % mean squared error, MSE=mean((Salary_predict-dataTest.Salary).^2);

RSS=sum((Salary_predict-dataTest.Salary).^2); % residual sum of square
TSS=sum((dataTest.Salary-mean(dataTest.Salary)).^2); % total sum of square
R_square=1-RSS/TSS;

figure,plot(1:size(dataTest,1),dataTest.Salary,'r.-',1:1:size(dataTest,1),Salary_predict,'b.-')
title(['[Method 3] MSE = ' num2str(MSE) ', R^2 = ' num2str(R_square)])
legend('True salaries','Predicted salaries')
xlabel('Players'),ylabel('Salaries')
