{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Random Hyperboxes with Hyper-parameter Optimisation for Base Learners" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This example shows how to use a random hyperboxes classifier, in which each base hyperbox-based model is trained on a subset of features and a subset of samples using random search-based hyper-parameter tuning and k-fold cross-validation.\n", "\n", "While the original random hyperboxes model in the class RandomHyperboxesClassifier uses the same base learners with the same hyperparameters, the cross-validation random hyperboxes model in the class CrossValRandomHyperboxesClassifier allows each base learner to use specific hyperparameters depending on its training data by performing random research to find the best combination of hyperparameters for each base learner." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import warnings\n", "warnings.filterwarnings('ignore')\n", "import numpy as np\n", "from sklearn.metrics import accuracy_score\n", "from sklearn.model_selection import train_test_split\n", "from hbbrain.numerical_data.ensemble_learner.cross_val_random_hyperboxes import CrossValRandomHyperboxesClassifier\n", "from hbbrain.numerical_data.incremental_learner.onln_gfmm import OnlineGFMM" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load dataset.\n", "This example will use the breast cancer dataset available in sklearn to demonstrate how to use this ensemble classifier. " ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from sklearn.datasets import load_breast_cancer\n", "from sklearn.preprocessing import MinMaxScaler" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "df = load_breast_cancer()\n", "X = df.data\n", "y = df.target" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# Normailise data into the range of [0, 1] as hyperbox-based models only work in the unit cube\n", "scaler = MinMaxScaler()\n", "X = scaler.fit_transform(X)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# Split data into training, validation and testing sets\n", "Xtr_val, X_test, ytr_val, y_test = train_test_split(X, y, train_size=0.8, random_state=0)\n", "Xtr, X_val, ytr, y_val = train_test_split(X, y, train_size=0.75, random_state=0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**This example will use the GFMM classifier with the original online learning algorithm as base learners. However, any type of hyperbox-based learning algorithms in this library can also be used to train base learners.**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Using random subsampling to generate training sets for various base learners" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### a. The number of features used in each base learner is different and is bounded by a maximum number of features" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Training" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# Initialise parameters\n", "n_estimators = 20 # number of base learners\n", "max_samples = 0.5 # sampling rate for samples\n", "max_features = 0.5 # sampling rate to generate the maximum number of features\n", "class_balanced = False # do not use the class-balanced sampling mode\n", "feature_balanced = False # use different numbers of features for base learners\n", "n_jobs = 4 # number of processes is used to build base learners\n", "n_iter = 20 # Number of parameter settings that are randomly sampled to choose the best combination of hyperparameters\n", "k_fold = 5 # Number of folds to conduct Stratified K-Fold cross-validation for hyperparameter tunning" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# Init a hyperbox-based model used to train base learners\n", "# Using the GFMM classifier with the original online learning algorithm\n", "base_estimator = OnlineGFMM()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# Init ranges for hyperparameters of base learners to perform a random search process for hyperparameter tunning\n", "base_estimator_params = {'theta': np.arange(0.05, 1.01, 0.05), 'theta_min':[1], 'gamma':[0.5, 1, 2, 4, 8, 16]}" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "CrossValRandomHyperboxesClassifier(base_estimator=OnlineGFMM(C=array([], dtype=float64),\n", " V=array([], dtype=float64),\n", " W=array([], dtype=float64)),\n", " base_estimator_params={'gamma': [0.5, 1, 2,\n", " 4, 8, 16],\n", " 'theta': array([0.05, 0.1 , 0.15, 0.2 , 0.25, 0.3 , 0.35, 0.4 , 0.45, 0.5 , 0.55,\n", " 0.6 , 0.65, 0.7 , 0.75, 0.8 , 0.85, 0.9 , 0.95, 1. ]),\n", " 'theta_min': [1]},\n", " max_features=0.5, n_estimators=20, n_iter=20,\n", " n_jobs=4, random_state=0)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cross_val_rh_subsampling_diff_num_features_clf = CrossValRandomHyperboxesClassifier(base_estimator=base_estimator, base_estimator_params=base_estimator_params, n_estimators=n_estimators, max_samples=max_samples, max_features=max_features, class_balanced=class_balanced, feature_balanced=feature_balanced, n_iter=n_iter, k_fold=k_fold, n_jobs=n_jobs, random_state=0)\n", "cross_val_rh_subsampling_diff_num_features_clf.fit(Xtr, ytr)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training time: 37.453 (s)\n" ] } ], "source": [ "print(\"Training time: %.3f (s)\"%(cross_val_rh_subsampling_diff_num_features_clf.elapsed_training_time))" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total number of hyperboxes from all base learners = 1110\n" ] } ], "source": [ "print('Total number of hyperboxes from all base learners = %d'%cross_val_rh_subsampling_diff_num_features_clf.get_n_hyperboxes())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prediction" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Testing accuracy = 93.86%\n" ] } ], "source": [ "y_pred = cross_val_rh_subsampling_diff_num_features_clf.predict(X_test)\n", "acc = accuracy_score(y_test, y_pred)\n", "print(f'Testing accuracy = {acc * 100: .2f}%')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Apply pruning for base learners" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "CrossValRandomHyperboxesClassifier(base_estimator=OnlineGFMM(C=array([], dtype=float64),\n", " V=array([], dtype=float64),\n", " W=array([], dtype=float64)),\n", " base_estimator_params={'gamma': [0.5, 1, 2,\n", " 4, 8, 16],\n", " 'theta': array([0.05, 0.1 , 0.15, 0.2 , 0.25, 0.3 , 0.35, 0.4 , 0.45, 0.5 , 0.55,\n", " 0.6 , 0.65, 0.7 , 0.75, 0.8 , 0.85, 0.9 , 0.95, 1. ]),\n", " 'theta_min': [1]},\n", " max_features=0.5, n_estimators=20, n_iter=20,\n", " n_jobs=4, random_state=0)" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "acc_threshold=0.5 # minimum accuracy score of the unpruned hyperboxes\n", "keep_empty_boxes=False # False means hyperboxes that do not join the prediction process within the pruning procedure are also eliminated\n", "cross_val_rh_subsampling_diff_num_features_clf.simple_pruning_base_estimators(X_val, y_val, acc_threshold, keep_empty_boxes)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total number of hyperboxes from all base learners after pruning = 671\n" ] } ], "source": [ "print('Total number of hyperboxes from all base learners after pruning = %d'%cross_val_rh_subsampling_diff_num_features_clf.get_n_hyperboxes())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prediction after doing a pruning procedure" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Testing accuracy (after pruning) = 95.61%\n" ] } ], "source": [ "y_pred_2 = cross_val_rh_subsampling_diff_num_features_clf.predict(X_test)\n", "acc_pruned = accuracy_score(y_test, y_pred_2)\n", "print(f'Testing accuracy (after pruning) = {acc_pruned * 100: .2f}%')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### b. The number of features used in each base learner is the same and is equal to the given maximum number of features" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "# Initialise parameters\n", "n_estimators = 20 # number of base learners\n", "max_samples = 0.5 # sampling rate for samples\n", "max_features = 0.5 # sampling rate to generate the maximum number of features\n", "class_balanced = False # do not use the class-balanced sampling mode\n", "# use the same numbers of features for base learners and the number of used features is the given maximum number of features\n", "feature_balanced = True\n", "n_jobs = 4 # number of processes is used to build base learners\n", "n_iter = 20 # Number of parameter settings that are randomly sampled to choose the best combination of hyperparameters\n", "k_fold = 5 # Number of folds to conduct Stratified K-Fold cross-validation for hyperparameter tunning" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "# Init a hyperbox-based model used to train base learners\n", "# Using the GFMM classifier with the original online learning algorithm\n", "base_estimator = OnlineGFMM()" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "# Init ranges for hyperparameters of base learners to perform a random search process for hyperparameter tunning\n", "base_estimator_params = {'theta': np.arange(0.05, 1.01, 0.05), 'theta_min':[1], 'gamma':[0.5, 1, 2, 4, 8, 16]}" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "CrossValRandomHyperboxesClassifier(base_estimator=OnlineGFMM(C=array([], dtype=float64),\n", " V=array([], dtype=float64),\n", " W=array([], dtype=float64)),\n", " base_estimator_params={'gamma': [0.5, 1, 2,\n", " 4, 8, 16],\n", " 'theta': array([0.05, 0.1 , 0.15, 0.2 , 0.25, 0.3 , 0.35, 0.4 , 0.45, 0.5 , 0.55,\n", " 0.6 , 0.65, 0.7 , 0.75, 0.8 , 0.85, 0.9 , 0.95, 1. ]),\n", " 'theta_min': [1]},\n", " feature_balanced=True, max_features=0.5,\n", " n_estimators=20, n_iter=20, n_jobs=4,\n", " random_state=0)" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cross_val_rh_subsampling_same_num_features_clf = CrossValRandomHyperboxesClassifier(base_estimator=base_estimator, base_estimator_params=base_estimator_params, n_estimators=n_estimators, max_samples=max_samples, max_features=max_features, class_balanced=class_balanced, feature_balanced=feature_balanced, n_iter=n_iter, k_fold=k_fold, n_jobs=n_jobs, random_state=0)\n", "cross_val_rh_subsampling_same_num_features_clf.fit(Xtr, ytr)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training time: 45.047 (s)\n" ] } ], "source": [ "print(\"Training time: %.3f (s)\"%(cross_val_rh_subsampling_same_num_features_clf.elapsed_training_time))" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total number of hyperboxes from all base learners = 973\n" ] } ], "source": [ "print('Total number of hyperboxes from all base learners = %d'%cross_val_rh_subsampling_same_num_features_clf.get_n_hyperboxes())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prediction" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Testing accuracy = 93.86%\n" ] } ], "source": [ "y_pred = cross_val_rh_subsampling_same_num_features_clf.predict(X_test)\n", "acc = accuracy_score(y_test, y_pred)\n", "print(f'Testing accuracy = {acc * 100: .2f}%')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Apply pruning for base learners" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "CrossValRandomHyperboxesClassifier(base_estimator=OnlineGFMM(C=array([], dtype=float64),\n", " V=array([], dtype=float64),\n", " W=array([], dtype=float64)),\n", " base_estimator_params={'gamma': [0.5, 1, 2,\n", " 4, 8, 16],\n", " 'theta': array([0.05, 0.1 , 0.15, 0.2 , 0.25, 0.3 , 0.35, 0.4 , 0.45, 0.5 , 0.55,\n", " 0.6 , 0.65, 0.7 , 0.75, 0.8 , 0.85, 0.9 , 0.95, 1. ]),\n", " 'theta_min': [1]},\n", " feature_balanced=True, max_features=0.5,\n", " n_estimators=20, n_iter=20, n_jobs=4,\n", " random_state=0)" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "acc_threshold=0.5 # minimum accuracy score of the unpruned hyperboxes\n", "keep_empty_boxes=False # False means hyperboxes that do not join the prediction process within the pruning procedure are also eliminated\n", "cross_val_rh_subsampling_same_num_features_clf.simple_pruning_base_estimators(X_val, y_val, acc_threshold, keep_empty_boxes)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prediction after doing a pruning procedure" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Testing accuracy (after pruning) = 94.74%\n" ] } ], "source": [ "y_pred_2 = cross_val_rh_subsampling_same_num_features_clf.predict(X_test)\n", "acc_pruned = accuracy_score(y_test, y_pred_2)\n", "print(f'Testing accuracy (after pruning) = {acc_pruned * 100: .2f}%')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Using random undersampling to generate class-balanced training sets for various base learners" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### a. The number of features used in each base learner is different and is bounded by a maximum number of features" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Training" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "# Initialise parameters\n", "n_estimators = 20 # number of base learners\n", "max_samples = 0.5 # sampling rate for samples\n", "max_features = 0.5 # sampling rate to generate the maximum number of features\n", "class_balanced = True # use the class-balanced sampling mode\n", "feature_balanced = False # use different numbers of features for base learners\n", "n_jobs = 4 # number of processes is used to build base learners\n", "n_iter = 20 # Number of parameter settings that are randomly sampled to choose the best combination of hyperparameters\n", "k_fold = 5 # Number of folds to conduct Stratified K-Fold cross-validation for hyperparameter tunning" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "# Init a hyperbox-based model used to train base learners\n", "# Using the GFMM classifier with the original online learning algorithm\n", "base_estimator = OnlineGFMM()" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "# Init ranges for hyperparameters of base learners to perform a random search process for hyperparameter tunning\n", "base_estimator_params = {'theta': np.arange(0.05, 1.01, 0.05), 'theta_min':[1], 'gamma':[0.5, 1, 2, 4, 8, 16]}" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "CrossValRandomHyperboxesClassifier(base_estimator=OnlineGFMM(C=array([], dtype=float64),\n", " V=array([], dtype=float64),\n", " W=array([], dtype=float64)),\n", " base_estimator_params={'gamma': [0.5, 1, 2,\n", " 4, 8, 16],\n", " 'theta': array([0.05, 0.1 , 0.15, 0.2 , 0.25, 0.3 , 0.35, 0.4 , 0.45, 0.5 , 0.55,\n", " 0.6 , 0.65, 0.7 , 0.75, 0.8 , 0.85, 0.9 , 0.95, 1. ]),\n", " 'theta_min': [1]},\n", " class_balanced=True, max_features=0.5,\n", " n_estimators=20, n_iter=20, n_jobs=4,\n", " random_state=0)" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cross_val_rh_class_balanced_diff_num_features_clf = CrossValRandomHyperboxesClassifier(base_estimator=base_estimator, base_estimator_params=base_estimator_params, n_estimators=n_estimators, max_samples=max_samples, max_features=max_features, class_balanced=class_balanced, feature_balanced=feature_balanced, n_iter=n_iter, k_fold=k_fold, n_jobs=n_jobs, random_state=0)\n", "cross_val_rh_class_balanced_diff_num_features_clf.fit(Xtr, ytr)" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training time: 33.372 (s)\n" ] } ], "source": [ "print(\"Training time: %.3f (s)\"%(cross_val_rh_class_balanced_diff_num_features_clf.elapsed_training_time))" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total number of hyperboxes from all base learners = 1123\n" ] } ], "source": [ "print('Total number of hyperboxes from all base learners = %d'%cross_val_rh_class_balanced_diff_num_features_clf.get_n_hyperboxes())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prediction" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Testing accuracy = 92.11%\n" ] } ], "source": [ "y_pred = cross_val_rh_class_balanced_diff_num_features_clf.predict(X_test)\n", "acc = accuracy_score(y_test, y_pred)\n", "print(f'Testing accuracy = {acc * 100: .2f}%')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Apply pruning for base learners" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "CrossValRandomHyperboxesClassifier(base_estimator=OnlineGFMM(C=array([], dtype=float64),\n", " V=array([], dtype=float64),\n", " W=array([], dtype=float64)),\n", " base_estimator_params={'gamma': [0.5, 1, 2,\n", " 4, 8, 16],\n", " 'theta': array([0.05, 0.1 , 0.15, 0.2 , 0.25, 0.3 , 0.35, 0.4 , 0.45, 0.5 , 0.55,\n", " 0.6 , 0.65, 0.7 , 0.75, 0.8 , 0.85, 0.9 , 0.95, 1. ]),\n", " 'theta_min': [1]},\n", " class_balanced=True, max_features=0.5,\n", " n_estimators=20, n_iter=20, n_jobs=4,\n", " random_state=0)" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "acc_threshold=0.5 # minimum accuracy score of the unpruned hyperboxes\n", "keep_empty_boxes=False # False means hyperboxes that do not join the prediction process within the pruning procedure are also eliminated\n", "cross_val_rh_class_balanced_diff_num_features_clf.simple_pruning_base_estimators(X_val, y_val, acc_threshold, keep_empty_boxes)" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total number of hyperboxes from all base learners after pruning = 663\n" ] } ], "source": [ "print('Total number of hyperboxes from all base learners after pruning = %d'%cross_val_rh_class_balanced_diff_num_features_clf.get_n_hyperboxes())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prediction after doing a pruning procedure" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Testing accuracy (after pruning) = 94.74%\n" ] } ], "source": [ "y_pred_2 = cross_val_rh_class_balanced_diff_num_features_clf.predict(X_test)\n", "acc_pruned = accuracy_score(y_test, y_pred_2)\n", "print(f'Testing accuracy (after pruning) = {acc_pruned * 100: .2f}%')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### b. The number of features used in each base learner is the same and is equal to the given maximum number of features" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [], "source": [ "# Initialise parameters\n", "n_estimators = 20 # number of base learners\n", "max_samples = 0.5 # sampling rate for samples\n", "max_features = 0.5 # sampling rate to generate the maximum number of features\n", "class_balanced = True # use the class-balanced sampling mode\n", "# use the same numbers of features for base learners and the number of used features is the given maximum number of features\n", "feature_balanced = True\n", "n_jobs = 4 # number of processes is used to build base learners\n", "n_iter = 20 # Number of parameter settings that are randomly sampled to choose the best combination of hyperparameters\n", "k_fold = 5 # Number of folds to conduct Stratified K-Fold cross-validation for hyperparameter tunning" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [], "source": [ "# Init a hyperbox-based model used to train base learners\n", "# Using the GFMM classifier with the original online learning algorithm\n", "base_estimator = OnlineGFMM()" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [], "source": [ "# Init ranges for hyperparameters of base learners to perform a random search process for hyperparameter tunning\n", "base_estimator_params = {'theta': np.arange(0.05, 1.01, 0.05), 'theta_min':[1], 'gamma':[0.5, 1, 2, 4, 8, 16]}" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "CrossValRandomHyperboxesClassifier(base_estimator=OnlineGFMM(C=array([], dtype=float64),\n", " V=array([], dtype=float64),\n", " W=array([], dtype=float64)),\n", " base_estimator_params={'gamma': [0.5, 1, 2,\n", " 4, 8, 16],\n", " 'theta': array([0.05, 0.1 , 0.15, 0.2 , 0.25, 0.3 , 0.35, 0.4 , 0.45, 0.5 , 0.55,\n", " 0.6 , 0.65, 0.7 , 0.75, 0.8 , 0.85, 0.9 , 0.95, 1. ]),\n", " 'theta_min': [1]},\n", " class_balanced=True, feature_balanced=True,\n", " max_features=0.5, n_estimators=20, n_iter=20,\n", " n_jobs=4, random_state=0)" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cross_val_rh_class_balanced_same_num_features_clf = CrossValRandomHyperboxesClassifier(base_estimator=base_estimator, base_estimator_params=base_estimator_params, n_estimators=n_estimators, max_samples=max_samples, max_features=max_features, class_balanced=class_balanced, feature_balanced=feature_balanced, n_iter=n_iter, k_fold=k_fold, n_jobs=n_jobs, random_state=0)\n", "cross_val_rh_class_balanced_same_num_features_clf.fit(Xtr, ytr)" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training time: 30.501 (s)\n" ] } ], "source": [ "print(\"Training time: %.3f (s)\"%(cross_val_rh_class_balanced_same_num_features_clf.elapsed_training_time))" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total number of hyperboxes from all base learners = 1623\n" ] } ], "source": [ "print('Total number of hyperboxes from all base learners = %d'%cross_val_rh_class_balanced_same_num_features_clf.get_n_hyperboxes())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prediction" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Testing accuracy = 91.23%\n" ] } ], "source": [ "y_pred = cross_val_rh_class_balanced_same_num_features_clf.predict(X_test)\n", "acc = accuracy_score(y_test, y_pred)\n", "print(f'Testing accuracy = {acc * 100: .2f}%')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Apply pruning for base learners" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "CrossValRandomHyperboxesClassifier(base_estimator=OnlineGFMM(C=array([], dtype=float64),\n", " V=array([], dtype=float64),\n", " W=array([], dtype=float64)),\n", " base_estimator_params={'gamma': [0.5, 1, 2,\n", " 4, 8, 16],\n", " 'theta': array([0.05, 0.1 , 0.15, 0.2 , 0.25, 0.3 , 0.35, 0.4 , 0.45, 0.5 , 0.55,\n", " 0.6 , 0.65, 0.7 , 0.75, 0.8 , 0.85, 0.9 , 0.95, 1. ]),\n", " 'theta_min': [1]},\n", " class_balanced=True, feature_balanced=True,\n", " max_features=0.5, n_estimators=20, n_iter=20,\n", " n_jobs=4, random_state=0)" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "acc_threshold=0.5 # minimum accuracy score of the unpruned hyperboxes\n", "keep_empty_boxes=False # False means hyperboxes that do not join the prediction process within the pruning procedure are also eliminated\n", "cross_val_rh_class_balanced_same_num_features_clf.simple_pruning_base_estimators(X_val, y_val, acc_threshold, keep_empty_boxes)" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total number of hyperboxes from all base learners after pruning = 1234\n" ] } ], "source": [ "print('Total number of hyperboxes from all base learners after pruning = %d'%cross_val_rh_class_balanced_same_num_features_clf.get_n_hyperboxes())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prediction after doing a pruning procedure" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Testing accuracy (after pruning) = 95.61%\n" ] } ], "source": [ "y_pred_2 = cross_val_rh_class_balanced_same_num_features_clf.predict(X_test)\n", "acc_pruned = accuracy_score(y_test, y_pred_2)\n", "print(f'Testing accuracy (after pruning) = {acc_pruned * 100: .2f}%')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.7" } }, "nbformat": 4, "nbformat_minor": 4 }