{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Integration of Single Hyperbox-based Estimators with Grid-Search and Random-Search in sklearn"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This example shows how to integrate the GFMM classifier with the Grid Search Cross-Validation and Random Search Cross-Validation functionalities implemented by scikit-learn\n",
"\n",
"Note that this example will use the original online learning algorithm of GFMM model for demonstration of the integration of Grid Search and Random Search with hyperbox-based model. However, this characteristic can be similarly applied for all of the other hyperbox-based machine learning algorithms."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import warnings\n",
"warnings.filterwarnings('ignore')\n",
"from sklearn.model_selection import GridSearchCV\n",
"from sklearn.model_selection import RandomizedSearchCV\n",
"from sklearn.preprocessing import MinMaxScaler\n",
"from sklearn.model_selection import train_test_split\n",
"from hbbrain.numerical_data.incremental_learner.onln_gfmm import OnlineGFMM"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Load Iris dataset, normalize it into the range of [0, 1] and build training and testing datasets"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.datasets import load_iris"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"df = load_iris()\n",
"X = df.data\n",
"y = df.target"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"scaler = MinMaxScaler()\n",
"X = scaler.fit_transform(X)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.8, random_state=0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Using Grid Search with 5-fold cross-validation"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from sklearn.metrics import accuracy_score"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"parameters = {'theta': np.arange(0.05, 1.01, 0.05), 'theta_min':[1], 'gamma':[0.5, 1, 2, 4, 8, 16]}"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"onln_gfmm = OnlineGFMM()\n",
"clf_grid_search = GridSearchCV(onln_gfmm, parameters, cv=5, scoring='accuracy', refit=True)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Best average score = 0.9583333333333334\n",
"Best params: {'gamma': 0.5, 'theta': 0.3, 'theta_min': 1}\n"
]
}
],
"source": [
"clf_grid_search.fit(X_train, y_train)\n",
"print(\"Best average score = \", clf_grid_search.best_score_)\n",
"print(\"Best params: \", clf_grid_search.best_params_)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"best_gfmm_grid_search = clf_grid_search.best_estimator_"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"# Testing the performance on the test set\n",
"y_pred = best_gfmm_grid_search.predict(X_test)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy (grid-search) = 96.67%\n"
]
}
],
"source": [
"acc_grid_search = accuracy_score(y_test, y_pred)\n",
"print(f'Accuracy (grid-search) = {acc_grid_search * 100: .2f}%')"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"# Try another way to create the best classifier\n",
"best_gfmm_grid_search_2 = OnlineGFMM(**clf_grid_search.best_params_)\n",
"#best_gfmm_grid_search_2.set_params(**clf_grid_search.best_params_)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"OnlineGFMM(C=array([2, 1, 0, 1, 2, 2, 1, 2, 0, 0, 1, 0, 2, 2, 1]),\n",
" V=array([[0.44444444, 0.29166667, 0.6440678 , 0.70833333],\n",
" [0.25 , 0.125 , 0.42372881, 0.375 ],\n",
" [0.11111111, 0.45833333, 0.03389831, 0.04166667],\n",
" [0.16666667, 0. , 0.33898305, 0.375 ],\n",
" [0.38888889, 0.08333333, 0.68221339, 0.58333333],\n",
" [0.77777778, 0.41666667, 0.83050847, 0.70833333],\n",
" [0.47222222, 0.375 , 0.55932203, 0.5 ],\n",
" [0.166666...\n",
" [0.16666667, 0.20833333, 0.59322034, 0.66666667],\n",
" [0.19444444, 0.58333333, 0.10169492, 0.08333333],\n",
" [0.41666667, 1. , 0.11864407, 0.125 ],\n",
" [0.55555556, 0.20833333, 0.66101695, 0.58333333],\n",
" [0.05555556, 0.125 , 0.05084746, 0.08333333],\n",
" [0.94444444, 0.41666667, 1. , 0.91666667],\n",
" [1. , 0.75 , 0.96610169, 0.875 ],\n",
" [0.44444444, 0.5 , 0.6440678 , 0.70833333]]),\n",
" gamma=0.5, theta=0.3, theta_min=0.3)"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Training\n",
"best_gfmm_grid_search_2.fit(X_train, y_train)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"# predict\n",
"y_pred_2 = best_gfmm_grid_search_2.predict(X_test)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy (grid-search) = 96.67%\n"
]
}
],
"source": [
"acc_grid_search_2 = accuracy_score(y_test, y_pred_2)\n",
"print(f'Accuracy (grid-search) = {acc_grid_search_2 * 100: .2f}%')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Using Random Search with 5-fold cross-validation"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"# Using random search with only 20 random combinations of parameters\n",
"onln_gfmm_rd_search = OnlineGFMM()\n",
"clf_rd_search = RandomizedSearchCV(onln_gfmm_rd_search, parameters, n_iter=20, cv=5, random_state=0)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Best average score = 0.9583333333333334\n",
"Best params: {'theta_min': 1, 'theta': 0.3, 'gamma': 2}\n"
]
}
],
"source": [
"clf_rd_search.fit(X_train, y_train)\n",
"print(\"Best average score = \", clf_rd_search.best_score_)\n",
"print(\"Best params: \", clf_rd_search.best_params_)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"best_gfmm_rd_search = clf_rd_search.best_estimator_"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"# Testing the performance on the test set\n",
"y_pred_rd_search = best_gfmm_rd_search.predict(X_test)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy (random-search) = 96.67%\n"
]
}
],
"source": [
"acc_rd_search = accuracy_score(y_test, y_pred_rd_search)\n",
"print(f'Accuracy (random-search) = {acc_rd_search * 100: .2f}%')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Try to show explanation for an input sample"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"sample_need_explain = 10\n",
"y_pred_input_0, mem_val_classes, min_points_classes, max_points_classes = best_gfmm_rd_search.get_sample_explanation(X_test[sample_need_explain], X_test[sample_need_explain])"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Predicted class for sample X = [0.5 0.25 0.77966102 0.54166667] is 2 and real class is 2\n"
]
}
],
"source": [
"print(\"Predicted class for sample X = %s is %d and real class is %d\" % (X_test[sample_need_explain], y_pred_input_0, y_test[sample_need_explain]))"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Membership values:\n",
"Class 0 has the maximum membership value = 0.000000\n",
"Class 1 has the maximum membership value = 0.805085\n",
"Class 2 has the maximum membership value = 0.916667\n",
"Class 0 has the representative hyperbox: V = [0.11111111 0.45833333 0.03389831 0.04166667] and W = [0.38888889 0.75 0.11864407 0.20833333]\n",
"Class 1 has the representative hyperbox: V = [0.25 0.125 0.42372881 0.375 ] and W = [0.5 0.41666667 0.68220339 0.625 ]\n",
"Class 2 has the representative hyperbox: V = [0.38888889 0.08333333 0.68221339 0.58333333] and W = [0.66666667 0.33333333 0.81355932 0.79166667]\n"
]
}
],
"source": [
"print(\"Membership values:\")\n",
"for key, val in mem_val_classes.items():\n",
" print(\"Class %d has the maximum membership value = %f\" % (key, val))\n",
" \n",
"for key in min_points_classes:\n",
" print(\"Class %d has the representative hyperbox: V = %s and W = %s\" % (key, min_points_classes[key], max_points_classes[key]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Show explanation results by parallel coordinates"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"# Create a parallel coordinates graph\n",
"best_gfmm_rd_search.show_sample_explanation(X_test[sample_need_explain], X_test[sample_need_explain], min_points_classes, max_points_classes, y_pred_input_0, file_path=\"par_cord/iris_par_cord.html\")"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" "
],
"text/plain": [
""
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Load parallel coordinates to display on the notebook\n",
"from IPython.display import IFrame\n",
"# We load the parallel coordinates from GitHub here for demostration in readthedocs\n",
"# On the local notebook, we only need to load from the graph storing at 'par_cord/iris_par_cord.html'\n",
"IFrame('https://uts-caslab.github.io/hyperbox-brain/docs/tutorials/par_cord/iris_par_cord.html', width=820, height=520)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}