训练模型
在上一篇文章中,我们已经通过LearningPipeline训练好了一个“鸢尾花瓣预测”模型,
var model = pipeline.Train<IrisData, IrisPrediction>();
现在就可以让模型对一条没有人工标注结果的数据进行分析,返回一个预测结果。
var prediction = model.Predict(new IrisData()
            {
                SepalLength = 3.3f,
                SepalWidth = 1.6f,
                PetalLength = 0.2f,
                PetalWidth = 5.1f,
            });
 
Console.WriteLine($"Predicted flower type is: {prediction.PredictedLabels}");或者一次预测一批数据
var inputs = new[]{                new IrisData()
                {
                    SepalLength = 3.3f,
                    SepalWidth = 1.6f,
                    PetalLength = 0.2f,
                    PetalWidth = 5.1f,
                }
                ,new IrisData()
                {
                    SepalLength = 5.2f,
                    SepalWidth = 3.5f,
                    PetalLength = 1.5f,
                    PetalWidth = 0.2f,
                }
            };var predictions = model.Predict(inputs);保存模型
但是大多数时候,已经训练好的模型以后还需要继续可以使用,因此需要把它持久化,写入到zip文件中。
await model.WriteAsync("IrisPredictionModel.zip");使用模型
一旦建立了机器学习模型,就可以部署它,利用它进行预测。我们可以通过REST API,接受来自客户端的数据输入,并返回预测结果。
- 创建API项目
dotnet new webapi -o myApi
- 安装依赖项
cd myApi dotnet add package Microsoft.ML dotnet restore
- 引用模型
要在API中引用我们前面保存的模型,只需将IrisPredictionModel.zip复制到API项目目录中即可。
- 创建数据结构
我们的模型使用数据结构IrisData和IrisPrediction来定义特征和预测属性。因此,当使用我们的模型通过API进行预测时,它也需要引用这些数据结构。因此,我们需要在API项目中定义IrisData和IrisPrediction。类的内容与上一篇文章中创建模型项目中的内容相同。
using Microsoft.ML.Runtime.Api;namespace myApi
{    public class IrisData
    {
        [Column("0")]        public float SepalLength;
        [Column("1")]        public float SepalWidth;
        [Column("2")]        public float PetalLength;
        [Column("3")]        public float PetalWidth;
        [Column("4")]
        [ColumnName("Label")]        public string Label;
    }
}
using Microsoft.ML.Runtime.Api;namespace myApi
{    public class IrisPrediction
    {
        [ColumnName("PredictedLabel")]        public string PredictedLabels;
    }
}- 创建Controller
现在,在API项目的Controllers目录中,创建PredictController类,用于处理来自客户端的预测请求,它包含一个POST方法
using System;using System.Collections.Generic;using System.Linq;using System.Threading.Tasks;using Microsoft.AspNetCore.Mvc;using Microsoft.ML;namespace myApi.Controllers
{
    [Route("api/[controller]")]
    [ApiController]    public class PredictController : ControllerBase
    {        // POST api/predict        [HttpPost]        public async Task<string> Post([FromBody] IrisData value)
        {            
            var model = await PredictionModel.ReadAsync<IrisData,IrisPrediction>("IrisPredictionModel.zip");            var prediction = model.Predict(value);            return prediction.PredictedLabels;
        } 
    }
}- 测试API
使用如下命令行运行程序:
dotnet run
然后,使用POSTMAN或其他工具向http://localhost:5000/api/predict发送POST请求,请求数据类似:
{    "SepalLength": 3.3,    "SepalWidth": 1.6,    "PetalLength": 0.2,    "PetalWidth": 5.1,
}如果成功,将会返回"Iris-virginica"。