訓練分類或迴歸模型

本頁說明如何使用 Google Cloud 控制台或 Vertex AI API,從表格資料集訓練分類或迴歸模型。

事前準備

訓練模型前,請務必完成下列步驟:

訓練模型

Google Cloud 控制台

  1. 在 Google Cloud 控制台的 Vertex AI 專區中,前往「Datasets」頁面。

    前往「資料集」頁面

  2. 按一下要用來訓練模型的資料集名稱,開啟詳細資料頁面。

  3. 如果資料類型使用註解集,請選取要用於這個模型的註解集。

  4. 按一下「訓練新模型」

  5. 選取「其他」

  6. 在「訓練新模型」頁面中,完成下列步驟:

    1. 選取模型訓練方法。

      • AutoML 適用於多種用途。

      按一下「繼續」

    2. 輸入新模型的顯示名稱。

    3. 選取目標欄。

      目標欄是模型要預測的值。

      進一步瞭解目標資料欄規定

    4. 選用:如要將測試資料集匯出至 BigQuery,請勾選「將測試資料集匯出至 BigQuery」,並提供資料表名稱。

    5. 選用:如要選擇如何將資料分割為訓練、測試和驗證集,請開啟「進階選項」。您可以選擇下列資料分割選項:

      • 隨機 (預設):Vertex AI 會隨機選取與各資料集相關聯的資料列。根據預設,Vertex AI 會選取 80% 的資料列做為訓練集、10% 做為驗證集,以及 10% 做為測試集。
      • 手動:Vertex AI 會根據資料分割欄中的值,為每個資料集選取資料列。提供資料分割欄的名稱。
      • 按照時間排序:Vertex AI 會根據時間欄中的時間戳記分割資料。提供時間資料欄的名稱。

      進一步瞭解資料分割

    6. 按一下「繼續」

    7. 選用:按一下「產生統計資料」。產生統計資料會填入「轉換」下拉式選單。

    8. 在「訓練選項」頁面中,檢查資料欄清單,並從訓練中排除不應用於訓練模型的資料欄。

    9. 檢查所選轉換是否適用於納入的功能,以及是否允許無效資料,並進行任何必要更新。

      進一步瞭解轉換無效資料

    10. 如要指定權重欄,或變更預設的最佳化目標,請開啟「進階選項」並選取所需選項。

      進一步瞭解權重資料欄最佳化目標

    11. 按一下「繼續」

    12. 在「運算與定價」視窗中,設定如下:

      輸入模型訓練時數上限。

      這項設定可協助您設定訓練費用上限。建立新模型時還會涉及其他作業,因此實際經過的時間可能會比這個值長。

      建議的訓練時間與訓練資料大小有關。 下表顯示依列數建議的訓練時間範圍;大量欄也會增加所需的訓練時間。

      資料列 建議的訓練時間
      小於 100,000 1-3 小時
      100,000 - 1,000,000 1-6 小時
      1,000,000 - 10,000,000 1-12 小時
      超過 10,000,000 3 到 24 小時
      如要瞭解訓練定價,請參閱定價頁面

    13. 按一下「開始訓練」

      模型訓練可能需要數小時,視資料大小和複雜度而定,以及您是否指定訓練預算。您可以關閉這個分頁,稍後再返回查看。模型訓練完成後,您會收到電子郵件通知。

API

選取表格資料類型目標。

分類

選取語言或環境的分頁標籤:

REST

您可以使用 trainingPipelines.create 指令訓練模型。

訓練模型。

使用任何要求資料之前,請先替換以下項目:

  • LOCATION:您的區域。
  • PROJECT:您的專案 ID
  • TRAININGPIPELINE_DISPLAY_NAME:為這項作業建立的訓練管道顯示名稱。
  • TARGET_COLUMN:您希望模型預測的資料欄 (值)。
  • WEIGHT_COLUMN:(選用) 權重欄。 瞭解詳情
  • TRAINING_BUDGET:模型訓練時間上限,以毫節點時數為單位 (1,000 毫節點時數等於 1 節點時數)。
  • OPTIMIZATION_OBJECTIVE:只有在不想為預測類型使用預設最佳化目標時,才需要提供這項資訊。瞭解詳情
  • TRANSFORMATION_TYPE:系統會為用於訓練模型的每個資料欄提供轉換類型。瞭解詳情
  • COLUMN_NAME:具有指定轉換類型的資料欄名稱。必須指定用於訓練模型的所有資料欄。
  • MODEL_DISPLAY_NAME:新訓練模型的顯示名稱。
  • DATASET_ID:訓練資料集的 ID。
  • 您可以提供 Split 物件來控制資料分割。如要瞭解如何控制資料分割,請參閱「使用 REST 控制資料分割」。
  • PROJECT_NUMBER:系統自動為專案產生的專案編號

HTTP 方法和網址:

POST https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/trainingPipelines

