跳至主要內容

Java开发者玩转机器学习的利器: Tribuo

DD编辑部原创JavaJava大约 7 分钟

Java开发者玩转机器学习的利器: Tribuo

1. 概述

机器学习(ML)和人工智能(AI)正在推动软件开发的变革,使系统能够通过数据学习并做出智能预测。

作为一名Java开发者,如果要训练自己的预测模型,是不是第一想到的还是把Python拿起来?其实不一定非要拿起Python,在Java领域也有自己的生产级机器学习工具,它支持分类、回归、聚类等常见任务,还能无缝对接 TensorFlow 等框架,用 Java 就能直接训模型、做预测!

它就是:Tribuoopen in new window,是 Oracle 推出的面向生产环境的开源机器学习库,极大简化了健壮 ML 模型的构建与部署。与 Weka 和 Deeplearning4j 类似,Tribuo 支持多种机器学习任务,并能轻松集成到 Java 应用中。

接下来本文我们将了解 Tribuo 支持的多种机器学习算法,并以 UCI 红葡萄酒质量数据集为例,构建一个用于预测葡萄酒质量的回归模型。

2. 什么是 Tribuo?

Tribuo 是一个以 Java 为核心的机器学习库,支持:

  • 监督学习:如回归、分类等
  • 无监督学习:如聚类

此外,Tribuo 拥有强类型特性,能够强制输入输出类型一致,有效防止运行时错误,确保模型开发过程的规范性。

它支持以 ONNX(开放神经网络交换)格式导入和导出模型,便于与 TensorFlow、PyTorch 等主流 ML 框架集成。

另一个亮点是 provenance(溯源)追踪功能,可记录数据集、模型参数和训练配置等元数据,提升透明度和可复现性。

随着 AI 在企业级 Java 应用中的普及,Tribuo 为在 Java 系统中直接嵌入智能行为提供了实用工具包。

3. 支持的机器学习算法

Tribuo 支持多种机器学习任务,包括:

  • 分类:预测离散类别或标签。例如,预测一支足球队是否会获胜,或根据质量阈值将葡萄酒分为好或坏。
  • 回归:预测连续值,如葡萄酒质量分数或患者胆固醇水平。
  • 聚类:在无标签数据中识别分组。例如,可以根据酸度、酒精含量等化学属性对葡萄酒进行分组,而无需知道其质量分数。

4. 搭建 Tribuo 项目

我们将通过构建一个葡萄酒质量回归预测模型,体验 Tribuo 的实际应用。

首先,在 pom.xml 中添加 Tribuo 依赖open in new window

<dependency>
    <groupId>org.tribuo</groupId>
    <artifactId>tribuo-all</artifactId>
    <version>4.3.2</version>
</dependency>

tribuo-all 依赖提供了加载和训练数据集所需的相关类。

然后,下载 UCI 红葡萄酒质量数据集open in new window,并放置到 src/main/resources/dataset 目录下。该数据集包含 11 个理化特征,如酸度和酒精含量:

quality 列是一个适合回归任务的连续数值。

最后,创建一个名为 WineQualityRegression 的类:

public class WineQualityRegression {
}

后续章节将在该类中实现训练和保存模型的相关逻辑。

5. 类级变量

接下来,定义如下类级变量:

public static final String DATASET_PATH = "src/main/resources/dataset/winequality-red.csv";
public static final String MODEL_PATH = "src/main/resources/model/winequality-red-regressor.ser";
public Model<Regressor> model;
public Trainer<Regressor> trainer;
public Dataset<Regressor> trainSet;
public Dataset<Regressor> testSet;

上述代码中,我们定义了数据集路径和训练模型保存/加载路径。

随后,定义了四个变量,分别代表:

  • Model —— 存储预测模型的类
  • Trainer —— 可训练预测模型的接口
  • Dataset —— 用于训练的数据集类

此外,我们显式指定了模型输出类型为 Regressor

6. 加载与划分数据集

定义一个方法用于加载并划分数据集:

void createDatasets() throws Exception {
    RegressionFactory regressionFactory = new RegressionFactory();
    CSVLoader<Regressor> csvLoader = new CSVLoader<>(';', CSVIterator.QUOTE, regressionFactory);
    DataSource<Regressor> dataSource = csvLoader.loadDataSource(Paths.get(DATASET_PATH), "quality");

    TrainTestSplitter<Regressor> dataSplitter = new TrainTestSplitter<>(dataSource, 0.7, 1L);
    trainSet = new MutableDataset<>(dataSplitter.getTrain());
    testSet = new MutableDataset<>(dataSplitter.getTest());
}

这里,我们用 CSVLoader 解析分号分隔的 CSV 文件并为回归任务做准备。RegressionFactory 用于创建回归输出,指定目标变量 quality 为连续变量。DataSource<Regressor> 保存解析后的数据。

随后,为了评估模型的泛化能力和表现,使用 TrainTestSplitter 将数据集按 7:3 划分为训练集和测试集。

7. 训练回归模型

由于葡萄酒质量分数为数值型,我们采用分类与回归树(CART)作为基学习器进行训练:

void createTrainer() {
    CARTRegressionTrainer subsamplingTree = new CARTRegressionTrainer(
      Integer.MAX_VALUE,
      AbstractCARTTrainer.MIN_EXAMPLES,
      0.001f,
      0.7f,
      new MeanSquaredError(),
      Trainer.DEFAULT_SEED
    );

    trainer = new RandomForestTrainer<>(subsamplingTree, new AveragingCombiner(), 10);
    model = trainer.train(trainSet); 
}

