In questo articolo implementiamo una semplice Convolutional Neural Network per la classificazione di immagini utilizzando il dataset CIFAR-10.
Il dataset CIFAR-10
Il primo passo è quello di comprendere il dataset con cui addestrare e valutare la rete. CIFAR-10 consiste di 60000 immagini a colori 32x32. Ciascuna immagine contiene oggetti appartenenti ad una di 10 possibili categorie.
Per ogni categoria sono definite 6000 immagini. Di tutto il dataset 50000 immagini sono utilizzate per il training mentre 1000 per il testing. Al seguente link è possibile visionare alcune immagini di esempio e avere maggiori dettagli sul dataset.
Download del dataset
Costruiamo una classe Cnn per la realizzazione della rete,videnziamo quindi gli import necessari:
import org.datavec.image.loader.CifarLoader;
import org.deeplearning4j.datasets.fetchers.DataSetType;
import org.deeplearning4j.datasets.iterator.impl.Cifar10DataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.BatchNormalization;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.InvocationType;
import org.deeplearning4j.optimize.listeners.EvaluativeListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.learning.config.AdaDelta;
import org.nd4j.linalg.lossfunctions.LossFunctions;
Definiamo alcune costanti necessarie per il training:
public class Cnn {
private final static int height = 32;
private final static int width = 32;
private final static int channels = 3;
private final static int numLabels = CifarLoader.NUM_LABELS;
private final static int batchSize = 96;
private final static long seed = 123L;
private final static int epochs = 5;
public static void main(String[] args) throws Exception {}
Dove dal CifarLoader otteniamo il numero di label per le immagini, mentre con le altre costanti definiamo la risoluzione delle stesse, la dimensione del batch e il numero di epoche. Proseguiamo con l'inserimento del codice per il download del dataset:
public static void main(String[] args) throws Exception {
Cnn cnn = new Cnn();
Cifar10DataSetIterator cifar = new Cifar10DataSetIterator(batchSize, new int[]{height, width}, DataSetType.TRAIN, null, seed);
Cifar10DataSetIterator cifarEval = new Cifar10DataSetIterator(batchSize, new int[]{height, width}, DataSetType.TEST, null, seed);
}
Una volta ottenuto il dataset iniziamo la costruzione del nostro modello e definiamo un metodo a parte.
Modello della rete
Per costruire il modello della rete Cnn, definiamo una cascata di layer di convoluzione alcuni dei quali seguiti da un max pooling o average pooling:
public MultiLayerNetwork getModel() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed)
.updater(new Sgd(0.01))
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.weightInit(WeightInit.XAVIER)
.list()
.layer(new ConvolutionLayer.Builder().kernelSize(3,3).stride(1,1).padding(1,1).activation(Activation.LEAKYRELU)
.nIn(channels).nOut(32).build())
.layer(new BatchNormalization())
.layer(new SubsamplingLayer.Builder().kernelSize(2,2).stride(2,2).poolingType(SubsamplingLayer.PoolingType.MAX).build())
.layer(new ConvolutionLayer.Builder().kernelSize(1,1).stride(1,1).padding(1,1).activation(Activation.LEAKYRELU)
.nOut(16).build())
.layer(new BatchNormalization())
.layer(new ConvolutionLayer.Builder().kernelSize(3,3).stride(1,1).padding(1,1).activation(Activation.LEAKYRELU)
.nOut(64).build())
.layer(new BatchNormalization())
.layer(new SubsamplingLayer.Builder().kernelSize(2,2).stride(2,2).poolingType(SubsamplingLayer.PoolingType.MAX).build())
.layer(new ConvolutionLayer.Builder().kernelSize(1,1).stride(1,1).padding(1,1).activation(Activation.LEAKYRELU)
.nOut(32).build())
.layer(new BatchNormalization())
.layer(new ConvolutionLayer.Builder().kernelSize(3,3).stride(1,1).padding(1,1).activation(Activation.LEAKYRELU)
.nOut(128).build())
.layer(new BatchNormalization())
.layer(new ConvolutionLayer.Builder().kernelSize(1,1).stride(1,1).padding(1,1).activation(Activation.LEAKYRELU)
.nOut(64).build())
.layer(new BatchNormalization())
.layer(new ConvolutionLayer.Builder().kernelSize(1,1).stride(1,1).padding(1,1).activation(Activation.LEAKYRELU)
.nOut(numLabels).build())
.layer(new BatchNormalization())
.layer(new SubsamplingLayer.Builder().kernelSize(2,2).stride(2,2).poolingType(SubsamplingLayer.PoolingType.AVG).build())
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.name("output")
.nOut(numLabels)
.dropOut(0.8)
.activation(Activation.SOFTMAX)
.build())
.setInputType(InputType.convolutional(height, width, channels))
.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
return model;
}
Per il training utilizziamo un learning rate di 0.01. La costruzione della rete inizia da un layer di convoluzione che acquisisce in input le immagini, successivamente si applicano layer per l'estrazione di aspetti significativi delle immagini.
Facciamo attenzione ad applicare una batch normalization per ottenere un buon training, con essa facciamo in modo che i dati abbiano sempre valore medio nullo e deviazione standard 1. Ogni layer utilizza una activation function di tipo RELU per evitare di avere valori negativi.
Questa funzione è definita come max(0,x)
quindi calcola il massimo tra il valore x e lo zero.
L'ultimo layer è di classificazione, fornisce in output la distribuzione delle probabilità su ciascuna classe. La classificazione dell'immagine sarà ottenuta prendendo l'output corrispondente alla classe con probabilità maggiore.
Training
Per il training, completiamo il codice della classe main
come segue:
public static void main(String[] args) throws Exception {
Cnn cnn = new Cnn();
Cifar10DataSetIterator cifar = new Cifar10DataSetIterator(batchSize, new int[]{height, width}, DataSetType.TRAIN, null, seed);
Cifar10DataSetIterator cifarEval = new Cifar10DataSetIterator(batchSize, new int[]{height, width}, DataSetType.TEST, null, seed);
MultiLayerNetwork model = cnn.getModel();
model.setListeners(new ScoreIterationListener(50), new EvaluativeListener(cifarEval, 1, InvocationType.EPOCH_END));
model.fit(cifar, epochs);
}
Avviamo l'applicativo e attendiamo i risultati.