/*
 * Decompiled with CFR 0.152.
 */
package tech.tablesaw.api.ml.classification;

import it.unimi.dsi.fastutil.ints.IntIterator;
import java.util.TreeSet;
import org.junit.Assert;
import org.junit.Test;
import smile.classification.KNN;
import tech.tablesaw.api.BooleanColumn;
import tech.tablesaw.api.NumericColumn;
import tech.tablesaw.api.QueryHelper;
import tech.tablesaw.api.Table;
import tech.tablesaw.api.ml.classification.LogisticRegression;
import tech.tablesaw.api.ml.classification.StandardConfusionMatrix;
import tech.tablesaw.columns.Column;
import tech.tablesaw.util.DoubleArrays;

public class ConfusionMatrixTest {
    @Test
    public void testAsTable() throws Exception {
        Table example = Table.read().csv("../data/KNN_Example_1.csv");
        Table[] splits = example.sampleSplit(0.5);
        Table train = splits[0];
        Table test = splits[1];
        KNN knn = KNN.learn((double[][])DoubleArrays.to2dArray((NumericColumn)train.nCol("X"), (NumericColumn)train.nCol("Y")), (int[])train.shortColumn(2).toIntArray(), (int)2);
        int[] predicted = new int[test.rowCount()];
        TreeSet lableSet = new TreeSet(train.shortColumn(2).asSet());
        StandardConfusionMatrix confusion = new StandardConfusionMatrix(lableSet);
        IntIterator intIterator = test.iterator();
        while (intIterator.hasNext()) {
            int row = (Integer)intIterator.next();
            double[] data = new double[]{test.floatColumn(0).getFloat(row), test.floatColumn(1).getFloat(row)};
            predicted[row] = knn.predict((Object)data);
            confusion.increment(Integer.valueOf(test.shortColumn(2).get(row)), Integer.valueOf(predicted[row]));
        }
    }

    @Test
    public void testWithBooleanColumn() throws Exception {
        Table example = Table.read().csv("../data/KNN_Example_1.csv");
        BooleanColumn booleanTarget = example.selectIntoColumn("bt", QueryHelper.column((String)"Label").isEqualTo(1));
        example.addColumn(new Column[]{booleanTarget});
        Table[] splits = example.sampleSplit(0.5);
        Table train = splits[0];
        Table test = splits[1];
        LogisticRegression lr = LogisticRegression.learn((BooleanColumn)train.booleanColumn(3), (NumericColumn[])new NumericColumn[]{train.nCol("X"), train.nCol("Y")});
        int[] predicted = new int[test.rowCount()];
        TreeSet lableSet = new TreeSet(train.shortColumn(2).asSet());
        StandardConfusionMatrix confusion = new StandardConfusionMatrix(lableSet);
        IntIterator intIterator = test.iterator();
        while (intIterator.hasNext()) {
            int row = (Integer)intIterator.next();
            double[] data = new double[]{test.floatColumn(0).getFloat(row), test.floatColumn(1).getFloat(row)};
            predicted[row] = lr.predict(data);
            confusion.increment(Integer.valueOf(test.shortColumn(2).get(row)), Integer.valueOf(predicted[row]));
        }
        Assert.assertNotNull((Object)confusion);
    }
}

