Java Top

I just announced the new Learn Spring course, focused on the fundamentals of Spring 5 and Spring Boot 2:


1. Overview

In this tutorial, we'll build and train a convolutional neural network model using the Deeplearning4j library in Java.

For further information on how to set up the library please refer to our guide on Deeplearning4j.

2. Image Classification

2.1. Problem Statement

Suppose we have a set of images. Each image represents an object of a particular class. Moreover, the object on the image belongs to the only known class. So, the problem statement is to build the model which will be able to recognize the class of the object on the given image.

For example, let's say we have a set of images with ten hand gestures. We build a model and train it to classify them. Then after training, we may pass other images and classify the hand gestures on them. Of course, the given gesture should belong to the known classes.

2.2. Image Representation

In computer memory, the image can be represented as a matrix of numbers. Each number is a pixel value, ranging from 0 to 255.

A grayscale image is a 2D matrix. Similarly, the RGB image is a 3D matrix with width, height, and depth dimensions.

As we may see, the image is a set of numbers. Therefore, we can build multi-layer network models to train them to classify images.

3. Convolutional Neural Networks

A Convolutional Neural Network (CNN) is a multi-layer network model that has a specific structure. The structure of a CNN may be divided into two blocks: convolutional layers and fully connected (or dense) layers. Let's look at each of them.

3.1. Convolutional Layer

Each convolutional layer is a set of square matrices, called kernels. Above all, we need them to perform convolution on the input image. Their amount and size may vary, depending on the given dataset. We mostly use 3×3 or 5×5 kernels, and rarely 7×7 ones. The exact size and amount are selected by trial and error.

In addition, we randomly select the variables of kernel matrices at the beginning of the train. They are the weights of the network.

To perform convolution, we can use the kernel as the sliding window. We will multiply the kernel weights to the corresponding image pixels and compute the sum. Then we can move the kernel to cover the next chunk of the image using stride (move right) and padding (move down). As a result,  we'll have values that will be used in further computations.

In short, with this layer, we get a convolved image. Some variables might be less than zero. This usually means that these variables are less important than the other ones. That is why applying the ReLU function is a good approach to make fewer computations further.

3.2. Subsampling Layer

The subsampling (or pooling) layer is a layer of the network, usually used after the convolutional one. After the convolution, we get a lot of computed variables. However, our task is to choose the most valuable among them.

The approach is to apply a sliding window algorithm to the convolved image. At each step, we'll choose the maximum value in the square window of a predefined size, usually between 2×2 and 5×5 pixels. As a result, we'll have fewer computed parameters. Therefore, this will reduce the computations.

3.3. Dense Layer

A dense (or fully-connected) layer is one that consists of multiple neurons. We need this layer to perform classification. Moreover, there might be two or more of such consequent layers. Importantly, the last layer should have a size equal to the number of classes for classification.

The output of the network is the probability of the image belonging to each of the classes. To predict the probabilities, we'll use the Softmax activation function.

3.4. Optimization Techniques

To perform training, we need to optimize the weights. Remember, we randomly choose these variables initially. The neural network is a big function. And, it has lots of unknown parameters, our weights.

When we pass an image to the network, it gives us the answer. Then, we may build a loss function, which will depend on this answer. In terms of supervised learning, we also have an actual answer – the true class. Our mission is to minimize this loss function. If we succeed, then our model is well-trained.

To minimize the function, we have to update the weights of the network. In order to do that, we can compute the derivative of the loss function with respect to each of these unknown parameters. Then, we can update each weight.

We may increase or decrease the weight value to find the local minimum of our loss function because we know the slope. Moreover, this process is iterative and is called Gradient Descent. Backpropagation uses gradient descent to propagate the weight update from the end to the beginning of the network.

In this tutorial, we'll use the Stochastic Gradient Decent (SGD) optimization algorithm. The main idea is that we randomly choose the batch of train images at each step. Then we apply backpropagation.

3.5. Evaluation Metrics

Finally, after training the network, we need to get information about how well our model performs.

The mostly used metric is accuracy. This is the ratio of correctly classified images to all images. Meanwhile, recall, precision, and F1-score are very important metrics for image classification as well.

4. Dataset Preparation

In this section, we'll prepare the images. Let's use the embedded CIFAR10 dataset in this tutorial. We'll create iterators to access the images:

public class CifarDatasetService implements IDataSetService {

    private CifarDataSetIterator trainIterator;
    private CifarDataSetIterator testIterator;

    public CifarDatasetService() {
         trainIterator = new CifarDataSetIterator(trainBatch, trainImagesNum, true);
         testIterator = new CifarDataSetIterator(testBatch, testImagesNum, false);

    // other methods and fields declaration


We can choose some parameters on our own. TrainBatch and testBatch are the numbers of images per train and evaluation step respectively. TrainImagesNum and testImagesNum are the numbers of images for training and testing. One epoch lasts trainImagesNum / trainBatch steps. So, having 2048 train images with a batch size = 32 will lead to 2048 / 32 = 64 steps per one epoch.

5. Convolutional Neural Network in Deeplearning4j

5.1. Building the Model

Next, let's build our CNN model from scratch. To do it, we'll use convolutional, subsampling (pooling), and fully connected (dense) layers.

MultiLayerConfiguration configuration = new NeuralNetConfiguration.Builder()
  .layer(0, conv5x5())
  .layer(1, pooling2x2Stride2())
  .layer(2, conv3x3Stride1Padding2())
  .layer(3, pooling2x2Stride1())
  .layer(4, conv3x3Stride1Padding1())
  .layer(5, pooling2x2Stride1())
  .layer(6, dense())

network = new MultiLayerNetwork(configuration);

Here we specify the learning rate, the update algorithm, the input type of our model, and the layered architecture. We can experiment on these configurations. Thus, we can train many models with different architectures and training parameters. Furthermore, we can compare the results and choose the best model.

5.2. Training the Model

Then, we'll train the built model. This can be done in a few lines of code:

public void train() {
    IntStream.range(1, epochsNum + 1).forEach(epoch -> {;

The number of epochs is the parameter that we can specify ourselves. We have a small dataset. As a result, several hundred epochs will be enough.

5.3. Evaluating the Model

Finally, we can evaluate the now-trained model. Deeplearning4j library provides an ability to do it easily:

public Evaluation evaluate() {
   return network.evaluate(dataSetService.testIterator());

Evaluation is an object, which contains computed metrics after training the model. Those are accuracy, precision, recall, and F1 score. Moreover, it has a friendly printable interface:

# of classes: 11
Accuracy: 0,8406
Precision: 0,7303
Recall: 0,6820
F1 Score: 0,6466

6. Conclusion

In this tutorial, we've learned about the architecture of CNN models, optimization techniques, and evaluation metrics. Furthermore, we've implemented the model using the Deeplearning4j library in Java.

As usual, code for this example is available over on GitHub.

Java bottom

I just announced the new Learn Spring course, focused on the fundamentals of Spring 5 and Spring Boot 2:

Comments are closed on this article!