{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Integration of Ensemble Models with Sklearn Pipeline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This example shows how to integrate the random hyperboxes classifier into the Pipeline class implemented by scikit-learn.\n", "\n", "Note that this example is illustrated by using the random hyperboxes model with original onliner learning algorithm for training base learners. However, it can be used for any ensemble model of GFMM classifiers using other learning algorithms." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import warnings\n", "warnings.filterwarnings('ignore')\n", "from sklearn.pipeline import Pipeline\n", "from sklearn.preprocessing import MinMaxScaler\n", "from sklearn.model_selection import train_test_split\n", "from hbbrain.numerical_data.incremental_learner.onln_gfmm import OnlineGFMM\n", "from hbbrain.numerical_data.ensemble_learner.random_hyperboxes import RandomHyperboxesClassifier" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load dataset.\n", "This example will use the breast cancer dataset in sklearn for illustration purposes." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from sklearn.datasets import load_breast_cancer" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "df = load_breast_cancer()\n", "X = df.data\n", "y = df.target" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.8, random_state=0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create a pipeline of pre-processing method (i.e., normalization of data in the range of [0, 1]) and a Random hyperboxes model.\n", "**Note:** the GFMM classifiers requires the input data in the range of [0, 1]." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "theta = 0.1\n", "theta_min = 0.1\n", "base_estimator = OnlineGFMM(theta=theta, theta_min=theta_min)\n", "n_estimators = 50\n", "max_samples = 0.5\n", "max_features = 0.5\n", "class_balanced = False\n", "feature_balanced = False\n", "n_jobs = 4\n", "# Init a classifier\n", "rh_clf = RandomHyperboxesClassifier(base_estimator=base_estimator, n_estimators=n_estimators, max_samples=max_samples, max_features=max_features, class_balanced=class_balanced, feature_balanced=feature_balanced, n_jobs=n_jobs, random_state=0)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# create a pipeline including data pre-processing and a classifier\n", "pipe = Pipeline([('scaler', MinMaxScaler()), ('rh_clf', rh_clf)])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Pipeline(steps=[('scaler', MinMaxScaler()),\n", " ('rh_clf',\n", " RandomHyperboxesClassifier(base_estimator=OnlineGFMM(C=array([], dtype=float64),\n", " V=array([], dtype=float64),\n", " W=array([], dtype=float64),\n", " theta=0.1,\n", " theta_min=0.1),\n", " max_features=0.5, n_estimators=50,\n", " n_jobs=4, random_state=0))])" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pipe.fit(X_train, y_train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Prediction" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Testing accuracy = 96.49%\n" ] } ], "source": [ "acc = pipe.score(X_test, y_test)\n", "print(f'Testing accuracy = {acc * 100: .2f}%')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.7" } }, "nbformat": 4, "nbformat_minor": 4 }