{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Store and Reload the Trained Models" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This example shows how to store a trained hyperbox-based model and reload it to make prediction. This example will use a random hyperboxes model for illustration." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from sklearn.datasets import make_classification\n", "from hbbrain.numerical_data.incremental_learner.iol_gfmm import ImprovedOnlineGFMM\n", "from hbbrain.numerical_data.ensemble_learner.random_hyperboxes import RandomHyperboxesClassifier\n", "from hbbrain.utils.model_storage import store_model, load_model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate training data" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "X, y = make_classification(n_samples=100, n_features=4, n_informative=2, n_redundant=0, random_state=0, shuffle=False)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# Normalise data into the range of [0, 1]\n", "from sklearn.preprocessing import MinMaxScaler\n", "scaler = MinMaxScaler()\n", "scaler.fit(X)\n", "X = scaler.transform(X)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training a random hyperboxes model" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "clf = RandomHyperboxesClassifier(base_estimator=ImprovedOnlineGFMM(0.1), n_estimators=10, random_state=0).fit(X, y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Make prediction" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predicted class for the input patter [1, 0.6, 0.5, 0.2] is 1\n" ] } ], "source": [ "y_pred = clf.predict([[1, 0.6, 0.5, 0.2]])\n", "print(\"Predicted class for the input patter [1, 0.6, 0.5, 0.2] is %d\"%y_pred[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Store the trained model" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "store_model(clf, \"store_example_model.dummy\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Reload the trained model and make prediction" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "clf_load = load_model(\"store_example_model.dummy\")" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predicted class for the input patter [1, 0.6, 0.5, 0.2] is 1\n" ] } ], "source": [ "y_pred = clf_load.predict([[1, 0.6, 0.5, 0.2]])\n", "print(\"Predicted class for the input patter [1, 0.6, 0.5, 0.2] is %d\"%y_pred[0])" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.7" } }, "nbformat": 4, "nbformat_minor": 4 }