This page explains how to load the MNIST database of handwritten digits with Matlab. This page is based on the converter developed by Markus Mayer, available on Github. If you reached this page, you don't have to download and run the converter, I did it for you. The Markus Mayer's converter output a .mat file containing the data, you can downlod the file below:
Once the mnist.mat
file is downloaded, run the following command to load the dataset:
load ('mnist.mat')
Once the dataset is loaded, type the who
command to list the variables:
>> who
Your variables are:
test training
The variable training
contains the training set, and the variable test
contains the test dataset.
Each variable is a structure composed of the following elements:
Name | Description |
---|---|
count | Number of image in the dataset |
width | Width of each image |
height | Height of an image |
images | Data of the images (see below) |
labels | Label of each image, number written in the image |
The images
variable is an array of width
x height
x count
the grayscale value
of each pixel encoded between 0 and 1.
To check if the dataset is properly loaded, you can display a digit with the following command:
image (training.images(:,:,18)*255);
Since the image is stored with values between 0 and 1, you have to scale the value between
0 and 255. The Matlab function rescale
is dedicated to this purpose.
image (rescale(training.images(:,:,18),0,255));
Here is the result:
Note that to get a linear coloring, I modified the CDataMapping
property of the image:
im = image (rescale(training.images(:,:,18),0,255));
im.CDataMapping = 'scaled';
To get the save scale on each axis, use the following command:
axis square equal
To get a gray scale image, use the following:
colormap(gray)
You can also display the color bar with this command:
colorbar
Since the label are stored in the dataset, it becomes easy to get the label (or digit) associated to an image:
>> training.labels(18)
ans =
8
For example, it becomes easy to display the label in the figure title:
title (sprintf ("Digit: %d", training.labels(18)))