mirror of
https://github.com/twitter/the-algorithm.git
synced 2025-06-10 22:58:17 -05:00
Twitter Recommendation Algorithm
Please note we have force-pushed a new initial commit in order to remove some publicly-available Twitter user information. Note that this process may be required in the future.
This commit is contained in:
466
trust_and_safety_models/nsfw/nsfw_media.py
Normal file
466
trust_and_safety_models/nsfw/nsfw_media.py
Normal file
@ -0,0 +1,466 @@
|
||||
import kerastuner as kt
|
||||
import math
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import random
|
||||
import sklearn.metrics
|
||||
import tensorflow as tf
|
||||
import os
|
||||
import glob
|
||||
|
||||
from tqdm import tqdm
|
||||
from matplotlib import pyplot as plt
|
||||
from tensorflow.keras.models import Sequential
|
||||
from tensorflow.keras.layers import Dense
|
||||
from google.cloud import storage
|
||||
|
||||
physical_devices = tf.config.list_physical_devices('GPU')
|
||||
physical_devices
|
||||
|
||||
tf.config.set_visible_devices([tf.config.PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU')], 'GPU')
|
||||
tf.config.get_visible_devices('GPU')
|
||||
|
||||
def decode_fn_embedding(example_proto):
|
||||
|
||||
feature_description = {
|
||||
"embedding": tf.io.FixedLenFeature([256], dtype=tf.float32),
|
||||
"labels": tf.io.FixedLenFeature([], dtype=tf.int64),
|
||||
}
|
||||
|
||||
example = tf.io.parse_single_example(
|
||||
example_proto,
|
||||
feature_description
|
||||
)
|
||||
|
||||
return example
|
||||
|
||||
def preprocess_embedding_example(example_dict, positive_label=1, features_as_dict=False):
|
||||
labels = example_dict["labels"]
|
||||
label = tf.math.reduce_any(labels == positive_label)
|
||||
label = tf.cast(label, tf.int32)
|
||||
embedding = example_dict["embedding"]
|
||||
|
||||
if features_as_dict:
|
||||
features = {"embedding": embedding}
|
||||
else:
|
||||
features = embedding
|
||||
|
||||
return features, label
|
||||
input_root = ...
|
||||
sens_prev_input_root = ...
|
||||
|
||||
use_sens_prev_data = True
|
||||
has_validation_data = True
|
||||
positive_label = 1
|
||||
|
||||
train_batch_size = 256
|
||||
test_batch_size = 256
|
||||
validation_batch_size = 256
|
||||
|
||||
do_resample = False
|
||||
def class_func(features, label):
|
||||
return label
|
||||
|
||||
resample_fn = tf.data.experimental.rejection_resample(
|
||||
class_func, target_dist = [0.5, 0.5], seed=0
|
||||
)
|
||||
train_glob = f"{input_root}/train/tfrecord/*.tfrecord"
|
||||
train_files = tf.io.gfile.glob(train_glob)
|
||||
|
||||
if use_sens_prev_data:
|
||||
train_sens_prev_glob = f"{sens_prev_input_root}/train/tfrecord/*.tfrecord"
|
||||
train_sens_prev_files = tf.io.gfile.glob(train_sens_prev_glob)
|
||||
train_files = train_files + train_sens_prev_files
|
||||
|
||||
random.shuffle(train_files)
|
||||
|
||||
if not len(train_files):
|
||||
raise ValueError(f"Did not find any train files matching {train_glob}")
|
||||
|
||||
|
||||
test_glob = f"{input_root}/test/tfrecord/*.tfrecord"
|
||||
test_files = tf.io.gfile.glob(test_glob)
|
||||
|
||||
if not len(test_files):
|
||||
raise ValueError(f"Did not find any eval files matching {test_glob}")
|
||||
|
||||
test_ds = tf.data.TFRecordDataset(test_files).map(decode_fn_embedding)
|
||||
test_ds = test_ds.map(lambda x: preprocess_embedding_example(x, positive_label=positive_label)).batch(batch_size=test_batch_size)
|
||||
|
||||
if use_sens_prev_data:
|
||||
test_sens_prev_glob = f"{sens_prev_input_root}/test/tfrecord/*.tfrecord"
|
||||
test_sens_prev_files = tf.io.gfile.glob(test_sens_prev_glob)
|
||||
|
||||
if not len(test_sens_prev_files):
|
||||
raise ValueError(f"Did not find any eval files matching {test_sens_prev_glob}")
|
||||
|
||||
test_sens_prev_ds = tf.data.TFRecordDataset(test_sens_prev_files).map(decode_fn_embedding)
|
||||
test_sens_prev_ds = test_sens_prev_ds.map(lambda x: preprocess_embedding_example(x, positive_label=positive_label)).batch(batch_size=test_batch_size)
|
||||
|
||||
train_ds = tf.data.TFRecordDataset(train_files).map(decode_fn_embedding)
|
||||
train_ds = train_ds.map(lambda x: preprocess_embedding_example(x, positive_label=positive_label))
|
||||
|
||||
if do_resample:
|
||||
train_ds = train_ds.apply(resample_fn).map(lambda _,b:(b))
|
||||
|
||||
train_ds = train_ds.batch(batch_size=256).shuffle(buffer_size=10)
|
||||
train_ds = train_ds.repeat()
|
||||
|
||||
|
||||
if has_validation_data:
|
||||
eval_glob = f"{input_root}/validation/tfrecord/*.tfrecord"
|
||||
eval_files = tf.io.gfile.glob(eval_glob)
|
||||
|
||||
if use_sens_prev_data:
|
||||
eval_sens_prev_glob = f"{sens_prev_input_root}/validation/tfrecord/*.tfrecord"
|
||||
eval_sens_prev_files = tf.io.gfile.glob(eval_sens_prev_glob)
|
||||
eval_files = eval_files + eval_sens_prev_files
|
||||
|
||||
|
||||
if not len(eval_files):
|
||||
raise ValueError(f"Did not find any eval files matching {eval_glob}")
|
||||
|
||||
eval_ds = tf.data.TFRecordDataset(eval_files).map(decode_fn_embedding)
|
||||
eval_ds = eval_ds.map(lambda x: preprocess_embedding_example(x, positive_label=positive_label)).batch(batch_size=validation_batch_size)
|
||||
|
||||
else:
|
||||
|
||||
eval_ds = tf.data.TFRecordDataset(test_files).map(decode_fn_embedding)
|
||||
eval_ds = eval_ds.map(lambda x: preprocess_embedding_example(x, positive_label=positive_label)).batch(batch_size=validation_batch_size)
|
||||
check_ds = tf.data.TFRecordDataset(train_files).map(decode_fn_embedding)
|
||||
cnt = 0
|
||||
pos_cnt = 0
|
||||
for example in tqdm(check_ds):
|
||||
label = example['labels']
|
||||
if label == 1:
|
||||
pos_cnt += 1
|
||||
cnt += 1
|
||||
print(f'{cnt} train entries with {pos_cnt} positive')
|
||||
|
||||
metrics = []
|
||||
|
||||
metrics.append(
|
||||
tf.keras.metrics.PrecisionAtRecall(
|
||||
recall=0.9, num_thresholds=200, class_id=None, name=None, dtype=None
|
||||
)
|
||||
)
|
||||
|
||||
metrics.append(
|
||||
tf.keras.metrics.AUC(
|
||||
num_thresholds=200,
|
||||
curve="PR",
|
||||
)
|
||||
)
|
||||
def build_model(hp):
|
||||
model = Sequential()
|
||||
|
||||
optimizer = tf.keras.optimizers.Adam(
|
||||
learning_rate=0.001,
|
||||
beta_1=0.9,
|
||||
beta_2=0.999,
|
||||
epsilon=1e-08,
|
||||
amsgrad=False,
|
||||
name="Adam",
|
||||
)
|
||||
|
||||
activation=hp.Choice("activation", ["tanh", "gelu"])
|
||||
kernel_initializer=hp.Choice("kernel_initializer", ["he_uniform", "glorot_uniform"])
|
||||
for i in range(hp.Int("num_layers", 1, 2)):
|
||||
model.add(tf.keras.layers.BatchNormalization())
|
||||
|
||||
units=hp.Int("units", min_value=128, max_value=256, step=128)
|
||||
|
||||
if i == 0:
|
||||
model.add(
|
||||
Dense(
|
||||
units=units,
|
||||
activation=activation,
|
||||
kernel_initializer=kernel_initializer,
|
||||
input_shape=(None, 256)
|
||||
)
|
||||
)
|
||||
else:
|
||||
model.add(
|
||||
Dense(
|
||||
units=units,
|
||||
activation=activation,
|
||||
kernel_initializer=kernel_initializer,
|
||||
)
|
||||
)
|
||||
|
||||
model.add(Dense(1, activation='sigmoid', kernel_initializer=kernel_initializer))
|
||||
model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=metrics)
|
||||
|
||||
return model
|
||||
|
||||
tuner = kt.tuners.BayesianOptimization(
|
||||
build_model,
|
||||
objective=kt.Objective('val_loss', direction="min"),
|
||||
max_trials=30,
|
||||
directory='tuner_dir',
|
||||
project_name='with_twitter_clip')
|
||||
|
||||
callbacks = [tf.keras.callbacks.EarlyStopping(
|
||||
monitor='val_loss', min_delta=0, patience=5, verbose=0,
|
||||
mode='auto', baseline=None, restore_best_weights=True
|
||||
)]
|
||||
|
||||
steps_per_epoch = 400
|
||||
tuner.search(train_ds,
|
||||
epochs=100,
|
||||
batch_size=256,
|
||||
steps_per_epoch=steps_per_epoch,
|
||||
verbose=2,
|
||||
validation_data=eval_ds,
|
||||
callbacks=callbacks)
|
||||
|
||||
tuner.results_summary()
|
||||
models = tuner.get_best_models(num_models=2)
|
||||
best_model = models[0]
|
||||
|
||||
best_model.build(input_shape=(None, 256))
|
||||
best_model.summary()
|
||||
|
||||
tuner.get_best_hyperparameters()[0].values
|
||||
|
||||
optimizer = tf.keras.optimizers.Adam(
|
||||
learning_rate=0.001,
|
||||
beta_1=0.9,
|
||||
beta_2=0.999,
|
||||
epsilon=1e-08,
|
||||
amsgrad=False,
|
||||
name="Adam",
|
||||
)
|
||||
best_model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=metrics)
|
||||
best_model.summary()
|
||||
|
||||
callbacks = [tf.keras.callbacks.EarlyStopping(
|
||||
monitor='val_loss', min_delta=0, patience=10, verbose=0,
|
||||
mode='auto', baseline=None, restore_best_weights=True
|
||||
)]
|
||||
history = best_model.fit(train_ds, epochs=100, validation_data=eval_ds, steps_per_epoch=steps_per_epoch, callbacks=callbacks)
|
||||
|
||||
model_name = 'twitter_hypertuned'
|
||||
model_path = f'models/nsfw_Keras_with_CLIP_{model_name}'
|
||||
tf.keras.models.save_model(best_model, model_path)
|
||||
|
||||
def copy_local_directory_to_gcs(local_path, bucket, gcs_path):
|
||||
"""Recursively copy a directory of files to GCS.
|
||||
|
||||
local_path should be a directory and not have a trailing slash.
|
||||
"""
|
||||
assert os.path.isdir(local_path)
|
||||
for local_file in glob.glob(local_path + '/**'):
|
||||
if not os.path.isfile(local_file):
|
||||
dir_name = os.path.basename(os.path.normpath(local_file))
|
||||
copy_local_directory_to_gcs(local_file, bucket, f"{gcs_path}/{dir_name}")
|
||||
else:
|
||||
remote_path = os.path.join(gcs_path, local_file[1 + len(local_path) :])
|
||||
blob = bucket.blob(remote_path)
|
||||
blob.upload_from_filename(local_file)
|
||||
|
||||
client = storage.Client(project=...)
|
||||
bucket = client.get_bucket(...)
|
||||
copy_local_directory_to_gcs(model_path, bucket, model_path)
|
||||
copy_local_directory_to_gcs('tuner_dir', bucket, 'tuner_dir')
|
||||
loaded_model = tf.keras.models.load_model(model_path)
|
||||
print(history.history.keys())
|
||||
|
||||
plt.figure(figsize = (20, 5))
|
||||
|
||||
plt.subplot(1, 3, 1)
|
||||
plt.plot(history.history['auc'])
|
||||
plt.plot(history.history['val_auc'])
|
||||
plt.title('model auc')
|
||||
plt.ylabel('auc')
|
||||
plt.xlabel('epoch')
|
||||
plt.legend(['train', 'test'], loc='upper left')
|
||||
|
||||
plt.subplot(1, 3, 2)
|
||||
plt.plot(history.history['loss'])
|
||||
plt.plot(history.history['val_loss'])
|
||||
plt.title('model loss')
|
||||
plt.ylabel('loss')
|
||||
plt.xlabel('epoch')
|
||||
plt.legend(['train', 'test'], loc='upper left')
|
||||
|
||||
plt.subplot(1, 3, 3)
|
||||
plt.plot(history.history['precision_at_recall'])
|
||||
plt.plot(history.history['val_precision_at_recall'])
|
||||
plt.title('model precision at 0.9 recall')
|
||||
plt.ylabel('precision_at_recall')
|
||||
plt.xlabel('epoch')
|
||||
plt.legend(['train', 'test'], loc='upper left')
|
||||
|
||||
plt.savefig('history_with_twitter_clip.pdf')
|
||||
|
||||
test_labels = []
|
||||
test_preds = []
|
||||
|
||||
for batch_features, batch_labels in tqdm(test_ds):
|
||||
test_preds.extend(loaded_model.predict_proba(batch_features))
|
||||
test_labels.extend(batch_labels.numpy())
|
||||
|
||||
test_sens_prev_labels = []
|
||||
test_sens_prev_preds = []
|
||||
|
||||
for batch_features, batch_labels in tqdm(test_sens_prev_ds):
|
||||
test_sens_prev_preds.extend(loaded_model.predict_proba(batch_features))
|
||||
test_sens_prev_labels.extend(batch_labels.numpy())
|
||||
|
||||
n_test_pos = 0
|
||||
n_test_neg = 0
|
||||
n_test = 0
|
||||
|
||||
for label in test_labels:
|
||||
n_test +=1
|
||||
if label == 1:
|
||||
n_test_pos +=1
|
||||
else:
|
||||
n_test_neg +=1
|
||||
|
||||
print(f'n_test = {n_test}, n_pos = {n_test_pos}, n_neg = {n_test_neg}')
|
||||
|
||||
n_test_sens_prev_pos = 0
|
||||
n_test_sens_prev_neg = 0
|
||||
n_test_sens_prev = 0
|
||||
|
||||
for label in test_sens_prev_labels:
|
||||
n_test_sens_prev +=1
|
||||
if label == 1:
|
||||
n_test_sens_prev_pos +=1
|
||||
else:
|
||||
n_test_sens_prev_neg +=1
|
||||
|
||||
print(f'n_test_sens_prev = {n_test_sens_prev}, n_pos_sens_prev = {n_test_sens_prev_pos}, n_neg = {n_test_sens_prev_neg}')
|
||||
|
||||
test_weights = np.ones(np.asarray(test_preds).shape)
|
||||
|
||||
test_labels = np.asarray(test_labels)
|
||||
test_preds = np.asarray(test_preds)
|
||||
test_weights = np.asarray(test_weights)
|
||||
|
||||
pr = sklearn.metrics.precision_recall_curve(
|
||||
test_labels,
|
||||
test_preds)
|
||||
|
||||
auc = sklearn.metrics.auc(pr[1], pr[0])
|
||||
plt.plot(pr[1], pr[0])
|
||||
plt.title("nsfw (MU test set)")
|
||||
|
||||
test_sens_prev_weights = np.ones(np.asarray(test_sens_prev_preds).shape)
|
||||
|
||||
test_sens_prev_labels = np.asarray(test_sens_prev_labels)
|
||||
test_sens_prev_preds = np.asarray(test_sens_prev_preds)
|
||||
test_sens_prev_weights = np.asarray(test_sens_prev_weights)
|
||||
|
||||
pr_sens_prev = sklearn.metrics.precision_recall_curve(
|
||||
test_sens_prev_labels,
|
||||
test_sens_prev_preds)
|
||||
|
||||
auc_sens_prev = sklearn.metrics.auc(pr_sens_prev[1], pr_sens_prev[0])
|
||||
plt.plot(pr_sens_prev[1], pr_sens_prev[0])
|
||||
plt.title("nsfw (sens prev test set)")
|
||||
|
||||
df = pd.DataFrame(
|
||||
{
|
||||
"label": test_labels.squeeze(),
|
||||
"preds_keras": np.asarray(test_preds).flatten(),
|
||||
})
|
||||
plt.figure(figsize=(15, 10))
|
||||
df["preds_keras"].hist()
|
||||
plt.title("Keras predictions", size=20)
|
||||
plt.xlabel('score')
|
||||
plt.ylabel("freq")
|
||||
|
||||
plt.figure(figsize = (20, 5))
|
||||
plt.subplot(1, 3, 1)
|
||||
|
||||
plt.plot(pr[2], pr[0][0:-1])
|
||||
plt.xlabel("threshold")
|
||||
plt.ylabel("precision")
|
||||
|
||||
plt.subplot(1, 3, 2)
|
||||
|
||||
plt.plot(pr[2], pr[1][0:-1])
|
||||
plt.xlabel("threshold")
|
||||
plt.ylabel("recall")
|
||||
plt.title("Keras", size=20)
|
||||
|
||||
plt.subplot(1, 3, 3)
|
||||
|
||||
plt.plot(pr[1], pr[0])
|
||||
plt.xlabel("recall")
|
||||
plt.ylabel("precision")
|
||||
|
||||
plt.savefig('with_twitter_clip.pdf')
|
||||
|
||||
def get_point_for_recall(recall_value, recall, precision):
|
||||
idx = np.argmin(np.abs(recall - recall_value))
|
||||
return (recall[idx], precision[idx])
|
||||
|
||||
def get_point_for_precision(precision_value, recall, precision):
|
||||
idx = np.argmin(np.abs(precision - precision_value))
|
||||
return (recall[idx], precision[idx])
|
||||
precision, recall, thresholds = pr
|
||||
|
||||
auc_precision_recall = sklearn.metrics.auc(recall, precision)
|
||||
|
||||
print(auc_precision_recall)
|
||||
|
||||
plt.figure(figsize=(15, 10))
|
||||
plt.plot(recall, precision)
|
||||
|
||||
plt.xlabel("recall")
|
||||
plt.ylabel("precision")
|
||||
|
||||
ptAt50 = get_point_for_recall(0.5, recall, precision)
|
||||
print(ptAt50)
|
||||
plt.plot( [ptAt50[0],ptAt50[0]], [0,ptAt50[1]], 'r')
|
||||
plt.plot([0, ptAt50[0]], [ptAt50[1], ptAt50[1]], 'r')
|
||||
|
||||
ptAt90 = get_point_for_recall(0.9, recall, precision)
|
||||
print(ptAt90)
|
||||
plt.plot( [ptAt90[0],ptAt90[0]], [0,ptAt90[1]], 'b')
|
||||
plt.plot([0, ptAt90[0]], [ptAt90[1], ptAt90[1]], 'b')
|
||||
|
||||
ptAt50fmt = "%.4f" % ptAt50[1]
|
||||
ptAt90fmt = "%.4f" % ptAt90[1]
|
||||
aucFmt = "%.4f" % auc_precision_recall
|
||||
plt.title(
|
||||
f"Keras (nsfw MU test)\nAUC={aucFmt}\np={ptAt50fmt} @ r=0.5\np={ptAt90fmt} @ r=0.9\nN_train={...}} ({...} pos), N_test={n_test} ({n_test_pos} pos)",
|
||||
size=20
|
||||
)
|
||||
plt.subplots_adjust(top=0.72)
|
||||
plt.savefig('recall_precision_nsfw_Keras_with_twitter_CLIP_MU_test.pdf')
|
||||
|
||||
precision, recall, thresholds = pr_sens_prev
|
||||
|
||||
auc_precision_recall = sklearn.metrics.auc(recall, precision)
|
||||
print(auc_precision_recall)
|
||||
plt.figure(figsize=(15, 10))
|
||||
|
||||
plt.plot(recall, precision)
|
||||
|
||||
plt.xlabel("recall")
|
||||
plt.ylabel("precision")
|
||||
|
||||
ptAt50 = get_point_for_recall(0.5, recall, precision)
|
||||
print(ptAt50)
|
||||
plt.plot( [ptAt50[0],ptAt50[0]], [0,ptAt50[1]], 'r')
|
||||
plt.plot([0, ptAt50[0]], [ptAt50[1], ptAt50[1]], 'r')
|
||||
|
||||
ptAt90 = get_point_for_recall(0.9, recall, precision)
|
||||
print(ptAt90)
|
||||
plt.plot( [ptAt90[0],ptAt90[0]], [0,ptAt90[1]], 'b')
|
||||
plt.plot([0, ptAt90[0]], [ptAt90[1], ptAt90[1]], 'b')
|
||||
|
||||
ptAt50fmt = "%.4f" % ptAt50[1]
|
||||
ptAt90fmt = "%.4f" % ptAt90[1]
|
||||
aucFmt = "%.4f" % auc_precision_recall
|
||||
plt.title(
|
||||
f"Keras (nsfw sens prev test)\nAUC={aucFmt}\np={ptAt50fmt} @ r=0.5\np={ptAt90fmt} @ r=0.9\nN_train={...} ({...} pos), N_test={n_test_sens_prev} ({n_test_sens_prev_pos} pos)",
|
||||
size=20
|
||||
)
|
||||
plt.subplots_adjust(top=0.72)
|
||||
plt.savefig('recall_precision_nsfw_Keras_with_twitter_CLIP_sens_prev_test.pdf')
|
152
trust_and_safety_models/nsfw/nsfw_text.py
Normal file
152
trust_and_safety_models/nsfw/nsfw_text.py
Normal file
@ -0,0 +1,152 @@
|
||||
from datetime import datetime
|
||||
from functools import reduce
|
||||
import os
|
||||
import pandas as pd
|
||||
import re
|
||||
from sklearn.metrics import average_precision_score, classification_report, precision_recall_curve, PrecisionRecallDisplay
|
||||
from sklearn.model_selection import train_test_split
|
||||
import tensorflow as tf
|
||||
import matplotlib.pyplot as plt
|
||||
import re
|
||||
|
||||
from twitter.cuad.representation.models.optimization import create_optimizer
|
||||
from twitter.cuad.representation.models.text_encoder import TextEncoder
|
||||
|
||||
pd.set_option('display.max_colwidth', None)
|
||||
pd.set_option('display.expand_frame_repr', False)
|
||||
|
||||
print(tf.__version__)
|
||||
print(tf.config.list_physical_devices())
|
||||
|
||||
log_path = os.path.join('pnsfwtweettext_model_runs', datetime.now().strftime('%Y-%m-%d_%H.%M.%S'))
|
||||
|
||||
tweet_text_feature = 'text'
|
||||
|
||||
params = {
|
||||
'batch_size': 32,
|
||||
'max_seq_lengths': 256,
|
||||
'model_type': 'twitter_bert_base_en_uncased_augmented_mlm',
|
||||
'trainable_text_encoder': True,
|
||||
'lr': 5e-5,
|
||||
'epochs': 10,
|
||||
}
|
||||
|
||||
REGEX_PATTERNS = [
|
||||
r'^RT @[A-Za-z0-9_]+: ',
|
||||
r"@[A-Za-z0-9_]+",
|
||||
r'https:\/\/t\.co\/[A-Za-z0-9]{10}',
|
||||
r'@\?\?\?\?\?',
|
||||
]
|
||||
|
||||
EMOJI_PATTERN = re.compile(
|
||||
"(["
|
||||
"\U0001F1E0-\U0001F1FF"
|
||||
"\U0001F300-\U0001F5FF"
|
||||
"\U0001F600-\U0001F64F"
|
||||
"\U0001F680-\U0001F6FF"
|
||||
"\U0001F700-\U0001F77F"
|
||||
"\U0001F780-\U0001F7FF"
|
||||
"\U0001F800-\U0001F8FF"
|
||||
"\U0001F900-\U0001F9FF"
|
||||
"\U0001FA00-\U0001FA6F"
|
||||
"\U0001FA70-\U0001FAFF"
|
||||
"\U00002702-\U000027B0"
|
||||
"])"
|
||||
)
|
||||
|
||||
def clean_tweet(text):
|
||||
for pattern in REGEX_PATTERNS:
|
||||
text = re.sub(pattern, '', text)
|
||||
|
||||
text = re.sub(EMOJI_PATTERN, r' \1 ', text)
|
||||
|
||||
text = re.sub(r'\n', ' ', text)
|
||||
|
||||
return text.strip().lower()
|
||||
|
||||
|
||||
df['processed_text'] = df['text'].astype(str).map(clean_tweet)
|
||||
df.sample(10)
|
||||
|
||||
X_train, X_val, y_train, y_val = train_test_split(df[['processed_text']], df['is_nsfw'], test_size=0.1, random_state=1)
|
||||
|
||||
def df_to_ds(X, y, shuffle=False):
|
||||
ds = tf.data.Dataset.from_tensor_slices((
|
||||
X.values,
|
||||
tf.one_hot(tf.cast(y.values, tf.int32), depth=2, axis=-1)
|
||||
))
|
||||
|
||||
if shuffle:
|
||||
ds = ds.shuffle(1000, seed=1, reshuffle_each_iteration=True)
|
||||
|
||||
return ds.map(lambda text, label: ({ tweet_text_feature: text }, label)).batch(params['batch_size'])
|
||||
|
||||
ds_train = df_to_ds(X_train, y_train, shuffle=True)
|
||||
ds_val = df_to_ds(X_val, y_val)
|
||||
X_train.values
|
||||
|
||||
inputs = tf.keras.layers.Input(shape=(), dtype=tf.string, name=tweet_text_feature)
|
||||
encoder = TextEncoder(
|
||||
max_seq_lengths=params['max_seq_lengths'],
|
||||
model_type=params['model_type'],
|
||||
trainable=params['trainable_text_encoder'],
|
||||
local_preprocessor_path='demo-preprocessor'
|
||||
)
|
||||
embedding = encoder([inputs])["pooled_output"]
|
||||
predictions = tf.keras.layers.Dense(2, activation='softmax')(embedding)
|
||||
model = tf.keras.models.Model(inputs=inputs, outputs=predictions)
|
||||
|
||||
model.summary()
|
||||
|
||||
optimizer = create_optimizer(
|
||||
params['lr'],
|
||||
params['epochs'] * len(ds_train),
|
||||
0,
|
||||
weight_decay_rate=0.01,
|
||||
optimizer_type='adamw'
|
||||
)
|
||||
bce = tf.keras.losses.BinaryCrossentropy(from_logits=False)
|
||||
pr_auc = tf.keras.metrics.AUC(curve='PR', num_thresholds=1000, from_logits=False)
|
||||
model.compile(optimizer=optimizer, loss=bce, metrics=[pr_auc])
|
||||
|
||||
callbacks = [
|
||||
tf.keras.callbacks.EarlyStopping(
|
||||
monitor='val_loss',
|
||||
mode='min',
|
||||
patience=1,
|
||||
restore_best_weights=True
|
||||
),
|
||||
tf.keras.callbacks.ModelCheckpoint(
|
||||
filepath=os.path.join(log_path, 'checkpoints', '{epoch:02d}'),
|
||||
save_freq='epoch'
|
||||
),
|
||||
tf.keras.callbacks.TensorBoard(
|
||||
log_dir=os.path.join(log_path, 'scalars'),
|
||||
update_freq='batch',
|
||||
write_graph=False
|
||||
)
|
||||
]
|
||||
history = model.fit(
|
||||
ds_train,
|
||||
epochs=params['epochs'],
|
||||
callbacks=callbacks,
|
||||
validation_data=ds_val,
|
||||
steps_per_epoch=len(ds_train)
|
||||
)
|
||||
|
||||
model.predict(["xxx 🍑"])
|
||||
|
||||
preds = X_val.processed_text.apply(apply_model)
|
||||
print(classification_report(y_val, preds >= 0.90, digits=4))
|
||||
|
||||
precision, recall, thresholds = precision_recall_curve(y_val, preds)
|
||||
|
||||
fig = plt.figure(figsize=(15, 10))
|
||||
plt.plot(precision, recall, lw=2)
|
||||
plt.grid()
|
||||
plt.xlim(0.2, 1)
|
||||
plt.ylim(0.3, 1)
|
||||
plt.xlabel("Recall", size=20)
|
||||
plt.ylabel("Precision", size=20)
|
||||
|
||||
average_precision_score(y_val, preds)
|
Reference in New Issue
Block a user