JSON 要求主體:

 {     "displayName": "TRAININGPIPELINE_DISPLAY_NAME",     "trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tabular_1.0.0.yaml",     "trainingTaskInputs": {         "targetColumn": "TARGET_COLUMN",         "weightColumn": "WEIGHT_COLUMN",         "predictionType": "classification",         "trainBudgetMilliNodeHours": TRAINING_BUDGET,         "optimizationObjective": "OPTIMIZATION_OBJECTIVE",         "transformations": [             {"TRANSFORMATION_TYPE_1":  {"column_name" : "COLUMN_NAME_1"} },             {"TRANSFORMATION_TYPE_2":  {"column_name" : "COLUMN_NAME_2"} },             ...     },     "modelToUpload": {"displayName": "MODEL_DISPLAY_NAME"},     "inputDataConfig": {       "datasetId": "DATASET_ID",     } } 

如要傳送要求,請展開以下其中一個選項:

您應該會收到如下的 JSON 回應:

 {   "name": "projects/PROJECT_NUMBER/locations/us-central1/trainingPipelines/4567",   "displayName": "myModelName",   "trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tabular_1.0.0.yaml",   "modelToUpload": {     "displayName": "myModelName"   },   "state": "PIPELINE_STATE_PENDING",   "createTime": "2020-08-18T01:22:57.479336Z",   "updateTime": "2020-08-18T01:22:57.479336Z" } 

Java

在試用這個範例之前,請先按照Java使用用戶端程式庫的 Vertex AI 快速入門中的操作說明進行設定。 詳情請參閱 Vertex AI Java API 參考說明文件

如要向 Vertex AI 進行驗證,請設定應用程式預設憑證。 詳情請參閱「為本機開發環境設定驗證」。

 import com.google.cloud.aiplatform.util.ValueConverter; import com.google.cloud.aiplatform.v1.DeployedModelRef; import com.google.cloud.aiplatform.v1.EnvVar; import com.google.cloud.aiplatform.v1.FilterSplit; import com.google.cloud.aiplatform.v1.FractionSplit; import com.google.cloud.aiplatform.v1.InputDataConfig; import com.google.cloud.aiplatform.v1.LocationName; import com.google.cloud.aiplatform.v1.Model; import com.google.cloud.aiplatform.v1.ModelContainerSpec; import com.google.cloud.aiplatform.v1.PipelineServiceClient; import com.google.cloud.aiplatform.v1.PipelineServiceSettings; import com.google.cloud.aiplatform.v1.Port; import com.google.cloud.aiplatform.v1.PredefinedSplit; import com.google.cloud.aiplatform.v1.PredictSchemata; import com.google.cloud.aiplatform.v1.TimestampSplit; import com.google.cloud.aiplatform.v1.TrainingPipeline; import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlTablesInputs; import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation; import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation.AutoTransformation; import com.google.rpc.Status; import java.io.IOException; import java.util.ArrayList;  public class CreateTrainingPipelineTabularClassificationSample {    public static void main(String[] args) throws IOException {     // TODO(developer): Replace these variables before running the sample.     String project = "YOUR_PROJECT_ID";     String modelDisplayName = "YOUR_DATASET_DISPLAY_NAME";     String datasetId = "YOUR_DATASET_ID";     String targetColumn = "TARGET_COLUMN";     createTrainingPipelineTableClassification(project, modelDisplayName, datasetId, targetColumn);   }    static void createTrainingPipelineTableClassification(       String project, String modelDisplayName, String datasetId, String targetColumn)       throws IOException {     PipelineServiceSettings pipelineServiceSettings =         PipelineServiceSettings.newBuilder()             .setEndpoint("us-central1-aiplatform.googleapis.com:443")             .build();      // Initialize client that will be used to send requests. This client only needs to be created     // once, and can be reused for multiple requests. After completing all of your requests, call     // the "close" method on the client to safely clean up any remaining background resources.     try (PipelineServiceClient pipelineServiceClient =         PipelineServiceClient.create(pipelineServiceSettings)) {       String location = "us-central1";       LocationName locationName = LocationName.of(project, location);       String trainingTaskDefinition =           "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tables_1.0.0.yaml";        // Set the columns used for training and their data types       Transformation transformation1 =           Transformation.newBuilder()               .setAuto(AutoTransformation.newBuilder().setColumnName("sepal_width").build())               .build();       Transformation transformation2 =           Transformation.newBuilder()               .setAuto(AutoTransformation.newBuilder().setColumnName("sepal_length").build())               .build();       Transformation transformation3 =           Transformation.newBuilder()               .setAuto(AutoTransformation.newBuilder().setColumnName("petal_length").build())               .build();       Transformation transformation4 =           Transformation.newBuilder()               .setAuto(AutoTransformation.newBuilder().setColumnName("petal_width").build())               .build();        ArrayList<Transformation> transformationArrayList = new ArrayList<>();       transformationArrayList.add(transformation1);       transformationArrayList.add(transformation2);       transformationArrayList.add(transformation3);       transformationArrayList.add(transformation4);        AutoMlTablesInputs autoMlTablesInputs =           AutoMlTablesInputs.newBuilder()               .setTargetColumn(targetColumn)               .setPredictionType("classification")               .addAllTransformations(transformationArrayList)               .setTrainBudgetMilliNodeHours(8000)               .build();        FractionSplit fractionSplit =           FractionSplit.newBuilder()               .setTrainingFraction(0.8)               .setValidationFraction(0.1)               .setTestFraction(0.1)               .build();        InputDataConfig inputDataConfig =           InputDataConfig.newBuilder()               .setDatasetId(datasetId)               .setFractionSplit(fractionSplit)               .build();       Model modelToUpload = Model.newBuilder().setDisplayName(modelDisplayName).build();        TrainingPipeline trainingPipeline =           TrainingPipeline.newBuilder()               .setDisplayName(modelDisplayName)               .setTrainingTaskDefinition(trainingTaskDefinition)               .setTrainingTaskInputs(ValueConverter.toValue(autoMlTablesInputs))               .setInputDataConfig(inputDataConfig)               .setModelToUpload(modelToUpload)               .build();        TrainingPipeline trainingPipelineResponse =           pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline);        System.out.println("Create Training Pipeline Tabular Classification Response");       System.out.format("\tName: %s\n", trainingPipelineResponse.getName());       System.out.format("\tDisplay Name: %s\n", trainingPipelineResponse.getDisplayName());       System.out.format(           "\tTraining Task Definition: %s\n", trainingPipelineResponse.getTrainingTaskDefinition());       System.out.format(           "\tTraining Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs());       System.out.format(           "\tTraining Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata());        System.out.format("\tState: %s\n", trainingPipelineResponse.getState());       System.out.format("\tCreate Time: %s\n", trainingPipelineResponse.getCreateTime());       System.out.format("\tStart Time: %s\n", trainingPipelineResponse.getStartTime());       System.out.format("\tEnd Time: %s\n", trainingPipelineResponse.getEndTime());       System.out.format("\tUpdate Time: %s\n", trainingPipelineResponse.getUpdateTime());       System.out.format("\tLabels: %s\n", trainingPipelineResponse.getLabelsMap());        InputDataConfig inputDataConfigResponse = trainingPipelineResponse.getInputDataConfig();       System.out.println("\tInput Data Config");       System.out.format("\t\tDataset Id: %s\n", inputDataConfigResponse.getDatasetId());       System.out.format(           "\t\tAnnotations Filter: %s\n", inputDataConfigResponse.getAnnotationsFilter());        FractionSplit fractionSplitResponse = inputDataConfigResponse.getFractionSplit();       System.out.println("\t\tFraction Split");       System.out.format(           "\t\t\tTraining Fraction: %s\n", fractionSplitResponse.getTrainingFraction());       System.out.format(           "\t\t\tValidation Fraction: %s\n", fractionSplitResponse.getValidationFraction());       System.out.format("\t\t\tTest Fraction: %s\n", fractionSplitResponse.getTestFraction());        FilterSplit filterSplit = inputDataConfigResponse.getFilterSplit();       System.out.println("\t\tFilter Split");       System.out.format("\t\t\tTraining Fraction: %s\n", filterSplit.getTrainingFilter());       System.out.format("\t\t\tValidation Fraction: %s\n", filterSplit.getValidationFilter());       System.out.format("\t\t\tTest Fraction: %s\n", filterSplit.getTestFilter());        PredefinedSplit predefinedSplit = inputDataConfigResponse.getPredefinedSplit();       System.out.println("\t\tPredefined Split");       System.out.format("\t\t\tKey: %s\n", predefinedSplit.getKey());        TimestampSplit timestampSplit = inputDataConfigResponse.getTimestampSplit();       System.out.println("\t\tTimestamp Split");       System.out.format("\t\t\tTraining Fraction: %s\n", timestampSplit.getTrainingFraction());       System.out.format("\t\t\tValidation Fraction: %s\n", timestampSplit.getValidationFraction());       System.out.format("\t\t\tTest Fraction: %s\n", timestampSplit.getTestFraction());       System.out.format("\t\t\tKey: %s\n", timestampSplit.getKey());        Model modelResponse = trainingPipelineResponse.getModelToUpload();       System.out.println("\tModel To Upload");       System.out.format("\t\tName: %s\n", modelResponse.getName());       System.out.format("\t\tDisplay Name: %s\n", modelResponse.getDisplayName());       System.out.format("\t\tDescription: %s\n", modelResponse.getDescription());       System.out.format("\t\tMetadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri());       System.out.format("\t\tMeta Data: %s\n", modelResponse.getMetadata());       System.out.format("\t\tTraining Pipeline: %s\n", modelResponse.getTrainingPipeline());       System.out.format("\t\tArtifact Uri: %s\n", modelResponse.getArtifactUri());        System.out.format(           "\t\tSupported Deployment Resources Types: %s\n",           modelResponse.getSupportedDeploymentResourcesTypesList().toString());       System.out.format(           "\t\tSupported Input Storage Formats: %s\n",           modelResponse.getSupportedInputStorageFormatsList().toString());       System.out.format(           "\t\tSupported Output Storage Formats: %s\n",           modelResponse.getSupportedOutputStorageFormatsList().toString());        System.out.format("\t\tCreate Time: %s\n", modelResponse.getCreateTime());       System.out.format("\t\tUpdate Time: %s\n", modelResponse.getUpdateTime());       System.out.format("\t\tLables: %s\n", modelResponse.getLabelsMap());       PredictSchemata predictSchemata = modelResponse.getPredictSchemata();        System.out.println("\tPredict Schemata");       System.out.format("\t\tInstance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri());       System.out.format(           "\t\tParameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri());       System.out.format(           "\t\tPrediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri());        for (Model.ExportFormat supportedExportFormat :           modelResponse.getSupportedExportFormatsList()) {         System.out.println("\tSupported Export Format");         System.out.format("\t\tId: %s\n", supportedExportFormat.getId());       }       ModelContainerSpec containerSpec = modelResponse.getContainerSpec();        System.out.println("\tContainer Spec");       System.out.format("\t\tImage Uri: %s\n", containerSpec.getImageUri());       System.out.format("\t\tCommand: %s\n", containerSpec.getCommandList());       System.out.format("\t\tArgs: %s\n", containerSpec.getArgsList());       System.out.format("\t\tPredict Route: %s\n", containerSpec.getPredictRoute());       System.out.format("\t\tHealth Route: %s\n", containerSpec.getHealthRoute());        for (EnvVar envVar : containerSpec.getEnvList()) {         System.out.println("\t\tEnv");         System.out.format("\t\t\tName: %s\n", envVar.getName());         System.out.format("\t\t\tValue: %s\n", envVar.getValue());       }        for (Port port : containerSpec.getPortsList()) {         System.out.println("\t\tPort");         System.out.format("\t\t\tContainer Port: %s\n", port.getContainerPort());       }        for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) {         System.out.println("\tDeployed Model");         System.out.format("\t\tEndpoint: %s\n", deployedModelRef.getEndpoint());         System.out.format("\t\tDeployed Model Id: %s\n", deployedModelRef.getDeployedModelId());       }        Status status = trainingPipelineResponse.getError();       System.out.println("\tError");       System.out.format("\t\tCode: %s\n", status.getCode());       System.out.format("\t\tMessage: %s\n", status.getMessage());     }   } }

Node.js

在試用這個範例之前,請先按照Node.js使用用戶端程式庫的 Vertex AI 快速入門中的操作說明進行設定。 詳情請參閱 Vertex AI Node.js API 參考說明文件

如要向 Vertex AI 進行驗證,請設定應用程式預設憑證。 詳情請參閱「為本機開發環境設定驗證」。

/**  * TODO(developer): Uncomment these variables before running the sample.\  * (Not necessary if passing values as arguments)  */  // const datasetId = 'YOUR_DATASET_ID'; // const modelDisplayName = 'YOUR_MODEL_DISPLAY_NAME'; // const trainingPipelineDisplayName = 'YOUR_TRAINING_PIPELINE_DISPLAY_NAME'; // const targetColumn = 'YOUR_TARGET_COLUMN'; // const project = 'YOUR_PROJECT_ID'; // const location = 'YOUR_PROJECT_LOCATION'; const aiplatform = require('@google-cloud/aiplatform'); const {definition} =   aiplatform.protos.google.cloud.aiplatform.v1.schema.trainingjob;  // Imports the Google Cloud Pipeline Service Client library const {PipelineServiceClient} = aiplatform.v1; // Specifies the location of the api endpoint const clientOptions = {   apiEndpoint: 'us-central1-aiplatform.googleapis.com', };  // Instantiates a client const pipelineServiceClient = new PipelineServiceClient(clientOptions);  async function createTrainingPipelineTablesClassification() {   // Configure the parent resource   const parent = `projects/${project}/locations/${location}`;    const transformations = [     {auto: {column_name: 'sepal_width'}},     {auto: {column_name: 'sepal_length'}},     {auto: {column_name: 'petal_length'}},     {auto: {column_name: 'petal_width'}},   ];   const trainingTaskInputsObj = new definition.AutoMlTablesInputs({     targetColumn: targetColumn,     predictionType: 'classification',     transformations: transformations,     trainBudgetMilliNodeHours: 8000,     disableEarlyStopping: false,     optimizationObjective: 'minimize-log-loss',   });   const trainingTaskInputs = trainingTaskInputsObj.toValue();    const modelToUpload = {displayName: modelDisplayName};   const inputDataConfig = {     datasetId: datasetId,     fractionSplit: {       trainingFraction: 0.8,       validationFraction: 0.1,       testFraction: 0.1,     },   };   const trainingPipeline = {     displayName: trainingPipelineDisplayName,     trainingTaskDefinition:       'gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tables_1.0.0.yaml',     trainingTaskInputs,     inputDataConfig,     modelToUpload,   };   const request = {     parent,     trainingPipeline,   };    // Create training pipeline request   const [response] =     await pipelineServiceClient.createTrainingPipeline(request);    console.log('Create training pipeline tabular classification response');   console.log(`Name : ${response.name}`);   console.log('Raw response:');   console.log(JSON.stringify(response, null, 2)); } createTrainingPipelineTablesClassification();

Python

如要瞭解如何安裝或更新 Python 適用的 Vertex AI SDK,請參閱「安裝 Python 適用的 Vertex AI SDK」。 詳情請參閱 Python API 參考說明文件

def create_training_pipeline_tabular_classification_sample(     project: str,     display_name: str,     dataset_id: str,     location: str = "us-central1",     model_display_name: str = None,     target_column: str = "target_column",     training_fraction_split: float = 0.8,     validation_fraction_split: float = 0.1,     test_fraction_split: float = 0.1,     budget_milli_node_hours: int = 8000,     disable_early_stopping: bool = False,     sync: bool = True, ):     aiplatform.init(project=project, location=location)      tabular_classification_job = aiplatform.AutoMLTabularTrainingJob(         display_name=display_name, optimization_prediction_type="classification"     )      my_tabular_dataset = aiplatform.TabularDataset(dataset_name=dataset_id)      model = tabular_classification_job.run(         dataset=my_tabular_dataset,         target_column=target_column,         training_fraction_split=training_fraction_split,         validation_fraction_split=validation_fraction_split,         test_fraction_split=test_fraction_split,         budget_milli_node_hours=budget_milli_node_hours,         model_display_name=model_display_name,         disable_early_stopping=disable_early_stopping,         sync=sync,     )      model.wait()      print(model.display_name)     print(model.resource_name)     print(model.uri)     return model  

迴歸

選取語言或環境的分頁標籤:

REST

您可以使用 trainingPipelines.create 指令訓練模型。

訓練模型。

使用任何要求資料之前,請先替換以下項目:

  • LOCATION:您的區域。
  • PROJECT:。
  • TRAININGPIPELINE_DISPLAY_NAME:為這項作業建立的訓練管道顯示名稱。
  • TARGET_COLUMN:您希望模型預測的資料欄 (值)。
  • WEIGHT_COLUMN:(選用) 權重欄。 瞭解詳情
  • TRAINING_BUDGET:模型訓練時間上限,以毫節點時數為單位 (1,000 毫節點時數等於 1 節點時數)。
  • OPTIMIZATION_OBJECTIVE:只有在不想為預測類型使用預設最佳化目標時,才需要提供這項資訊。瞭解詳情
  • TRANSFORMATION_TYPE:系統會為用於訓練模型的每個資料欄提供轉換類型。瞭解詳情
  • COLUMN_NAME:具有指定轉換類型的資料欄名稱。必須指定用於訓練模型的所有資料欄。
  • MODEL_DISPLAY_NAME:新訓練模型的顯示名稱。
  • DATASET_ID:訓練資料集的 ID。
  • 您可以提供 Split 物件來控制資料分割。如要瞭解如何控制資料分割,請參閱「使用 REST 控制資料分割」。
  • PROJECT_NUMBER:系統自動為專案產生的專案編號

HTTP 方法和網址:

POST https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/trainingPipelines

JSON 要求主體:

 {     "displayName": "TRAININGPIPELINE_DISPLAY_NAME",     "trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tabular_1.0.0.yaml",     "trainingTaskInputs": {         "targetColumn": "TARGET_COLUMN",         "weightColumn": "WEIGHT_COLUMN",         "predictionType": "regression",         "trainBudgetMilliNodeHours": TRAINING_BUDGET,         "optimizationObjective": "OPTIMIZATION_OBJECTIVE",         "transformations": [             {"TRANSFORMATION_TYPE_1":  {"column_name" : "COLUMN_NAME_1"} },             {"TRANSFORMATION_TYPE_2":  {"column_name" : "COLUMN_NAME_2"} },             ...     },     "modelToUpload": {"displayName": "MODEL_DISPLAY_NAME"},     "inputDataConfig": {       "datasetId": "DATASET_ID",     } } 

如要傳送要求,請展開以下其中一個選項:

您應該會收到如下的 JSON 回應:

 {   "name": "projects/PROJECT_NUMBER/locations/us-central1/trainingPipelines/4567",   "displayName": "myModelName",   "trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tabular_1.0.0.yaml",   "modelToUpload": {     "displayName": "myModelName"   },   "state": "PIPELINE_STATE_PENDING",   "createTime": "2020-08-18T01:22:57.479336Z",   "updateTime": "2020-08-18T01:22:57.479336Z" } 

Java

在試用這個範例之前,請先按照Java使用用戶端程式庫的 Vertex AI 快速入門中的操作說明進行設定。 詳情請參閱 Vertex AI Java API 參考說明文件

如要向 Vertex AI 進行驗證,請設定應用程式預設憑證。 詳情請參閱「為本機開發環境設定驗證」。

 import com.google.cloud.aiplatform.util.ValueConverter; import com.google.cloud.aiplatform.v1.DeployedModelRef; import com.google.cloud.aiplatform.v1.EnvVar; import com.google.cloud.aiplatform.v1.FilterSplit; import com.google.cloud.aiplatform.v1.FractionSplit; import com.google.cloud.aiplatform.v1.InputDataConfig; import com.google.cloud.aiplatform.v1.LocationName; import com.google.cloud.aiplatform.v1.Model; import com.google.cloud.aiplatform.v1.ModelContainerSpec; import com.google.cloud.aiplatform.v1.PipelineServiceClient; import com.google.cloud.aiplatform.v1.PipelineServiceSettings; import com.google.cloud.aiplatform.v1.Port; import com.google.cloud.aiplatform.v1.PredefinedSplit; import com.google.cloud.aiplatform.v1.PredictSchemata; import com.google.cloud.aiplatform.v1.TimestampSplit; import com.google.cloud.aiplatform.v1.TrainingPipeline; import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlTablesInputs; import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation; import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation.AutoTransformation; import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation.TimestampTransformation; import com.google.rpc.Status; import java.io.IOException; import java.util.ArrayList;  public class CreateTrainingPipelineTabularRegressionSample {    public static void main(String[] args) throws IOException {     // TODO(developer): Replace these variables before running the sample.     String project = "YOUR_PROJECT_ID";     String modelDisplayName = "YOUR_DATASET_DISPLAY_NAME";     String datasetId = "YOUR_DATASET_ID";     String targetColumn = "TARGET_COLUMN";     createTrainingPipelineTableRegression(project, modelDisplayName, datasetId, targetColumn);   }    static void createTrainingPipelineTableRegression(       String project, String modelDisplayName, String datasetId, String targetColumn)       throws IOException {     PipelineServiceSettings pipelineServiceSettings =         PipelineServiceSettings.newBuilder()             .setEndpoint("us-central1-aiplatform.googleapis.com:443")             .build();      // Initialize client that will be used to send requests. This client only needs to be created     // once, and can be reused for multiple requests. After completing all of your requests, call     // the "close" method on the client to safely clean up any remaining background resources.     try (PipelineServiceClient pipelineServiceClient =         PipelineServiceClient.create(pipelineServiceSettings)) {       String location = "us-central1";       LocationName locationName = LocationName.of(project, location);       String trainingTaskDefinition =           "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tables_1.0.0.yaml";        // Set the columns used for training and their data types       ArrayList<Transformation> tranformations = new ArrayList<>();       tranformations.add(           Transformation.newBuilder()               .setAuto(AutoTransformation.newBuilder().setColumnName("STRING_5000unique_NULLABLE"))               .build());       tranformations.add(           Transformation.newBuilder()               .setAuto(AutoTransformation.newBuilder().setColumnName("INTEGER_5000unique_NULLABLE"))               .build());       tranformations.add(           Transformation.newBuilder()               .setAuto(AutoTransformation.newBuilder().setColumnName("FLOAT_5000unique_NULLABLE"))               .build());       tranformations.add(           Transformation.newBuilder()               .setAuto(AutoTransformation.newBuilder().setColumnName("FLOAT_5000unique_REPEATED"))               .build());       tranformations.add(           Transformation.newBuilder()               .setAuto(AutoTransformation.newBuilder().setColumnName("NUMERIC_5000unique_NULLABLE"))               .build());       tranformations.add(           Transformation.newBuilder()               .setAuto(AutoTransformation.newBuilder().setColumnName("BOOLEAN_2unique_NULLABLE"))               .build());       tranformations.add(           Transformation.newBuilder()               .setTimestamp(                   TimestampTransformation.newBuilder()                       .setColumnName("TIMESTAMP_1unique_NULLABLE")                       .setInvalidValuesAllowed(true))               .build());       tranformations.add(           Transformation.newBuilder()               .setAuto(AutoTransformation.newBuilder().setColumnName("DATE_1unique_NULLABLE"))               .build());       tranformations.add(           Transformation.newBuilder()               .setAuto(AutoTransformation.newBuilder().setColumnName("TIME_1unique_NULLABLE"))               .build());       tranformations.add(           Transformation.newBuilder()               .setTimestamp(                   TimestampTransformation.newBuilder()                       .setColumnName("DATETIME_1unique_NULLABLE")                       .setInvalidValuesAllowed(true))               .build());       tranformations.add(           Transformation.newBuilder()               .setAuto(                   AutoTransformation.newBuilder()                       .setColumnName("STRUCT_NULLABLE.STRING_5000unique_NULLABLE"))               .build());       tranformations.add(           Transformation.newBuilder()               .setAuto(                   AutoTransformation.newBuilder()                       .setColumnName("STRUCT_NULLABLE.INTEGER_5000unique_NULLABLE"))               .build());       tranformations.add(           Transformation.newBuilder()               .setAuto(                   AutoTransformation.newBuilder()                       .setColumnName("STRUCT_NULLABLE.FLOAT_5000unique_NULLABLE"))               .build());       tranformations.add(           Transformation.newBuilder()               .setAuto(                   AutoTransformation.newBuilder()                       .setColumnName("STRUCT_NULLABLE.FLOAT_5000unique_REQUIRED"))               .build());       tranformations.add(           Transformation.newBuilder()               .setAuto(                   AutoTransformation.newBuilder()                       .setColumnName("STRUCT_NULLABLE.FLOAT_5000unique_REPEATED"))               .build());       tranformations.add(           Transformation.newBuilder()               .setAuto(                   AutoTransformation.newBuilder()                       .setColumnName("STRUCT_NULLABLE.NUMERIC_5000unique_NULLABLE"))               .build());       tranformations.add(           Transformation.newBuilder()               .setAuto(                   AutoTransformation.newBuilder()                       .setColumnName("STRUCT_NULLABLE.TIMESTAMP_1unique_NULLABLE"))               .build());        AutoMlTablesInputs trainingTaskInputs =           AutoMlTablesInputs.newBuilder()               .addAllTransformations(tranformations)               .setTargetColumn(targetColumn)               .setPredictionType("regression")               .setTrainBudgetMilliNodeHours(8000)               .setDisableEarlyStopping(false)               // supported regression optimisation objectives: minimize-rmse,               // minimize-mae, minimize-rmsle               .setOptimizationObjective("minimize-rmse")               .build();        FractionSplit fractionSplit =           FractionSplit.newBuilder()               .setTrainingFraction(0.8)               .setValidationFraction(0.1)               .setTestFraction(0.1)               .build();        InputDataConfig inputDataConfig =           InputDataConfig.newBuilder()               .setDatasetId(datasetId)               .setFractionSplit(fractionSplit)               .build();       Model modelToUpload = Model.newBuilder().setDisplayName(modelDisplayName).build();        TrainingPipeline trainingPipeline =           TrainingPipeline.newBuilder()               .setDisplayName(modelDisplayName)               .setTrainingTaskDefinition(trainingTaskDefinition)               .setTrainingTaskInputs(ValueConverter.toValue(trainingTaskInputs))               .setInputDataConfig(inputDataConfig)               .setModelToUpload(modelToUpload)               .build();        TrainingPipeline trainingPipelineResponse =           pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline);        System.out.println("Create Training Pipeline Tabular Regression Response");       System.out.format("\tName: %s\n", trainingPipelineResponse.getName());       System.out.format("\tDisplay Name: %s\n", trainingPipelineResponse.getDisplayName());       System.out.format(           "\tTraining Task Definition: %s\n", trainingPipelineResponse.getTrainingTaskDefinition());       System.out.format(           "\tTraining Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs());       System.out.format(           "\tTraining Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata());        System.out.format("\tState: %s\n", trainingPipelineResponse.getState());       System.out.format("\tCreate Time: %s\n", trainingPipelineResponse.getCreateTime());       System.out.format("\tStart Time: %s\n", trainingPipelineResponse.getStartTime());       System.out.format("\tEnd Time: %s\n", trainingPipelineResponse.getEndTime());       System.out.format("\tUpdate Time: %s\n", trainingPipelineResponse.getUpdateTime());       System.out.format("\tLabels: %s\n", trainingPipelineResponse.getLabelsMap());        InputDataConfig inputDataConfigResponse = trainingPipelineResponse.getInputDataConfig();       System.out.println("\tInput Data Config");       System.out.format("\t\tDataset Id: %s\n", inputDataConfigResponse.getDatasetId());       System.out.format(           "\t\tAnnotations Filter: %s\n", inputDataConfigResponse.getAnnotationsFilter());        FractionSplit fractionSplitResponse = inputDataConfigResponse.getFractionSplit();       System.out.println("\t\tFraction Split");       System.out.format(           "\t\t\tTraining Fraction: %s\n", fractionSplitResponse.getTrainingFraction());       System.out.format(           "\t\t\tValidation Fraction: %s\n", fractionSplitResponse.getValidationFraction());       System.out.format("\t\t\tTest Fraction: %s\n", fractionSplitResponse.getTestFraction());        FilterSplit filterSplit = inputDataConfigResponse.getFilterSplit();       System.out.println("\t\tFilter Split");       System.out.format("\t\t\tTraining Fraction: %s\n", filterSplit.getTrainingFilter());       System.out.format("\t\t\tValidation Fraction: %s\n", filterSplit.getValidationFilter());       System.out.format("\t\t\tTest Fraction: %s\n", filterSplit.getTestFilter());        PredefinedSplit predefinedSplit = inputDataConfigResponse.getPredefinedSplit();       System.out.println("\t\tPredefined Split");       System.out.format("\t\t\tKey: %s\n", predefinedSplit.getKey());        TimestampSplit timestampSplit = inputDataConfigResponse.getTimestampSplit();       System.out.println("\t\tTimestamp Split");       System.out.format("\t\t\tTraining Fraction: %s\n", timestampSplit.getTrainingFraction());       System.out.format("\t\t\tValidation Fraction: %s\n", timestampSplit.getValidationFraction());       System.out.format("\t\t\tTest Fraction: %s\n", timestampSplit.getTestFraction());       System.out.format("\t\t\tKey: %s\n", timestampSplit.getKey());        Model modelResponse = trainingPipelineResponse.getModelToUpload();       System.out.println("\tModel To Upload");       System.out.format("\t\tName: %s\n", modelResponse.getName());       System.out.format("\t\tDisplay Name: %s\n", modelResponse.getDisplayName());       System.out.format("\t\tDescription: %s\n", modelResponse.getDescription());       System.out.format("\t\tMetadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri());       System.out.format("\t\tMeta Data: %s\n", modelResponse.getMetadata());       System.out.format("\t\tTraining Pipeline: %s\n", modelResponse.getTrainingPipeline());       System.out.format("\t\tArtifact Uri: %s\n", modelResponse.getArtifactUri());        System.out.format(           "\t\tSupported Deployment Resources Types: %s\n",           modelResponse.getSupportedDeploymentResourcesTypesList().toString());       System.out.format(           "\t\tSupported Input Storage Formats: %s\n",           modelResponse.getSupportedInputStorageFormatsList().toString());       System.out.format(           "\t\tSupported Output Storage Formats: %s\n",           modelResponse.getSupportedOutputStorageFormatsList().toString());        System.out.format("\t\tCreate Time: %s\n", modelResponse.getCreateTime());       System.out.format("\t\tUpdate Time: %s\n", modelResponse.getUpdateTime());       System.out.format("\t\tLables: %s\n", modelResponse.getLabelsMap());       PredictSchemata predictSchemata = modelResponse.getPredictSchemata();        System.out.println("\tPredict Schemata");       System.out.format("\t\tInstance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri());       System.out.format(           "\t\tParameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri());       System.out.format(           "\t\tPrediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri());        for (Model.ExportFormat supportedExportFormat :           modelResponse.getSupportedExportFormatsList()) {         System.out.println("\tSupported Export Format");         System.out.format("\t\tId: %s\n", supportedExportFormat.getId());       }       ModelContainerSpec containerSpec = modelResponse.getContainerSpec();        System.out.println("\tContainer Spec");       System.out.format("\t\tImage Uri: %s\n", containerSpec.getImageUri());       System.out.format("\t\tCommand: %s\n", containerSpec.getCommandList());       System.out.format("\t\tArgs: %s\n", containerSpec.getArgsList());       System.out.format("\t\tPredict Route: %s\n", containerSpec.getPredictRoute());       System.out.format("\t\tHealth Route: %s\n", containerSpec.getHealthRoute());        for (EnvVar envVar : containerSpec.getEnvList()) {         System.out.println("\t\tEnv");         System.out.format("\t\t\tName: %s\n", envVar.getName());         System.out.format("\t\t\tValue: %s\n", envVar.getValue());       }        for (Port port : containerSpec.getPortsList()) {         System.out.println("\t\tPort");         System.out.format("\t\t\tContainer Port: %s\n", port.getContainerPort());       }        for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) {         System.out.println("\tDeployed Model");         System.out.format("\t\tEndpoint: %s\n", deployedModelRef.getEndpoint());         System.out.format("\t\tDeployed Model Id: %s\n", deployedModelRef.getDeployedModelId());       }        Status status = trainingPipelineResponse.getError();       System.out.println("\tError");       System.out.format("\t\tCode: %s\n", status.getCode());       System.out.format("\t\tMessage: %s\n", status.getMessage());     }   } }

Node.js

在試用這個範例之前,請先按照Node.js使用用戶端程式庫的 Vertex AI 快速入門中的操作說明進行設定。 詳情請參閱 Vertex AI Node.js API 參考說明文件

如要向 Vertex AI 進行驗證,請設定應用程式預設憑證。 詳情請參閱「為本機開發環境設定驗證」。

/**  * TODO(developer): Uncomment these variables before running the sample.\  * (Not necessary if passing values as arguments)  */  // const datasetId = 'YOUR_DATASET_ID'; // const modelDisplayName = 'YOUR_MODEL_DISPLAY_NAME'; // const trainingPipelineDisplayName = 'YOUR_TRAINING_PIPELINE_DISPLAY_NAME'; // const targetColumn = 'YOUR_TARGET_COLUMN'; // const project = 'YOUR_PROJECT_ID'; // const location = 'YOUR_PROJECT_LOCATION'; const aiplatform = require('@google-cloud/aiplatform'); const {definition} =   aiplatform.protos.google.cloud.aiplatform.v1.schema.trainingjob;  // Imports the Google Cloud Pipeline Service Client library const {PipelineServiceClient} = aiplatform.v1; // Specifies the location of the api endpoint const clientOptions = {   apiEndpoint: 'us-central1-aiplatform.googleapis.com', };  // Instantiates a client const pipelineServiceClient = new PipelineServiceClient(clientOptions);  async function createTrainingPipelineTablesRegression() {   // Configure the parent resource   const parent = `projects/${project}/locations/${location}`;    const transformations = [     {auto: {column_name: 'STRING_5000unique_NULLABLE'}},     {auto: {column_name: 'INTEGER_5000unique_NULLABLE'}},     {auto: {column_name: 'FLOAT_5000unique_NULLABLE'}},     {auto: {column_name: 'FLOAT_5000unique_REPEATED'}},     {auto: {column_name: 'NUMERIC_5000unique_NULLABLE'}},     {auto: {column_name: 'BOOLEAN_2unique_NULLABLE'}},     {       timestamp: {         column_name: 'TIMESTAMP_1unique_NULLABLE',         invalid_values_allowed: true,       },     },     {auto: {column_name: 'DATE_1unique_NULLABLE'}},     {auto: {column_name: 'TIME_1unique_NULLABLE'}},     {       timestamp: {         column_name: 'DATETIME_1unique_NULLABLE',         invalid_values_allowed: true,       },     },     {auto: {column_name: 'STRUCT_NULLABLE.STRING_5000unique_NULLABLE'}},     {auto: {column_name: 'STRUCT_NULLABLE.INTEGER_5000unique_NULLABLE'}},     {auto: {column_name: 'STRUCT_NULLABLE.FLOAT_5000unique_NULLABLE'}},     {auto: {column_name: 'STRUCT_NULLABLE.FLOAT_5000unique_REQUIRED'}},     {auto: {column_name: 'STRUCT_NULLABLE.FLOAT_5000unique_REPEATED'}},     {auto: {column_name: 'STRUCT_NULLABLE.NUMERIC_5000unique_NULLABLE'}},     {auto: {column_name: 'STRUCT_NULLABLE.BOOLEAN_2unique_NULLABLE'}},     {auto: {column_name: 'STRUCT_NULLABLE.TIMESTAMP_1unique_NULLABLE'}},   ];    const trainingTaskInputsObj = new definition.AutoMlTablesInputs({     transformations,     targetColumn,     predictionType: 'regression',     trainBudgetMilliNodeHours: 8000,     disableEarlyStopping: false,     optimizationObjective: 'minimize-rmse',   });   const trainingTaskInputs = trainingTaskInputsObj.toValue();    const modelToUpload = {displayName: modelDisplayName};   const inputDataConfig = {     datasetId: datasetId,     fractionSplit: {       trainingFraction: 0.8,       validationFraction: 0.1,       testFraction: 0.1,     },   };   const trainingPipeline = {     displayName: trainingPipelineDisplayName,     trainingTaskDefinition:       'gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tables_1.0.0.yaml',     trainingTaskInputs,     inputDataConfig,     modelToUpload,   };   const request = {     parent,     trainingPipeline,   };    // Create training pipeline request   const [response] =     await pipelineServiceClient.createTrainingPipeline(request);    console.log('Create training pipeline tabular regression response');   console.log(`Name : ${response.name}`);   console.log('Raw response:');   console.log(JSON.stringify(response, null, 2)); } createTrainingPipelineTablesRegression();

Python

如要瞭解如何安裝或更新 Python 適用的 Vertex AI SDK,請參閱「安裝 Python 適用的 Vertex AI SDK」。 詳情請參閱 Python API 參考說明文件

def create_training_pipeline_tabular_regression_sample(     project: str,     display_name: str,     dataset_id: str,     location: str = "us-central1",     model_display_name: str = "my_model",     target_column: str = "target_column",     training_fraction_split: float = 0.8,     validation_fraction_split: float = 0.1,     test_fraction_split: float = 0.1,     budget_milli_node_hours: int = 8000,     disable_early_stopping: bool = False,     sync: bool = True, ):     aiplatform.init(project=project, location=location)      tabular_regression_job = aiplatform.AutoMLTabularTrainingJob(         display_name=display_name, optimization_prediction_type="regression"     )      my_tabular_dataset = aiplatform.TabularDataset(dataset_name=dataset_id)      model = tabular_regression_job.run(         dataset=my_tabular_dataset,         target_column=target_column,         training_fraction_split=training_fraction_split,         validation_fraction_split=validation_fraction_split,         test_fraction_split=test_fraction_split,         budget_milli_node_hours=budget_milli_node_hours,         model_display_name=model_display_name,         disable_early_stopping=disable_early_stopping,         sync=sync,     )      model.wait()      print(model.display_name)     print(model.resource_name)     print(model.uri)     return model  

使用 REST 控制資料分割

您可以控制訓練資料在訓練、驗證和測試集之間的分割方式。使用 Vertex AI API 時,請使用 Split 物件判斷資料分割。Split 物件可做為其中一種物件類型,納入 inputDataConfig 物件中,每種物件類型都提供不同的訓練資料分割方式。

分割資料的方法取決於資料類型:

  • FractionSplit

    • TRAINING_FRACTION:用於訓練集的訓練資料比例。
    • VALIDATION_FRACTION:用於驗證集的訓練資料比例。
    • TEST_FRACTION:用於測試集的訓練資料比例。

    如果指定任何分數,請指定所有分數。分數總和必須為 1.0。瞭解詳情

     "fractionSplit": { "trainingFraction": TRAINING_FRACTION, "validationFraction": VALIDATION_FRACTION, "testFraction": TEST_FRACTION }, 

  • PredefinedSplit

    • DATA_SPLIT_COLUMN:包含資料分割值的資料欄 (TRAINVALIDATIONTEST)。

    使用分割欄手動指定每個資料列的資料分割。瞭解詳情

    "predefinedSplit": {   "key": DATA_SPLIT_COLUMN }, 
  • TimestampSplit

    • TRAINING_FRACTION:用於訓練集的訓練資料百分比。預設值為 0.80。
    • VALIDATION_FRACTION:用於驗證集的訓練資料百分比。預設值為 0.10。
    • TEST_FRACTION:用於測試集的訓練資料百分比。預設值為 0.10。
    • TIME_COLUMN:包含時間戳記的資料欄。

    如果指定任何分數,請指定所有分數。分數加總必須為 1.0。瞭解詳情

    "timestampSplit": {   "trainingFraction": TRAINING_FRACTION,   "validationFraction": VALIDATION_FRACTION,   "testFraction": TEST_FRACTION,   "key": TIME_COLUMN } 

分類或迴歸模型的最佳化目標

訓練模型時,Vertex AI 會根據模型類型和目標資料欄所用的資料類型,選取預設最佳化目標。

分類模型最適合用於:
最佳化目標 API 值 目標的適用情境
AUC ROC maximize-au-roc 盡量增加接收者操作特徵 (ROC) 曲線下的面積。區分類別。二元分類的預設值。
對數損失 minimize-log-loss 盡可能保持推論機率的準確度。多元分類唯一支援的目標。
AUC PR maximize-au-prc 盡量增加精確度和喚回度曲線下的面積。可最佳化少用類別的推論結果。
維持特定喚回度時的精確度 maximize-precision-at-recall 在特定召回值下,盡量提高精確度。
維持特定精確度時的喚回度 maximize-recall-at-precision 以特定精確度值為目標,盡量提高喚回度。
迴歸模型最適合用於:
最佳化目標 API 值 目標的適用情境
均方根誤差 minimize-rmse 盡量減少均方根誤差 (RMSE)。準確擷取更多極端值。預設值。
平均絕對誤差 minimize-mae 盡量減少平均絕對誤差 (MAE)。將極端值視為對模型影響較小的離群值。
均方根對數誤差 minimize-rmsle 盡量減少均方根對數誤差 (RMSLE)。依據相對大小 (而非絕對值) 懲罰誤差,如果預測值和實際值可能相當大,這個做法就非常實用。

後續步驟