Example code for XOR solver using Apache Spark MLlib.

Training data:

//text is in libsvm format
String sampleText = 
"0 1:0 2:0\n"+
"1 1:0 2:1\n"+
"1 1:1 2:0\n"+
"0 1:1 2:1";

//spark session loads data from file, so save it
String fileName = "somefilename.txt";
saveSampleTextToFile(fileName);

//load to Dataset<Row>
Dataset<Row> samples = anySparkSessionHere.read().
format("libsvm").load(fileName);

Create the network:

int   inputCount = 2; //2 inputs
int   classCount = 2; //0 and 1
int[] layers     = [inputCount,2,classCount];

MultilayerPerceptronClassifier ann = 
new MultilayerPerceptronClassifier().setLayers(layers).
setMaxIter(1000);

Train the network:

MultilayerPerceptronClassificationModel model =
ann.fit(samples);

Use model to compute some result:

Dataset<Row> inputs;
loadLibsvmFormatDataToInputsAbove();

Dataset<Row> outputs = model.transform(inputs);
List<Row>    list    = outputs.collectAsList();

for (Row row: list) {
  int index = row.fieldIndex("prediction");
  System.out.println(row.getDouble(index));
}

 

Advertisements