{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Integration of Single Hyperbox-based Estimators with Grid-Search and Random-Search in sklearn" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This example shows how to integrate the GFMM classifier with the Grid Search Cross-Validation and Random Search Cross-Validation functionalities implemented by scikit-learn\n", "\n", "Note that this example will use the original online learning algorithm of GFMM model for demonstration of the integration of Grid Search and Random Search with hyperbox-based model. However, this characteristic can be similarly applied for all of the other hyperbox-based machine learning algorithms." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import warnings\n", "warnings.filterwarnings('ignore')\n", "from sklearn.model_selection import GridSearchCV\n", "from sklearn.model_selection import RandomizedSearchCV\n", "from sklearn.preprocessing import MinMaxScaler\n", "from sklearn.model_selection import train_test_split\n", "from hbbrain.numerical_data.incremental_learner.onln_gfmm import OnlineGFMM" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Load Iris dataset, normalize it into the range of [0, 1] and build training and testing datasets" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from sklearn.datasets import load_iris" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "df = load_iris()\n", "X = df.data\n", "y = df.target" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "scaler = MinMaxScaler()\n", "X = scaler.fit_transform(X)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.8, random_state=0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Using Grid Search with 5-fold cross-validation" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from sklearn.metrics import accuracy_score" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "parameters = {'theta': np.arange(0.05, 1.01, 0.05), 'theta_min':[1], 'gamma':[0.5, 1, 2, 4, 8, 16]}" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "onln_gfmm = OnlineGFMM()\n", "clf_grid_search = GridSearchCV(onln_gfmm, parameters, cv=5, scoring='accuracy', refit=True)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Best average score = 0.9583333333333334\n", "Best params: {'gamma': 0.5, 'theta': 0.3, 'theta_min': 1}\n" ] } ], "source": [ "clf_grid_search.fit(X_train, y_train)\n", "print(\"Best average score = \", clf_grid_search.best_score_)\n", "print(\"Best params: \", clf_grid_search.best_params_)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "best_gfmm_grid_search = clf_grid_search.best_estimator_" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "# Testing the performance on the test set\n", "y_pred = best_gfmm_grid_search.predict(X_test)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy (grid-search) = 96.67%\n" ] } ], "source": [ "acc_grid_search = accuracy_score(y_test, y_pred)\n", "print(f'Accuracy (grid-search) = {acc_grid_search * 100: .2f}%')" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "# Try another way to create the best classifier\n", "best_gfmm_grid_search_2 = OnlineGFMM(**clf_grid_search.best_params_)\n", "#best_gfmm_grid_search_2.set_params(**clf_grid_search.best_params_)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "OnlineGFMM(C=array([2, 1, 0, 1, 2, 2, 1, 2, 0, 0, 1, 0, 2, 2, 1]),\n", " V=array([[0.44444444, 0.29166667, 0.6440678 , 0.70833333],\n", " [0.25 , 0.125 , 0.42372881, 0.375 ],\n", " [0.11111111, 0.45833333, 0.03389831, 0.04166667],\n", " [0.16666667, 0. , 0.33898305, 0.375 ],\n", " [0.38888889, 0.08333333, 0.68221339, 0.58333333],\n", " [0.77777778, 0.41666667, 0.83050847, 0.70833333],\n", " [0.47222222, 0.375 , 0.55932203, 0.5 ],\n", " [0.166666...\n", " [0.16666667, 0.20833333, 0.59322034, 0.66666667],\n", " [0.19444444, 0.58333333, 0.10169492, 0.08333333],\n", " [0.41666667, 1. , 0.11864407, 0.125 ],\n", " [0.55555556, 0.20833333, 0.66101695, 0.58333333],\n", " [0.05555556, 0.125 , 0.05084746, 0.08333333],\n", " [0.94444444, 0.41666667, 1. , 0.91666667],\n", " [1. , 0.75 , 0.96610169, 0.875 ],\n", " [0.44444444, 0.5 , 0.6440678 , 0.70833333]]),\n", " gamma=0.5, theta=0.3, theta_min=0.3)" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Training\n", "best_gfmm_grid_search_2.fit(X_train, y_train)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "# predict\n", "y_pred_2 = best_gfmm_grid_search_2.predict(X_test)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy (grid-search) = 96.67%\n" ] } ], "source": [ "acc_grid_search_2 = accuracy_score(y_test, y_pred_2)\n", "print(f'Accuracy (grid-search) = {acc_grid_search_2 * 100: .2f}%')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Using Random Search with 5-fold cross-validation" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "# Using random search with only 20 random combinations of parameters\n", "onln_gfmm_rd_search = OnlineGFMM()\n", "clf_rd_search = RandomizedSearchCV(onln_gfmm_rd_search, parameters, n_iter=20, cv=5, random_state=0)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Best average score = 0.9583333333333334\n", "Best params: {'theta_min': 1, 'theta': 0.3, 'gamma': 2}\n" ] } ], "source": [ "clf_rd_search.fit(X_train, y_train)\n", "print(\"Best average score = \", clf_rd_search.best_score_)\n", "print(\"Best params: \", clf_rd_search.best_params_)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "best_gfmm_rd_search = clf_rd_search.best_estimator_" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "# Testing the performance on the test set\n", "y_pred_rd_search = best_gfmm_rd_search.predict(X_test)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy (random-search) = 96.67%\n" ] } ], "source": [ "acc_rd_search = accuracy_score(y_test, y_pred_rd_search)\n", "print(f'Accuracy (random-search) = {acc_rd_search * 100: .2f}%')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Try to show explanation for an input sample" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "sample_need_explain = 10\n", "y_pred_input_0, mem_val_classes, min_points_classes, max_points_classes = best_gfmm_rd_search.get_sample_explanation(X_test[sample_need_explain], X_test[sample_need_explain])" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predicted class for sample X = [0.5 0.25 0.77966102 0.54166667] is 2 and real class is 2\n" ] } ], "source": [ "print(\"Predicted class for sample X = %s is %d and real class is %d\" % (X_test[sample_need_explain], y_pred_input_0, y_test[sample_need_explain]))" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Membership values:\n", "Class 0 has the maximum membership value = 0.000000\n", "Class 1 has the maximum membership value = 0.805085\n", "Class 2 has the maximum membership value = 0.916667\n", "Class 0 has the representative hyperbox: V = [0.11111111 0.45833333 0.03389831 0.04166667] and W = [0.38888889 0.75 0.11864407 0.20833333]\n", "Class 1 has the representative hyperbox: V = [0.25 0.125 0.42372881 0.375 ] and W = [0.5 0.41666667 0.68220339 0.625 ]\n", "Class 2 has the representative hyperbox: V = [0.38888889 0.08333333 0.68221339 0.58333333] and W = [0.66666667 0.33333333 0.81355932 0.79166667]\n" ] } ], "source": [ "print(\"Membership values:\")\n", "for key, val in mem_val_classes.items():\n", " print(\"Class %d has the maximum membership value = %f\" % (key, val))\n", " \n", "for key in min_points_classes:\n", " print(\"Class %d has the representative hyperbox: V = %s and W = %s\" % (key, min_points_classes[key], max_points_classes[key]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Show explanation results by parallel coordinates" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "# Create a parallel coordinates graph\n", "best_gfmm_rd_search.show_sample_explanation(X_test[sample_need_explain], X_test[sample_need_explain], min_points_classes, max_points_classes, y_pred_input_0, file_path=\"par_cord/iris_par_cord.html\")" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Load parallel coordinates to display on the notebook\n", "from IPython.display import IFrame\n", "# We load the parallel coordinates from GitHub here for demostration in readthedocs\n", "# On the local notebook, we only need to load from the graph storing at 'par_cord/iris_par_cord.html'\n", "IFrame('https://uts-caslab.github.io/hyperbox-brain/docs/tutorials/par_cord/iris_par_cord.html', width=820, height=520)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.7" } }, "nbformat": 4, "nbformat_minor": 4 }