{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Decision-level Bagging of Hyperbox-based Models" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This example shows how to use a Bagging classifier of base hyperbox-based models trained on a full set of features and a subset of samples." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import warnings\n", "warnings.filterwarnings('ignore')\n", "from sklearn.metrics import accuracy_score\n", "from sklearn.model_selection import train_test_split\n", "from hbbrain.numerical_data.ensemble_learner.decision_comb_bagging import DecisionCombinationBagging\n", "from hbbrain.numerical_data.incremental_learner.onln_gfmm import OnlineGFMM" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load dataset.\n", "This example will use the breast cancer dataset available in sklearn to demonstrate how to use this ensemble classifier. " ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from sklearn.datasets import load_breast_cancer\n", "from sklearn.preprocessing import MinMaxScaler" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "df = load_breast_cancer()\n", "X = df.data\n", "y = df.target" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# Normailise data into the range of [0, 1] as hyperbox-based models only work in the unit cube\n", "scaler = MinMaxScaler()\n", "X = scaler.fit_transform(X)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# Split data into training, validation and testing sets\n", "Xtr_val, X_test, ytr_val, y_test = train_test_split(X, y, train_size=0.8, random_state=0)\n", "Xtr, X_val, ytr, y_val = train_test_split(X, y, train_size=0.75, random_state=0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This example will use the GFMM classifier with the original online learning algorithm as base learners. However, any type of hyperbox-based learning algorithms in this library can also be used to train base learners." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Using random subsampling to generate training sets for various base learners" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Training" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# Initialise parameters\n", "n_estimators = 20 # number of base learners\n", "max_samples = 0.5 # sampling rate for samples\n", "bootstrap = False # random subsampling without replacement\n", "class_balanced = False # do not use the class-balanced sampling mode\n", "n_jobs = 4 # number of processes is used to build base learners" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# Init a hyperbox-based model used to train base learners\n", "# Using the GFMM classifier with the original online learning algorithm with the maximum hyperbox size 0.1\n", "base_estimator = OnlineGFMM(theta=0.1)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "dc_bagging_subsampling = DecisionCombinationBagging(base_estimator=base_estimator, n_estimators=n_estimators, max_samples=max_samples, bootstrap=bootstrap, class_balanced=class_balanced, n_jobs=n_jobs, random_state=0)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DecisionCombinationBagging(base_estimator=OnlineGFMM(C=array([], dtype=float64),\n", " V=array([], dtype=float64),\n", " W=array([], dtype=float64),\n", " theta=0.1),\n", " n_estimators=20, n_jobs=4, random_state=0)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dc_bagging_subsampling.fit(Xtr, ytr)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training time: 4.355 (s)\n" ] } ], "source": [ "print(\"Training time: %.3f (s)\"%(dc_bagging_subsampling.elapsed_training_time))" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total number of hyperboxes from all base learners = 3948\n" ] } ], "source": [ "print('Total number of hyperboxes from all base learners = %d'%dc_bagging_subsampling.get_n_hyperboxes())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prediction" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Testing accuracy = 93.86%\n" ] } ], "source": [ "y_pred = dc_bagging_subsampling.predict(X_test)\n", "acc = accuracy_score(y_test, y_pred)\n", "print(f'Testing accuracy = {acc * 100: .2f}%')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Apply pruning for base learners" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DecisionCombinationBagging(base_estimator=OnlineGFMM(C=array([], dtype=float64),\n", " V=array([], dtype=float64),\n", " W=array([], dtype=float64),\n", " theta=0.1),\n", " n_estimators=20, n_jobs=4, random_state=0)" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "acc_threshold=0.5 # minimum accuracy score of the unpruned hyperboxes\n", "keep_empty_boxes=False # False means hyperboxes that do not join the prediction process within the pruning procedure are also eliminated\n", "dc_bagging_subsampling.simple_pruning_base_estimators(X_val, y_val, acc_threshold, keep_empty_boxes)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total number of hyperboxes from all base learners after pruning = 2195\n" ] } ], "source": [ "print('Total number of hyperboxes from all base learners after pruning = %d'%dc_bagging_subsampling.get_n_hyperboxes())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prediction after doing a pruning procedure" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Testing accuracy (after pruning) = 95.61%\n" ] } ], "source": [ "y_pred_2 = dc_bagging_subsampling.predict(X_test)\n", "acc_pruned = accuracy_score(y_test, y_pred_2)\n", "print(f'Testing accuracy (after pruning) = {acc_pruned * 100: .2f}%')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Using random undersampling to generate class-balanced training sets for various base learners" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Training" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "# Initialise parameters\n", "n_estimators = 20 # number of base learners\n", "max_samples = 0.5 # sampling rate for samples\n", "bootstrap = False # random subsampling without replacement\n", "class_balanced = True # use the class-balanced sampling mode\n", "n_jobs = 4 # number of processes is used to build base learners" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "# Init a hyperbox-based model used to train base learners\n", "# Using the GFMM classifier with the original online learning algorithm with the maximum hyperbox size 0.1\n", "base_estimator = OnlineGFMM(theta=0.1)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "dc_bagging_class_balanced = DecisionCombinationBagging(base_estimator=base_estimator, n_estimators=n_estimators, max_samples=max_samples, bootstrap=bootstrap, class_balanced=class_balanced, n_jobs=n_jobs, random_state=0)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DecisionCombinationBagging(base_estimator=OnlineGFMM(C=array([], dtype=float64),\n", " V=array([], dtype=float64),\n", " W=array([], dtype=float64),\n", " theta=0.1),\n", " class_balanced=True, n_estimators=20, n_jobs=4,\n", " random_state=0)" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dc_bagging_class_balanced.fit(Xtr, ytr)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training time: 0.271 (s)\n" ] } ], "source": [ "print(\"Training time: %.3f (s)\"%(dc_bagging_class_balanced.elapsed_training_time))" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total number of hyperboxes from all base learners = 4010\n" ] } ], "source": [ "print('Total number of hyperboxes from all base learners = %d'%dc_bagging_class_balanced.get_n_hyperboxes())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prediction" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Testing accuracy = 92.11%\n" ] } ], "source": [ "y_pred = dc_bagging_class_balanced.predict(X_test)\n", "acc = accuracy_score(y_test, y_pred)\n", "print(f'Testing accuracy = {acc * 100: .2f}%')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Apply pruning for base learners" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DecisionCombinationBagging(base_estimator=OnlineGFMM(C=array([], dtype=float64),\n", " V=array([], dtype=float64),\n", " W=array([], dtype=float64),\n", " theta=0.1),\n", " class_balanced=True, n_estimators=20, n_jobs=4,\n", " random_state=0)" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "acc_threshold=0.5 # minimum accuracy score of the unpruned hyperboxes\n", "keep_empty_boxes=False # False means hyperboxes that do not join the prediction process within the pruning procedure are also eliminated\n", "dc_bagging_class_balanced.simple_pruning_base_estimators(X_val, y_val, acc_threshold, keep_empty_boxes)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total number of hyperboxes from all base learners after pruning = 2738\n" ] } ], "source": [ "print('Total number of hyperboxes from all base learners after pruning = %d'%dc_bagging_class_balanced.get_n_hyperboxes())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prediction after doing a pruning procedure" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Testing accuracy (after pruning) = 94.74%\n" ] } ], "source": [ "y_pred_2 = dc_bagging_class_balanced.predict(X_test)\n", "acc_pruned = accuracy_score(y_test, y_pred_2)\n", "print(f'Testing accuracy (after pruning) = {acc_pruned * 100: .2f}%')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.7" } }, "nbformat": 4, "nbformat_minor": 4 }