上述方法中,CARTRegressionTrainer 配置了无最大深度、每次分裂最少 6 个样本、以均方误差为分裂标准。随后,RandomForestTrainer 结合 10 棵 CART 决策树,并用 AveragingCombiner 平均预测结果。

train() 方法在 trainSet 数据集上训练模型,生成用于预测葡萄酒质量分数的 Model<Regressor>

8. 评估

接下来,使用 RegressionEvaluator 评估模型在数据集上的表现,计算相关指标:

void evaluate(Model<Regressor> model, String datasetName, Dataset<Regressor> dataset) {
    RegressionEvaluator evaluator = new RegressionEvaluator();
    RegressionEvaluation evaluation = evaluator.evaluate(model, dataset);
    Regressor dimension0 = new Regressor("DIM-0", Double.NaN);

    log.info("MAE: " + evaluation.mae(dimension0));
    log.info("RMSE: " + evaluation.rmse(dimension0));
    log.info("R^2: " + evaluation.r2(dimension0));
}

RegressionEvaluator 用于评估模型在数据集上的表现。我们将 MAE(平均绝对误差)、RMSE(均方根误差)和 R^2(决定系数)输出到控制台。

随后,调用 evaluate() 方法评估模型和数据集:

void evaluateModels() throws Exception {
    log.info("Training model");
    evaluate(model, "trainSet", trainSet);

    log.info("Testing model");
    evaluate(model, "testSet", testSet);
}

执行程序后,训练集和测试集的评估结果如下:

07:10:14.405 [main] INFO  tribuo.WineQualityRegression - Training model
07:10:14.406 [main] INFO  tribuo.WineQualityRegression - Results for trainSet---------------------
07:10:14.537 [main] INFO  tribuo.WineQualityRegression - MAE: 0.25025410332970005
07:10:14.537 [main] INFO  tribuo.WineQualityRegression - RMSE: 0.3422557198486092
07:10:14.538 [main] INFO  tribuo.WineQualityRegression - R^2: 0.8190947891297661
07:10:14.538 [main] INFO  tribuo.WineQualityRegression - Testing model
07:10:14.540 [main] INFO  tribuo.WineQualityRegression - Results for testSet---------------------
07:10:14.565 [main] INFO  tribuo.WineQualityRegression - MAE: 0.48711029366796743
07:10:14.565 [main] INFO  tribuo.WineQualityRegression - RMSE: 0.6584973595553575
07:10:14.565 [main] INFO  tribuo.WineQualityRegression - R^2: 0.3444460580874339

MAE 表示预测值与实际值的绝对差异,RMSE 表示预测值与实际值的平方差均值的平方根,R^2 表示模型对训练和测试数据方差的解释能力。

更低的 MAERMSE,以及更高的 R^2,意味着模型预测性能更优。

9. 保存模型

最后,将模型保存为文件以便后续复用:

void saveModel() throws Exception {
    File modelFile = new File(MODEL_PATH);
    try (ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(modelFile))) {
        objectOutputStream.writeObject(model);
    }
}

上述代码通过 ObjectOutputStream 类将训练好的模型序列化保存到文件。这样,我们可以在后续预测中直接复用模型,无需重新训练

10. 方法调用

现在,在 main() 方法中调用前面创建的方法:

public static void main(String[] args) throws Exception {
    WineQualityRegression wineQualityRegression = new WineQualityRegression();

    wineQualityRegression.createDatasets();
    wineQualityRegression.createTrainer();
    wineQualityRegression.evaluateModels();
    wineQualityRegression.saveModel();
}

编译代码后,模型会被保存到指定目录。

11. 使用模型

新建一个 WinePredictor 类,在 main() 方法中加载已保存的模型:

class WineQualityPredictor {
    private static final Logger log = LoggerFactory.getLogger(WineQualityPredictor.class);

    public static void main(String[] args) throws IOException, ClassNotFoundException {
        File modelFile = new File("src/main/resources/model/winequality-red-regressor.ser");
        Model<Regressor> loadedModel = null;

        try (ObjectInputStream objectInputStream = new ObjectInputStream(new FileInputStream(modelFile))) {
            loadedModel = (Model<Regressor>) objectInputStream.readObject();
        }
}

如前所述,Tribuo 对类型敏感,因此我们指定模型类型为 Regressor

通过创建 ObjectInputStream 并传入模型路径来加载模型。

然后,创建一个 ArrayExample 对象,表示单个葡萄酒样本:

ArrayExample<Regressor> wineAttribute = new ArrayExample<Regressor>(new Regressor("quality", Double.NaN));
wineAttribute.add("fixed acidity", 7.4f);
wineAttribute.add("volatile acidity", 0.7f);
wineAttribute.add("citric acid", 0.47f);
wineAttribute.add("residual sugar", 1.9f);
wineAttribute.add("chlorides", 0.076f);
wineAttribute.add("free sulfur dioxide", 11.0f);
wineAttribute.add("total sulfur dioxide", 34.0f);
wineAttribute.add("density", 0.9978f);
wineAttribute.add("pH", 3.51f);
wineAttribute.add("sulphates", 0.56f);
wineAttribute.add("alcohol", 9.4f);

最后,使用 Prediction 类进行预测:

Prediction<Regressor> prediction = loadedModel.predict(wineAttribute);
double predictQuality = prediction.getOutput().getValues()[0];
log.info("Predicted wine quality: " + predictQuality);

预测结果如下:

07:31:05.772 [main] INFO  tribuo.WineQualityPredictor - Predicted wine quality: 5.028163673540464

12. 总结

在本文中,我们学习了 Tribuo 及其特性,了解了其支持的部分机器学习算法,并通过回归算法训练了葡萄酒质量预测模型。

上次编辑于:
贡献者: didi