{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Multi-resolution Hierarchical Granular Representation based Classifier using GFMM" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This example shows how to use the multi-resolution hierarchical granular representation based classifier using general fuzzy min-max neural network." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Execute directly from the python file" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%matplotlib notebook" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import os\n", "import warnings\n", "warnings.filterwarnings('ignore')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Get the path to the this jupyter notebook file" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'C:\\\\hyperbox-brain\\\\docs\\\\tutorials'" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "this_notebook_dir = os.path.dirname(os.path.abspath(\"__file__\"))\n", "this_notebook_dir" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Get the home folder of the Hyperbox-Brain project" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "WindowsPath('C:/hyperbox-brain')" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from pathlib import Path\n", "project_dir = Path(this_notebook_dir).parent.parent\n", "project_dir" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Create the path to the Python file containing the implementation of the multi-resolution hierarchical granular representation based classifier using the general fuzzy min-max neural network" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'C:\\\\hyperbox-brain\\\\hbbrain\\\\numerical_data\\\\multigranular_learner\\\\multi_resolution_gfmm.py'" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "multi_resolution_gfmm_file_path = os.path.join(project_dir, Path(\"hbbrain/numerical_data/multigranular_learner/multi_resolution_gfmm.py\"))\n", "multi_resolution_gfmm_file_path" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Run the found file by showing the execution directions" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "usage: multi_resolution_gfmm.py [-h] -training_file TRAINING_FILE\n", " -testing_file TESTING_FILE\n", " [--val_file VAL_FILE]\n", " [--n_partitions N_PARTITIONS]\n", " [--granular_theta GRANULAR_THETA]\n", " [--gamma GAMMA]\n", " [--min_membership_aggregation MIN_MEMBERSHIP_AGGREGATION]\n", "\n", "The description of parameters\n", "\n", "required arguments:\n", " -training_file TRAINING_FILE\n", " A required argument for the path to training data file\n", " (including file name)\n", " -testing_file TESTING_FILE\n", " A required argument for the path to testing data file\n", " (including file name)\n", "\n", "optional arguments:\n", " --val_file VAL_FILE The path to validation data file (including file name)\n", " --n_partitions N_PARTITIONS\n", " Number of disjoint partitions to train base learners\n", " (default: 4)\n", " --granular_theta GRANULAR_THETA\n", " Granular maximum hyperbox sizes (default: [0.1, 0.2,\n", " 0.3, 0.4, 0.5])\n", " --gamma GAMMA A sensitivity parameter describing the speed of\n", " decreasing of the membership function in each\n", " dimension (larger than 0) (default: 1)\n", " --min_membership_aggregation MIN_MEMBERSHIP_AGGREGATION\n", " Minimum membership value for hyperbox aggregration at\n", " higher granular levels (in the range of [0, 1])\n", " (default: 0)\n" ] } ], "source": [ "!python \"{multi_resolution_gfmm_file_path}\" -h" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Create the path to training and testing datasets stored in the dataset folder" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'C:\\\\hyperbox-brain\\\\dataset\\\\syn_num_train.csv'" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "training_data_file = os.path.join(project_dir, Path(\"dataset/syn_num_train.csv\"))\n", "training_data_file" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'C:\\\\hyperbox-brain\\\\dataset\\\\syn_num_test.csv'" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "testing_data_file = os.path.join(project_dir, Path(\"dataset/syn_num_test.csv\"))\n", "testing_data_file" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Run a demo program" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### If the argument 'validation_file' gets the value of validation file path, the pruning procedure will be used after merging all hyperboxes from base learners. Otherwise, the pruning procedure will not be used." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training time: 3.847 (s)\n", "Testing accuracy (using voting from all granularity levels) = 86.70%\n", "Prediction of each base learner at a given partition:\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.\n", "[Parallel(n_jobs=4)]: Done 2 out of 4 | elapsed: 3.6s remaining: 3.6s\n", "[Parallel(n_jobs=4)]: Done 4 out of 4 | elapsed: 3.7s finished\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Partition 0 - Testing accuracy = 84.00% - No boxes = 27\n", "Partition 1 - Testing accuracy = 87.80% - No boxes = 29\n", "Partition 2 - Testing accuracy = 85.40% - No boxes = 27\n", "Partition 3 - Testing accuracy = 87.10% - No boxes = 26\n", "Prediction for each granularity level:\n", "Level 1 - Testing accuracy = 85.10% - No boxes = 101\n", "Level 2 - Testing accuracy = 88.20% - No boxes = 38\n", "Level 3 - Testing accuracy = 87.10% - No boxes = 27\n", "Level 4 - Testing accuracy = 86.10% - No boxes = 20\n", "Level 5 - Testing accuracy = 86.20% - No boxes = 14\n", "Level 6 - Testing accuracy = 82.60% - No boxes = 10\n" ] } ], "source": [ "!python \"{multi_resolution_gfmm_file_path}\" -training_file \"{training_data_file}\" -testing_file \"{testing_data_file}\" --n_partitions 4 --granular_theta \"[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]\" --gamma 1 --min_membership_aggregation 0.1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Using the multi-resolution hierarchical granular representation based classifier using general fuzzy min-max neural network through init, fit, and predict functions" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "from hbbrain.numerical_data.multigranular_learner.multi_resolution_gfmm import MultiGranularGFMM\n", "import pandas as pd" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Create training and testing data sets" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "df_train = pd.read_csv(training_data_file, header=None)\n", "df_test = pd.read_csv(testing_data_file, header=None)\n", "\n", "Xy_train = df_train.to_numpy()\n", "Xy_test = df_test.to_numpy()\n", "\n", "Xtr = Xy_train[:, :-1]\n", "ytr = Xy_train[:, -1]\n", "\n", "Xtest = Xy_test[:, :-1]\n", "ytest = Xy_test[:, -1]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Initializing parameters" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "# number of disjoint partitions to build base learners\n", "n_partitions = 4\n", "# a list of maximum hyperbox sizes for granularity levels\n", "granular_theta = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]\n", "# minimum membership values between two hyperboxes aggregated at higher abstraction levels\n", "min_membership_aggregation = 0.1\n", "# the speed of decreasing of membership values\n", "gamma = 1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Training" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.\n", "[Parallel(n_jobs=4)]: Done 2 out of 4 | elapsed: 2.9s remaining: 2.9s\n", "[Parallel(n_jobs=4)]: Done 4 out of 4 | elapsed: 2.9s finished\n" ] }, { "data": { "text/plain": [ "MultiGranularGFMM(granular_theta=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],\n", " min_membership_aggregation=0.1)" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from hbbrain.constants import HETEROGENEOUS_CLASS_LEARNING\n", "multi_granular_gfmm_clf = MultiGranularGFMM(n_partitions=n_partitions, granular_theta=granular_theta, gamma=gamma, min_membership_aggregation=min_membership_aggregation)\n", "# Training using the heterogeneous model for class labels.\n", "multi_granular_gfmm_clf.fit(Xtr, ytr, learning_type=HETEROGENEOUS_CLASS_LEARNING)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### The code below shows how to display decision boundaries among classes at a given granularity level if input data are 2-dimensional" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "application/javascript": [ "/* Put everything inside the global mpl namespace */\n", "window.mpl = {};\n", "\n", "\n", "mpl.get_websocket_type = function() {\n", " if (typeof(WebSocket) !== 'undefined') {\n", " return WebSocket;\n", " } else if (typeof(MozWebSocket) !== 'undefined') {\n", " return MozWebSocket;\n", " } else {\n", " alert('Your browser does not have WebSocket support.' +\n", " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", " 'Firefox 4 and 5 are also supported but you ' +\n", " 'have to enable WebSockets in about:config.');\n", " };\n", "}\n", "\n", "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", " this.id = figure_id;\n", "\n", " this.ws = websocket;\n", "\n", " this.supports_binary = (this.ws.binaryType != undefined);\n", "\n", " if (!this.supports_binary) {\n", " var warnings = document.getElementById(\"mpl-warnings\");\n", " if (warnings) {\n", " warnings.style.display = 'block';\n", " warnings.textContent = (\n", " \"This browser does not support binary websocket messages. \" +\n", " \"Performance may be slow.\");\n", " }\n", " }\n", "\n", " this.imageObj = new Image();\n", "\n", " this.context = undefined;\n", " this.message = undefined;\n", " this.canvas = undefined;\n", " this.rubberband_canvas = undefined;\n", " this.rubberband_context = undefined;\n", " this.format_dropdown = undefined;\n", "\n", " this.image_mode = 'full';\n", "\n", " this.root = $('
');\n", " this._root_extra_style(this.root)\n", " this.root.attr('style', 'display: inline-block');\n", "\n", " $(parent_element).append(this.root);\n", "\n", " this._init_header(this);\n", " this._init_canvas(this);\n", " this._init_toolbar(this);\n", "\n", " var fig = this;\n", "\n", " this.waiting = false;\n", "\n", " this.ws.onopen = function () {\n", " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", " fig.send_message(\"send_image_mode\", {});\n", " if (mpl.ratio != 1) {\n", " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n", " }\n", " fig.send_message(\"refresh\", {});\n", " }\n", "\n", " this.imageObj.onload = function() {\n", " if (fig.image_mode == 'full') {\n", " // Full images could contain transparency (where diff images\n", " // almost always do), so we need to clear the canvas so that\n", " // there is no ghosting.\n", " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", " }\n", " fig.context.drawImage(fig.imageObj, 0, 0);\n", " };\n", "\n", " this.imageObj.onunload = function() {\n", " fig.ws.close();\n", " }\n", "\n", " this.ws.onmessage = this._make_on_message_function(this);\n", "\n", " this.ondownload = ondownload;\n", "}\n", "\n", "mpl.figure.prototype._init_header = function() {\n", " var titlebar = $(\n", " '
');\n", " var titletext = $(\n", " '
');\n", " titlebar.append(titletext)\n", " this.root.append(titlebar);\n", " this.header = titletext[0];\n", "}\n", "\n", "\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "\n", "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "mpl.figure.prototype._init_canvas = function() {\n", " var fig = this;\n", "\n", " var canvas_div = $('
');\n", "\n", " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", "\n", " function canvas_keyboard_event(event) {\n", " return fig.key_event(event, event['data']);\n", " }\n", "\n", " canvas_div.keydown('key_press', canvas_keyboard_event);\n", " canvas_div.keyup('key_release', canvas_keyboard_event);\n", " this.canvas_div = canvas_div\n", " this._canvas_extra_style(canvas_div)\n", " this.root.append(canvas_div);\n", "\n", " var canvas = $('');\n", " canvas.addClass('mpl-canvas');\n", " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", "\n", " this.canvas = canvas[0];\n", " this.context = canvas[0].getContext(\"2d\");\n", "\n", " var backingStore = this.context.backingStorePixelRatio ||\n", "\tthis.context.webkitBackingStorePixelRatio ||\n", "\tthis.context.mozBackingStorePixelRatio ||\n", "\tthis.context.msBackingStorePixelRatio ||\n", "\tthis.context.oBackingStorePixelRatio ||\n", "\tthis.context.backingStorePixelRatio || 1;\n", "\n", " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n", "\n", " var rubberband = $('');\n", " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", "\n", " var pass_mouse_events = true;\n", "\n", " canvas_div.resizable({\n", " start: function(event, ui) {\n", " pass_mouse_events = false;\n", " },\n", " resize: function(event, ui) {\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " stop: function(event, ui) {\n", " pass_mouse_events = true;\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " });\n", "\n", " function mouse_event_fn(event) {\n", " if (pass_mouse_events)\n", " return fig.mouse_event(event, event['data']);\n", " }\n", "\n", " rubberband.mousedown('button_press', mouse_event_fn);\n", " rubberband.mouseup('button_release', mouse_event_fn);\n", " // Throttle sequential mouse events to 1 every 20ms.\n", " rubberband.mousemove('motion_notify', mouse_event_fn);\n", "\n", " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", "\n", " canvas_div.on(\"wheel\", function (event) {\n", " event = event.originalEvent;\n", " event['data'] = 'scroll'\n", " if (event.deltaY < 0) {\n", " event.step = 1;\n", " } else {\n", " event.step = -1;\n", " }\n", " mouse_event_fn(event);\n", " });\n", "\n", " canvas_div.append(canvas);\n", " canvas_div.append(rubberband);\n", "\n", " this.rubberband = rubberband;\n", " this.rubberband_canvas = rubberband[0];\n", " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", " this.rubberband_context.strokeStyle = \"#000000\";\n", "\n", " this._resize_canvas = function(width, height) {\n", " // Keep the size of the canvas, canvas container, and rubber band\n", " // canvas in synch.\n", " canvas_div.css('width', width)\n", " canvas_div.css('height', height)\n", "\n", " canvas.attr('width', width * mpl.ratio);\n", " canvas.attr('height', height * mpl.ratio);\n", " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n", "\n", " rubberband.attr('width', width);\n", " rubberband.attr('height', height);\n", " }\n", "\n", " // Set the figure to an initial 600x600px, this will subsequently be updated\n", " // upon first draw.\n", " this._resize_canvas(600, 600);\n", "\n", " // Disable right mouse context menu.\n", " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", " return false;\n", " });\n", "\n", " function set_focus () {\n", " canvas.focus();\n", " canvas_div.focus();\n", " }\n", "\n", " window.setTimeout(set_focus, 100);\n", "}\n", "\n", "mpl.figure.prototype._init_toolbar = function() {\n", " var fig = this;\n", "\n", " var nav_element = $('
')\n", " nav_element.attr('style', 'width: 100%');\n", " this.root.append(nav_element);\n", "\n", " // Define a callback function for later on.\n", " function toolbar_event(event) {\n", " return fig.toolbar_button_onclick(event['data']);\n", " }\n", " function toolbar_mouse_event(event) {\n", " return fig.toolbar_button_onmouseover(event['data']);\n", " }\n", "\n", " for(var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " // put a spacer in here.\n", " continue;\n", " }\n", " var button = $('');\n", " button.click(method_name, toolbar_event);\n", " button.mouseover(tooltip, toolbar_mouse_event);\n", " nav_element.append(button);\n", " }\n", "\n", " // Add the status bar.\n", " var status_bar = $('');\n", " nav_element.append(status_bar);\n", " this.message = status_bar[0];\n", "\n", " // Add the close button to the window.\n", " var buttongrp = $('
');\n", " var button = $('');\n", " button.click(function (evt) { fig.handle_close(fig, {}); } );\n", " button.mouseover('Stop Interaction', toolbar_mouse_event);\n", " buttongrp.append(button);\n", " var titlebar = this.root.find($('.ui-dialog-titlebar'));\n", " titlebar.prepend(buttongrp);\n", "}\n", "\n", "mpl.figure.prototype._root_extra_style = function(el){\n", " var fig = this\n", " el.on(\"remove\", function(){\n", "\tfig.close_ws(fig, {});\n", " });\n", "}\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(el){\n", " // this is important to make the div 'focusable\n", " el.attr('tabindex', 0)\n", " // reach out to IPython and tell the keyboard manager to turn it's self\n", " // off when our div gets focus\n", "\n", " // location in version 3\n", " if (IPython.notebook.keyboard_manager) {\n", " IPython.notebook.keyboard_manager.register_events(el);\n", " }\n", " else {\n", " // location in version 2\n", " IPython.keyboard_manager.register_events(el);\n", " }\n", "\n", "}\n", "\n", "mpl.figure.prototype._key_event_extra = function(event, name) {\n", " var manager = IPython.notebook.keyboard_manager;\n", " if (!manager)\n", " manager = IPython.keyboard_manager;\n", "\n", " // Check for shift+enter\n", " if (event.shiftKey && event.which == 13) {\n", " this.canvas_div.blur();\n", " event.shiftKey = false;\n", " // Send a \"J\" for go to next cell\n", " event.which = 74;\n", " event.keyCode = 74;\n", " manager.command_mode();\n", " manager.handle_keydown(event);\n", " }\n", "}\n", "\n", "mpl.figure.prototype.handle_save = function(fig, msg) {\n", " fig.ondownload(fig, null);\n", "}\n", "\n", "\n", "mpl.find_output_cell = function(html_output) {\n", " // Return the cell and output element which can be found *uniquely* in the notebook.\n", " // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n", " // IPython event is triggered only after the cells have been serialised, which for\n", " // our purposes (turning an active figure into a static one), is too late.\n", " var cells = IPython.notebook.get_cells();\n", " var ncells = cells.length;\n", " for (var i=0; i= 3 moved mimebundle to data attribute of output\n", " data = data.data;\n", " }\n", " if (data['text/html'] == html_output) {\n", " return [cell, data, j];\n", " }\n", " }\n", " }\n", " }\n", "}\n", "\n", "// Register the function which deals with the matplotlib target/channel.\n", "// The kernel may be null if the page has been refreshed.\n", "if (IPython.notebook.kernel != null) {\n", " IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n", "}\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# showing hyperboxes and boundaries at the last granularity level (level 6)\n", "multi_granular_gfmm_clf.draw_2D_hyperbox_and_boundary_granular_level(window_name=\"Hyperbox-based classifier and its decision boundaries at level 6\", level = 5)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total number of hyperboxes at all granularity levels = 210\n" ] } ], "source": [ "# Get total number of hyperboxes at all granularity levels\n", "print(\"Total number of hyperboxes at all granularity levels = %d\"%multi_granular_gfmm_clf.get_n_hyperboxes(level=-1))" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total number of hyperboxes at the first granularity levels = 101\n", "Total number of hyperboxes at the last granularity levels = 10\n" ] } ], "source": [ "# Get number of hyperboxes at a given granularity level\n", "print(\"Total number of hyperboxes at the first granularity levels = %d\"%multi_granular_gfmm_clf.get_n_hyperboxes(level=0))\n", "print(\"Total number of hyperboxes at the last granularity levels = %d\"%multi_granular_gfmm_clf.get_n_hyperboxes(level=5))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prediction" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Using all GFMM models from all granularity level to make the final prediction using majority voting, in which each granularity level contributes one predicted result for each input pattern and the final predicted result is the class getting most of votes from the models at all granularity levels." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics import accuracy_score" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy (majority voting) = 86.70%\n" ] } ], "source": [ "y_pred = multi_granular_gfmm_clf.predict(Xtest, level=-1)\n", "acc = accuracy_score(ytest, y_pred)\n", "print(f'Accuracy (majority voting) = {acc * 100: .2f}%')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Use a certain granularity level to make prediction" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Prediction for each granularity level:\n", "Level 1 - Testing accuracy = 85.10% - No hyperboxes = 101\n", "Level 2 - Testing accuracy = 88.20% - No hyperboxes = 38\n", "Level 3 - Testing accuracy = 87.10% - No hyperboxes = 27\n", "Level 4 - Testing accuracy = 86.10% - No hyperboxes = 20\n", "Level 5 - Testing accuracy = 86.20% - No hyperboxes = 14\n", "Level 6 - Testing accuracy = 82.60% - No hyperboxes = 10\n" ] } ], "source": [ "print(\"Prediction for each granularity level:\")\n", "for i in range(len(granular_theta)):\n", " y_pred_lv = multi_granular_gfmm_clf.predict(Xtest, level=i)\n", " acc_lv = accuracy_score(ytest, y_pred_lv)\n", " n_boxes = multi_granular_gfmm_clf.get_n_hyperboxes(i)\n", " print(f'Level {i + 1} - Testing accuracy = {acc_lv * 100: .2f}% - No hyperboxes = {n_boxes}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Explaining the predicted result for the input sample by showing membership values and hyperboxes for each class using the model at a given granularity level" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "sample_need_explain = 0\n", "# Using the trained model at the sixth granularity level to make prediction and explanation. Note that the value for the level parameter starts from 0. \n", "level_explain = 5\n", "y_pred_input_0, mem_val_classes, min_points_classes, max_points_classes = multi_granular_gfmm_clf.get_sample_explanation_granular_level(Xtest[sample_need_explain], Xtest[sample_need_explain], level_explain)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predicted class for sample X = [0.752930, 0.385920] is 1 and real class is 2\n" ] } ], "source": [ "print(\"Predicted class for sample X = [%f, %f] is %d and real class is %d\" % (Xtest[sample_need_explain, 0], Xtest[sample_need_explain, 1], y_pred_input_0, ytest[sample_need_explain]))" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Membership values:\n", "Class 1 has the maximum membership value = 0.989125\n", "Class 2 has the maximum membership value = 0.915132\n", "Class 1 has the representative hyperbox: V = [0.763805 0.369765] and W = [0.91185 0.48598]\n", "Class 2 has the representative hyperbox: V = [0.563065 0.17003 ] and W = [0.6680625 0.65662 ]\n" ] } ], "source": [ "print(\"Membership values:\")\n", "for key, val in mem_val_classes.items():\n", " print(\"Class %d has the maximum membership value = %f\" % (key, val))\n", " \n", "for key in min_points_classes:\n", " print(\"Class %d has the representative hyperbox: V = %s and W = %s\" % (key, min_points_classes[key], max_points_classes[key]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Show input sample and hyperboxes belonging to each class. In 2D, we can show rectangles or use parallel coordinates" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Using rectangles to show explanations" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "application/javascript": [ "/* Put everything inside the global mpl namespace */\n", "window.mpl = {};\n", "\n", "\n", "mpl.get_websocket_type = function() {\n", " if (typeof(WebSocket) !== 'undefined') {\n", " return WebSocket;\n", " } else if (typeof(MozWebSocket) !== 'undefined') {\n", " return MozWebSocket;\n", " } else {\n", " alert('Your browser does not have WebSocket support.' +\n", " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", " 'Firefox 4 and 5 are also supported but you ' +\n", " 'have to enable WebSockets in about:config.');\n", " };\n", "}\n", "\n", "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", " this.id = figure_id;\n", "\n", " this.ws = websocket;\n", "\n", " this.supports_binary = (this.ws.binaryType != undefined);\n", "\n", " if (!this.supports_binary) {\n", " var warnings = document.getElementById(\"mpl-warnings\");\n", " if (warnings) {\n", " warnings.style.display = 'block';\n", " warnings.textContent = (\n", " \"This browser does not support binary websocket messages. \" +\n", " \"Performance may be slow.\");\n", " }\n", " }\n", "\n", " this.imageObj = new Image();\n", "\n", " this.context = undefined;\n", " this.message = undefined;\n", " this.canvas = undefined;\n", " this.rubberband_canvas = undefined;\n", " this.rubberband_context = undefined;\n", " this.format_dropdown = undefined;\n", "\n", " this.image_mode = 'full';\n", "\n", " this.root = $('
');\n", " this._root_extra_style(this.root)\n", " this.root.attr('style', 'display: inline-block');\n", "\n", " $(parent_element).append(this.root);\n", "\n", " this._init_header(this);\n", " this._init_canvas(this);\n", " this._init_toolbar(this);\n", "\n", " var fig = this;\n", "\n", " this.waiting = false;\n", "\n", " this.ws.onopen = function () {\n", " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", " fig.send_message(\"send_image_mode\", {});\n", " if (mpl.ratio != 1) {\n", " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n", " }\n", " fig.send_message(\"refresh\", {});\n", " }\n", "\n", " this.imageObj.onload = function() {\n", " if (fig.image_mode == 'full') {\n", " // Full images could contain transparency (where diff images\n", " // almost always do), so we need to clear the canvas so that\n", " // there is no ghosting.\n", " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", " }\n", " fig.context.drawImage(fig.imageObj, 0, 0);\n", " };\n", "\n", " this.imageObj.onunload = function() {\n", " fig.ws.close();\n", " }\n", "\n", " this.ws.onmessage = this._make_on_message_function(this);\n", "\n", " this.ondownload = ondownload;\n", "}\n", "\n", "mpl.figure.prototype._init_header = function() {\n", " var titlebar = $(\n", " '
');\n", " var titletext = $(\n", " '
');\n", " titlebar.append(titletext)\n", " this.root.append(titlebar);\n", " this.header = titletext[0];\n", "}\n", "\n", "\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "\n", "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "mpl.figure.prototype._init_canvas = function() {\n", " var fig = this;\n", "\n", " var canvas_div = $('
');\n", "\n", " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", "\n", " function canvas_keyboard_event(event) {\n", " return fig.key_event(event, event['data']);\n", " }\n", "\n", " canvas_div.keydown('key_press', canvas_keyboard_event);\n", " canvas_div.keyup('key_release', canvas_keyboard_event);\n", " this.canvas_div = canvas_div\n", " this._canvas_extra_style(canvas_div)\n", " this.root.append(canvas_div);\n", "\n", " var canvas = $('');\n", " canvas.addClass('mpl-canvas');\n", " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", "\n", " this.canvas = canvas[0];\n", " this.context = canvas[0].getContext(\"2d\");\n", "\n", " var backingStore = this.context.backingStorePixelRatio ||\n", "\tthis.context.webkitBackingStorePixelRatio ||\n", "\tthis.context.mozBackingStorePixelRatio ||\n", "\tthis.context.msBackingStorePixelRatio ||\n", "\tthis.context.oBackingStorePixelRatio ||\n", "\tthis.context.backingStorePixelRatio || 1;\n", "\n", " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n", "\n", " var rubberband = $('');\n", " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", "\n", " var pass_mouse_events = true;\n", "\n", " canvas_div.resizable({\n", " start: function(event, ui) {\n", " pass_mouse_events = false;\n", " },\n", " resize: function(event, ui) {\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " stop: function(event, ui) {\n", " pass_mouse_events = true;\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " });\n", "\n", " function mouse_event_fn(event) {\n", " if (pass_mouse_events)\n", " return fig.mouse_event(event, event['data']);\n", " }\n", "\n", " rubberband.mousedown('button_press', mouse_event_fn);\n", " rubberband.mouseup('button_release', mouse_event_fn);\n", " // Throttle sequential mouse events to 1 every 20ms.\n", " rubberband.mousemove('motion_notify', mouse_event_fn);\n", "\n", " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", "\n", " canvas_div.on(\"wheel\", function (event) {\n", " event = event.originalEvent;\n", " event['data'] = 'scroll'\n", " if (event.deltaY < 0) {\n", " event.step = 1;\n", " } else {\n", " event.step = -1;\n", " }\n", " mouse_event_fn(event);\n", " });\n", "\n", " canvas_div.append(canvas);\n", " canvas_div.append(rubberband);\n", "\n", " this.rubberband = rubberband;\n", " this.rubberband_canvas = rubberband[0];\n", " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", " this.rubberband_context.strokeStyle = \"#000000\";\n", "\n", " this._resize_canvas = function(width, height) {\n", " // Keep the size of the canvas, canvas container, and rubber band\n", " // canvas in synch.\n", " canvas_div.css('width', width)\n", " canvas_div.css('height', height)\n", "\n", " canvas.attr('width', width * mpl.ratio);\n", " canvas.attr('height', height * mpl.ratio);\n", " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n", "\n", " rubberband.attr('width', width);\n", " rubberband.attr('height', height);\n", " }\n", "\n", " // Set the figure to an initial 600x600px, this will subsequently be updated\n", " // upon first draw.\n", " this._resize_canvas(600, 600);\n", "\n", " // Disable right mouse context menu.\n", " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", " return false;\n", " });\n", "\n", " function set_focus () {\n", " canvas.focus();\n", " canvas_div.focus();\n", " }\n", "\n", " window.setTimeout(set_focus, 100);\n", "}\n", "\n", "mpl.figure.prototype._init_toolbar = function() {\n", " var fig = this;\n", "\n", " var nav_element = $('
')\n", " nav_element.attr('style', 'width: 100%');\n", " this.root.append(nav_element);\n", "\n", " // Define a callback function for later on.\n", " function toolbar_event(event) {\n", " return fig.toolbar_button_onclick(event['data']);\n", " }\n", " function toolbar_mouse_event(event) {\n", " return fig.toolbar_button_onmouseover(event['data']);\n", " }\n", "\n", " for(var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " // put a spacer in here.\n", " continue;\n", " }\n", " var button = $('