Text Prediction - Heterogeneous Data Types

In your applications, your text data may be mixed with other common data types like numerical data and categorical data (which are commonly found in tabular data). The TextPrediction task in AutoGluon can train a single neural network that jointly operates on multiple feature types, including text, categorical, and numerical columns. Here we’ll again use the Semantic Textual Similarity dataset to illustrate this functionality.

import numpy as np
import warnings
warnings.filterwarnings('ignore')
np.random.seed(123)

Load Data

from autogluon.utils.tabular.utils.loaders import load_pd

train_data = load_pd.load('https://autogluon-text.s3-accelerate.amazonaws.com/glue/sts/train.parquet')
dev_data = load_pd.load('https://autogluon-text.s3-accelerate.amazonaws.com/glue/sts/dev.parquet')
train_data.head(10)
Loaded data from: https://autogluon-text.s3-accelerate.amazonaws.com/glue/sts/train.parquet | Columns = 4 / 4 | Rows = 5749 -> 5749
Loaded data from: https://autogluon-text.s3-accelerate.amazonaws.com/glue/sts/dev.parquet | Columns = 4 / 4 | Rows = 1500 -> 1500
sentence1 sentence2 genre score
0 A plane is taking off. An air plane is taking off. main-captions 5.00
1 A man is playing a large flute. A man is playing a flute. main-captions 3.80
2 A man is spreading shreded cheese on a pizza. A man is spreading shredded cheese on an uncoo... main-captions 3.80
3 Three men are playing chess. Two men are playing chess. main-captions 2.60
4 A man is playing the cello. A man seated is playing the cello. main-captions 4.25
5 Some men are fighting. Two men are fighting. main-captions 4.25
6 A man is smoking. A man is skating. main-captions 0.50
7 The man is playing the piano. The man is playing the guitar. main-captions 1.60
8 A man is playing on a guitar and singing. A woman is playing an acoustic guitar and sing... main-captions 2.20
9 A person is throwing a cat on to the ceiling. A person throws a cat on the ceiling. main-captions 5.00

Note the STS dataset contains two text fields: sentence1 and sentence2, one categorical field: genre, and one numerical field score. Let’s try to predict the score based on the other features: sentence1, sentence2, genre.

import autogluon as ag
from autogluon import TextPrediction as task

predictor_score = task.fit(train_data, label='score',
                           time_limits=60, ngpus_per_trial=1, seed=123,
                           output_directory='./ag_sts_mixed_score')
NumPy-shape semantics has been activated in your code. This is required for creating and manipulating scalar and zero-size tensors, which were not supported in MXNet before, as in the official NumPy library. Please DO NOT manually deactivate this semantics while using mxnet.numpy and mxnet.numpy_extension modules.
2020-09-19 08:20:20,603 - root - INFO - All Logs will be saved to ./ag_sts_mixed_score/ag_text_prediction.log
2020-09-19 08:20:20,626 - root - INFO - Train Dataset:
2020-09-19 08:20:20,627 - root - INFO - Columns:

- Text(
   name="sentence1"
   #total/missing=4599/0
   length, min/avg/max=16/57.73/340
)
- Text(
   name="sentence2"
   #total/missing=4599/0
   length, min/avg/max=15/57.55/311
)
- Categorical(
   name="genre"
   #total/missing=4599/0
   num_class (total/non_special)=4/3
   categories=['main-captions', 'main-forums', 'main-news']
   freq=[1617, 360, 2622]
)
- Numerical(
   name="score"
   #total/missing=4599/0
   shape=()
)


2020-09-19 08:20:20,627 - root - INFO - Tuning Dataset:
2020-09-19 08:20:20,628 - root - INFO - Columns:

- Text(
   name="sentence1"
   #total/missing=1150/0
   length, min/avg/max=17/57.64/367
)
- Text(
   name="sentence2"
   #total/missing=1150/0
   length, min/avg/max=17/57.45/265
)
- Categorical(
   name="genre"
   #total/missing=1150/0
   num_class (total/non_special)=4/3
   categories=['main-captions', 'main-forums', 'main-news']
   freq=[383, 90, 677]
)
- Numerical(
   name="score"
   #total/missing=1150/0
   shape=()
)


2020-09-19 08:20:20,628 - root - INFO - Label columns=['score'], Feature columns=['sentence1', 'sentence2', 'genre'], Problem types=['regression'], Label shapes=[()]
2020-09-19 08:20:20,629 - root - INFO - Eval Metric=mse, Stop Metric=mse, Log Metrics=['mse', 'rmse', 'mae']
HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))
 55%|█████▍    | 314/576 [01:01<00:51,  5.07it/s]
 55%|█████▍    | 314/576 [01:01<00:51,  5.12it/s]
score = predictor_score.evaluate(dev_data, metrics='spearmanr')
print('Spearman Correlation=', score['spearmanr'])
Spearman Correlation= 0.8531228050726722

We can also train a model that predicts the genre using the other columns as features.

predictor_genre = task.fit(train_data, label='genre',
                           time_limits=60, ngpus_per_trial=1, seed=123,
                           output_directory='./ag_sts_mixed_genre')
2020-09-19 08:22:57,661 - root - INFO - All Logs will be saved to ./ag_sts_mixed_genre/ag_text_prediction.log
2020-09-19 08:22:57,687 - root - INFO - Train Dataset:
2020-09-19 08:22:57,687 - root - INFO - Columns:

- Text(
   name="sentence1"
   #total/missing=4599/0
   length, min/avg/max=16/57.69/367
)
- Text(
   name="sentence2"
   #total/missing=4599/0
   length, min/avg/max=15/57.54/311
)
- Categorical(
   name="genre"
   #total/missing=4599/0
   num_class (total/non_special)=3/3
   categories=['main-captions', 'main-forums', 'main-news']
   freq=[1601, 365, 2633]
)
- Numerical(
   name="score"
   #total/missing=4599/0
   shape=()
)


2020-09-19 08:22:57,688 - root - INFO - Tuning Dataset:
2020-09-19 08:22:57,688 - root - INFO - Columns:

- Text(
   name="sentence1"
   #total/missing=1150/0
   length, min/avg/max=16/57.80/267
)
- Text(
   name="sentence2"
   #total/missing=1150/0
   length, min/avg/max=17/57.52/237
)
- Categorical(
   name="genre"
   #total/missing=1150/0
   num_class (total/non_special)=3/3
   categories=['main-captions', 'main-forums', 'main-news']
   freq=[399, 85, 666]
)
- Numerical(
   name="score"
   #total/missing=1150/0
   shape=()
)


2020-09-19 08:22:57,689 - root - INFO - Label columns=['genre'], Feature columns=['sentence1', 'sentence2', 'score'], Problem types=['classification'], Label shapes=[3]
2020-09-19 08:22:57,689 - root - INFO - Eval Metric=acc, Stop Metric=acc, Log Metrics=['acc', 'nll']
HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))
 55%|█████▍    | 314/576 [01:00<00:50,  5.15it/s]
 55%|█████▍    | 314/576 [01:00<00:50,  5.22it/s]
score = predictor_genre.evaluate(dev_data, metrics='acc')
print('Genre-prediction Accuracy = {}%'.format(score['acc'] * 100))
Genre-prediction Accuracy = 86.6%