Twitter Recommendation Algorithm

Please note we have force-pushed a new initial commit in order to remove some publicly-available Twitter user information. Note that this process may be required in the future.
This commit is contained in:
twitter-team
2023-03-31 17:36:31 -05:00
commit ef4c5eb65e
5364 changed files with 460239 additions and 0 deletions

View File

@ -0,0 +1,21 @@
# pylint: disable=wildcard-import
"""
This module contains the ``tf.layers.Layer`` subclasses implemented in twml.
Layers are used to instantiate common subgraphs.
Typically, these layers are used when defining a ``build_graph_fn``
for the ``twml.trainers.Trainer``.
"""
from .batch_prediction_tensor_writer import BatchPredictionTensorWriter # noqa: F401
from .batch_prediction_writer import BatchPredictionWriter # noqa: F401
from .data_record_tensor_writer import DataRecordTensorWriter # noqa: F401
from .full_dense import full_dense, FullDense # noqa: F401
from .full_sparse import full_sparse, FullSparse # noqa: F401
from .isotonic import Isotonic # noqa: F401
from .layer import Layer # noqa: F401
from .mdl import MDL # noqa: F401
from .partition import Partition # noqa: F401
from .percentile_discretizer import PercentileDiscretizer # noqa: F401
from .sequential import Sequential # noqa: F401
from .sparse_max_norm import MaxNorm, sparse_max_norm, SparseMaxNorm # noqa: F401
from .stitch import Stitch # noqa: F401

View File

@ -0,0 +1,51 @@
# pylint: disable=no-member, invalid-name
"""
Implementing Writer Layer
"""
from .layer import Layer
import libtwml
class BatchPredictionTensorWriter(Layer):
"""
A layer that packages keys and dense tensors into a BatchPredictionResponse.
Typically used at the out of an exported model for use in a the PredictionEngine
(that is, in production) when model predictions are dense tensors.
Arguments:
keys:
keys to hashmap
Output:
output:
a BatchPredictionResponse serialized using Thrift into a uint8 tensor.
"""
def __init__(self, keys, **kwargs): # pylint: disable=useless-super-delegation
super(BatchPredictionTensorWriter, self).__init__(**kwargs)
self.keys = keys
def compute_output_shape(self, input_shape):
"""Computes the output shape of the layer given the input shape.
Args:
input_shape: A (possibly nested tuple of) `TensorShape`. It need not
be fully defined (e.g. the batch size may be unknown).
Raise NotImplementedError.
"""
raise NotImplementedError
def call(self, values, **kwargs): # pylint: disable=unused-argument, arguments-differ
"""The logic of the layer lives here.
Arguments:
values:
dense tensors corresponding to keys in hashmap
Returns:
The output from the layer
"""
write_op = libtwml.ops.batch_prediction_tensor_response_writer(self.keys, values)
return write_op

View File

@ -0,0 +1,51 @@
# pylint: disable=no-member, invalid-name
"""
Implementing Writer Layer
"""
from .layer import Layer
import libtwml
class BatchPredictionWriter(Layer):
"""
A layer that packages keys and values into a BatchPredictionResponse.
Typically used at the out of an exported model for use in a the PredictionEngine
(that is, in production).
Arguments:
keys:
keys to hashmap
Output:
output:
a BatchPredictionResponse serialized using Thrift into a uint8 tensor.
"""
def __init__(self, keys, **kwargs): # pylint: disable=useless-super-delegation
super(BatchPredictionWriter, self).__init__(**kwargs)
self.keys = keys
def compute_output_shape(self, input_shape):
"""Computes the output shape of the layer given the input shape.
Args:
input_shape: A (possibly nested tuple of) `TensorShape`. It need not
be fully defined (e.g. the batch size may be unknown).
Raise NotImplementedError.
"""
raise NotImplementedError
def call(self, values, **kwargs): # pylint: disable=unused-argument, arguments-differ
"""The logic of the layer lives here.
Arguments:
values:
values corresponding to keys in hashmap
Returns:
The output from the layer
"""
write_op = libtwml.ops.batch_prediction_response_writer(self.keys, values)
return write_op

View File

@ -0,0 +1,50 @@
# pylint: disable=no-member, invalid-name
"""
Implementing Writer Layer
"""
from .layer import Layer
import libtwml
class DataRecordTensorWriter(Layer):
"""
A layer that packages keys and dense tensors into a DataRecord.
This layer was initially added to support exporting user embeddings as tensors.
Arguments:
keys:
keys to hashmap
Output:
output:
a DataRecord serialized using Thrift into a uint8 tensor
"""
def __init__(self, keys, **kwargs): # pylint: disable=useless-super-delegation
super(DataRecordTensorWriter, self).__init__(**kwargs)
self.keys = keys
def compute_output_shape(self, input_shape):
"""Computes the output shape of the layer given the input shape.
Args:
input_shape: A (possibly nested tuple of) `TensorShape`. It need not
be fully defined (e.g. the batch size may be unknown).
Raises NotImplementedError.
"""
raise NotImplementedError
def call(self, values, **kwargs): # pylint: disable=unused-argument, arguments-differ
"""The logic of the layer lives here.
Arguments:
values:
dense tensors corresponding to keys in hashmap
Returns:
The output from the layer
"""
write_op = libtwml.ops.data_record_tensor_writer(self.keys, values)
return write_op

View File

@ -0,0 +1,259 @@
# pylint: disable=no-member,arguments-differ, attribute-defined-outside-init
"""
Implementing Full Dense Layer
"""
from tensorflow.python.layers import core as core_layers
from tensorflow.python.ops import init_ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras.engine.base_layer import InputSpec
import tensorflow.compat.v1 as tf
class FullDense(core_layers.Dense):
"""
Densely-connected layer class.
This is wrapping tensorflow.python.layers.core.Dense
This layer implements the operation:
.. code-block:: python
outputs = activation(inputs.weight + bias)
Where ``activation`` is the activation function passed as the ``activation``
argument (if not ``None``), ``weight`` is a weights matrix created by the layer,
and ``bias`` is a bias vector created by the layer.
Arguments:
output_size:
Integer or Long, dimensionality of the output space.
activation:
Activation function (callable). Set it to None to maintain a linear activation.
weight_initializer:
Initializer function for the weight matrix.
bias_initializer:
Initializer function for the bias.
weight_regularizer:
Regularizer function for the weight matrix.
Ensure to add tf.losses.get_regularization_loss() to your loss for this to take effect.
bias_regularizer:
Regularizer function for the bias.
Ensure to add tf.losses.get_regularization_loss() to your loss for this to take effect.
activity_regularizer:
Regularizer function for the output.
weight_constraint:
An optional projection function to be applied to the
weight after being updated by an `Optimizer` (e.g. used to implement
norm constraints or value constraints for layer weights). The function
must take as input the unprojected variable and must return the
projected variable (which must have the same shape). Constraints are
not safe to use when doing asynchronous distributed training.
bias_constraint:
An optional projection function to be applied to the
bias after being updated by an `Optimizer`.
trainable:
Boolean, if `True` also add variables to the graph collection
``GraphKeys.TRAINABLE_VARIABLES`` (see `tf.Variable
<https://www.tensorflow.org/versions/master/api_docs/python/tf/Variable>`_).
name:
String, the name of the layer. Layers with the same name will
share weights, but to avoid mistakes we require ``reuse=True`` in such cases.
Properties:
output_size:
Python integer, dimensionality of the output space.
activation:
Activation function (callable).
weight_initializer:
Initializer instance (or name) for the weight matrix.
bias_initializer:
Initializer instance (or name) for the bias.
weight:
Weight matrix (TensorFlow variable or tensor). (weight)
bias:
Bias vector, if applicable (TensorFlow variable or tensor).
weight_regularizer:
Regularizer instance for the weight matrix (callable)
bias_regularizer:
Regularizer instance for the bias (callable).
activity_regularizer:
Regularizer instance for the output (callable)
weight_constraint:
Constraint function for the weight matrix.
bias_constraint:
Constraint function for the bias.
"""
def __init__(self, output_size,
weight_initializer=None,
weight_regularizer=None,
weight_constraint=None,
bias_constraint=None,
num_partitions=None,
**kwargs):
super(FullDense, self).__init__(units=output_size,
kernel_initializer=weight_initializer,
kernel_regularizer=weight_regularizer,
kernel_constraint=weight_constraint,
**kwargs)
self._num_partitions = num_partitions
def build(self, input_shape):
'''
code adapted from TF 1.12 Keras Dense layer:
https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/python/keras/layers/core.py#L930-L956
'''
input_shape = tensor_shape.TensorShape(input_shape)
if input_shape[-1] is None:
raise ValueError('The last dimension of the inputs to `Dense` '
'should be defined. Found `None`.')
self.input_spec = InputSpec(min_ndim=2,
axes={-1: input_shape[-1]})
partitioner = None
if self._num_partitions:
partitioner = tf.fixed_size_partitioner(self._num_partitions)
self.kernel = self.add_weight(
'kernel',
shape=[input_shape[-1], self.units],
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint,
dtype=self.dtype,
partitioner=partitioner,
trainable=True)
if self.use_bias:
self.bias = self.add_weight(
'bias',
shape=[self.units, ],
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint,
dtype=self.dtype,
trainable=True)
else:
self.bias = None
self.built = True
@property
def output_size(self):
"""
Returns output_size
"""
return self.units
@property
def weight(self):
"""
Returns weight
"""
return self.kernel
@property
def weight_regularizer(self):
"""
Returns weight_regularizer
"""
return self.kernel_regularizer
@property
def weight_initializer(self):
"""
Returns weight_initializer
"""
return self.kernel_initializer
@property
def weight_constraint(self):
"""
Returns weight_constraint
"""
return self.kernel_constraint
def full_dense(inputs, output_size,
activation=None,
use_bias=True,
weight_initializer=None,
bias_initializer=init_ops.zeros_initializer(),
weight_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
weight_constraint=None,
bias_constraint=None,
trainable=True,
name=None,
num_partitions=None,
reuse=None):
"""Functional interface for the densely-connected layer.
This layer implements the operation:
`outputs = activation(inputs.weight + bias)`
Where `activation` is the activation function passed as the `activation`
argument (if not `None`), `weight` is a weights matrix created by the layer,
and `bias` is a bias vector created by the layer
(only if `use_bias` is `True`).
Arguments:
inputs: Tensor input.
units: Integer or Long, dimensionality of the output space.
activation: Activation function (callable). Set it to None to maintain a
linear activation.
use_bias: Boolean, whether the layer uses a bias.
weight_initializer: Initializer function for the weight matrix.
If `None` (default), weights are initialized using the default
initializer used by `tf.get_variable`.
bias_initializer:
Initializer function for the bias.
weight_regularizer:
Regularizer function for the weight matrix.
Ensure to add tf.losses.get_regularization_loss() to your loss for this to take effect.
bias_regularizer:
Regularizer function for the bias.
Ensure to add tf.losses.get_regularization_loss() to your loss for this to take effect.
activity_regularizer:
Regularizer function for the output.
weight_constraint:
An optional projection function to be applied to the
weight after being updated by an `Optimizer` (e.g. used to implement
norm constraints or value constraints for layer weights). The function
must take as input the unprojected variable and must return the
projected variable (which must have the same shape). Constraints are
not safe to use when doing asynchronous distributed training.
bias_constraint:
An optional projection function to be applied to the
bias after being updated by an `Optimizer`.
trainable:
Boolean, if `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
name:
String, the name of the layer.
reuse:
Boolean, whether to reuse the weights of a previous layer
by the same name.
Returns:
Output tensor the same shape as `inputs` except the last dimension is of
size `units`.
Raises:
ValueError: if eager execution is enabled.
"""
layer = FullDense(output_size,
activation=activation,
use_bias=use_bias,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
weight_regularizer=weight_regularizer,
bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
weight_constraint=weight_constraint,
bias_constraint=bias_constraint,
trainable=trainable,
name=name,
dtype=inputs.dtype.base_dtype,
num_partitions=num_partitions,
_scope=name,
_reuse=reuse)
return layer.apply(inputs)

View File

@ -0,0 +1,370 @@
# pylint: disable=no-member, arguments-differ, attribute-defined-outside-init, unused-argument
"""
Implementing Full Sparse Layer
"""
import math
from twitter.deepbird.sparse import sparse_dense_matmul
from .layer import Layer
import tensorflow.compat.v1 as tf
import twml
class FullSparse(Layer):
"""Fully-sparse layer class.
This layer implements the operation:
.. code-block:: python
outputs = activation(inputs.weight + bias)
Arguments:
output_size:
Long or Integer, dimensionality of the output space.
input_size:
The number of input units. (Deprecated)
weight_initializer:
Initializer function for the weight matrix.
This argument defaults to zeros_initializer().
This is valid when the FullSparse is the first layer of
parameters but should be changed otherwise.
weight_regularizer:
Regularizer function for the weight matrix.
Ensure to add tf.losses.get_regularization_loss() to your loss for this to take effect.
bias_regularizer:
Regularizer function for the bias.
Ensure to add tf.losses.get_regularization_loss() to your loss for this to take effect
activation:
Activation function (callable). Set it to None to maintain a linear activation.
bias_initializer:
Initializer function for the bias.
This argument defaults to tf.constant_initializer(1/output_size)
trainable:
Boolean, if `True` also add variables to the graph collection
``GraphKeys.TRAINABLE_VARIABLES`` (see `tf.Variable
<https://www.tensorflow.org/versions/master/api_docs/python/tf/Variable>`_).
name:
String, the name of the layer. Layers with the same name will
share weights, but to avoid mistakes we require ``reuse=True`` in such cases.
use_sparse_grads:
Boolean, if `True` do sparse mat mul with `embedding_lookup_sparse`, which will
make gradients to weight matrix also sparse in backward pass. This can lead to non-trivial
speed up at training time when input_size is large and optimizer handles sparse gradients
correctly (eg. with SGD or LazyAdamOptimizer). If weight matrix is small, it's recommended
to set this flag to `False`; for most use cases of FullSparse, however, weight matrix will
be large, so it's better to set it to `True`
num_partitions:
Number of partitions to use for the weight variable. Defaults to 1.
partition_axis:
If num_partitions is specified, the partition axis for the weight variable
Defaults to 0 (partition by row).
Must be 0 (row) or 1 (column)
use_binary_values:
Assume all non zero values are 1. Defaults to False.
This can improve training if used in conjunction with MDL.
This parameter can also be a list of binary values if `inputs` passed to `call` a list.
use_compression:
Default False. Set True to enable data compression techniques for
optimization of network traffic for distributed training.
use_binary_sparse_dense_matmul:
If binary sparse dense matmul op is to be used. It will only be enabled if
`use_binary_values` is set true. It only should be used for inference, best practice is
to set `use_binary_sparse_dense_matmul = not is_training`.
"""
def __init__(self,
output_size,
input_size=None,
weight_initializer=None,
activation=None,
bias_initializer=None,
trainable=True,
name=None,
use_sparse_grads=True,
num_partitions=None,
partition_axis=0,
use_binary_values=False,
bias_regularizer=None,
weight_regularizer=None,
use_compression=False,
use_binary_sparse_dense_matmul=False,
**kwargs):
super(FullSparse, self).__init__(trainable=trainable, name=name, **kwargs)
# TODO - remove input_size warning.
if input_size:
raise ValueError('input_size is deprecated - it is now automatically \
inferred from your input.')
# The bias initialization and weights initialization is set to match v1's implementation.
if bias_initializer is None:
bias_initializer = tf.constant_initializer(1 / output_size)
# Weights initialization is set to 0s. This is safe for full sparse layers because
# you are supposed to learn your embedding from the label.
if weight_initializer is None:
weight_initializer = tf.zeros_initializer()
self.weight_initializer = weight_initializer
self.bias_initializer = bias_initializer
self.output_size = output_size
self.activation = activation
self.use_sparse_grads = use_sparse_grads
self.num_partitions = num_partitions
if partition_axis != 0 and partition_axis != 1:
raise ValueError('partition_axis must be 0 or 1')
self.partition_axis = partition_axis
self.use_binary_values = use_binary_values
self.weight_regularizer = weight_regularizer
self.bias_regularizer = bias_regularizer
self._use_compression = use_compression
self._cast_indices_dtype = tf.int32 if self._use_compression else None
self.use_binary_sparse_dense_matmul = use_binary_sparse_dense_matmul
def _make_weight_var(self, shape, partitioner):
self.weight = self.add_variable(
'weight',
initializer=self.weight_initializer,
regularizer=self.weight_regularizer,
shape=shape,
dtype=self.dtype,
trainable=True,
partitioner=partitioner,
)
def build(self, input_shapes):
"""
creates the ``bias`` and ``weight`` Variables
of shape ``[output_size]`` and ``[input_size, output_size]`` respectively.
"""
if isinstance(input_shapes, (list, tuple)):
input_shape = input_shapes[0]
is_compatible = True
for other_shape in input_shapes[1:]:
is_compatible &= input_shape.is_compatible_with(other_shape)
if not is_compatible:
raise ValueError("Input shapes %s are not compatible." % input_shapes)
else:
input_shape = input_shapes
self.bias = self.add_variable(
'bias',
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
shape=[self.output_size, ],
dtype=self.dtype,
trainable=True
)
partitioner = None
shape = [input_shape[1], self.output_size]
# There is a 2gb limitation for each tensor because of protobuf.
# 2**30 is 1GB. 2 * (2**30) is 2GB.
dtype = tf.as_dtype(self.dtype)
num_partitions = 1 if self.num_partitions is None else self.num_partitions
in_shape = input_shape[1]
out_shape = self.output_size
# when v2 behavior is disabled, in_shape is tf.Dimension. otherwise it is int.
if isinstance(in_shape, tf.Dimension):
in_shape = in_shape.value
if in_shape is None:
raise ValueError("Input tensor should have shape."
" You can set it using twml.util.limit_sparse_tensor_size")
(split_dim, other_dim) = (in_shape, out_shape) if self.partition_axis == 0 else (out_shape, in_shape)
requested_size = math.ceil(float(split_dim) / num_partitions) * other_dim * dtype.size
if (requested_size >= 2**31):
raise ValueError("Weight tensor partitions cannot be larger than 2GB.\n"
"Requested Dimensions(%d, %d) of type %s (%d bytes total) over %d partitions.\n"
"Possible solutions:\n"
"- reduce the params.output_size_bits\n"
"- reduce the output_size of the sparse_layer\n"
"- specify a larger num_partitions argument\n"
"- reduce input_size_bits" %
(in_shape, self.output_size, dtype.name, requested_size, num_partitions))
if self.num_partitions:
partition_axis = int(self.partition_axis)
partitioner = tf.fixed_size_partitioner(self.num_partitions, axis=partition_axis)
else:
# Regular variables do not like it when you pass both constant tensors and shape
if not callable(self.weight_initializer):
shape = None
self._make_weight_var(shape, partitioner)
self.built = True
def compute_output_shape(self, input_shape):
"""Computes the output shape of the layer given the input shape.
Args:
input_shape: A (possibly nested tuple of) `TensorShape`. It need not
be fully defined (e.g. the batch size may be unknown).
Raises NotImplementedError.
"""
raise NotImplementedError
def call(self, inputs, **kwargs): # pylint: disable=unused-argument
"""The logic of the layer lives here.
Arguments:
inputs:
A SparseTensor or a list of SparseTensors.
If `inputs` is a list, all tensors must have same `dense_shape`.
Returns:
- If `inputs` is `SparseTensor`, then returns `bias + inputs * dense_b`.
- If `inputs` is a `list[SparseTensor`, then returns
`bias + add_n([sp_a * dense_b for sp_a in inputs])`.
"""
if isinstance(inputs, (list, tuple)):
if isinstance(self.use_binary_values, (list, tuple)):
use_binary_values = self.use_binary_values
else:
use_binary_values = [self.use_binary_values] * len(inputs)
num_inputs = len(inputs)
if num_inputs != len(use_binary_values):
raise ValueError("#inputs is %d while #use_binary_values is %d"
% (num_inputs, len(use_binary_values)))
outputs = []
for n in range(num_inputs):
outputs.append(sparse_dense_matmul(inputs[n], self.weight,
self.use_sparse_grads,
use_binary_values[n],
name='sparse_mm_' + str(n),
partition_axis=self.partition_axis,
num_partitions=self.num_partitions,
compress_ids=self._use_compression,
cast_indices_dtype=self._cast_indices_dtype,
use_binary_sparse_dense_matmul=self.use_binary_sparse_dense_matmul))
outputs = tf.accumulate_n(outputs)
else:
if isinstance(self.use_binary_values, (list, tuple)):
raise ValueError("use_binary_values can not be %s when inputs is %s" %
(type(self.use_binary_values), type(inputs)))
outputs = sparse_dense_matmul(inputs, self.weight,
self.use_sparse_grads,
self.use_binary_values,
name='sparse_mm',
partition_axis=self.partition_axis,
num_partitions=self.num_partitions,
compress_ids=self._use_compression,
cast_indices_dtype=self._cast_indices_dtype,
use_binary_sparse_dense_matmul=self.use_binary_sparse_dense_matmul)
if self.bias is not None:
outputs = tf.nn.bias_add(outputs, self.bias)
if self.activation is not None:
return self.activation(outputs) # pylint: disable=not-callable
return outputs
def full_sparse(
inputs, output_size,
input_size=None,
activation=None,
bias_regularizer=None,
weight_regularizer=None,
bias_initializer=None,
weight_initializer=None,
trainable=True,
name=None,
reuse=None,
use_sparse_grads=True,
num_partitions=None,
partition_axis=0,
use_binary_values=False,
use_compression=False):
"""Functional interface for the sparsely-connected layer.
Arguments:
inputs:
A sparse tensor (can be twml.SparseTensor or tf.SparseTensor)
output_size:
Long or Integer, dimensionality of the output space.
weight_initializer:
Initializer function for the weight matrix.
activation:
Activation function (callable). Set it to None to maintain a linear activation.
bias_initializer:
Initializer function for the bias.
weight_regularizer:
Regularizer function for the weight matrix.
Ensure to add tf.losses.get_regularization_loss() to your loss for this to take effect.
bias_regularizer:
Regularizer function for the bias.
Ensure to add tf.losses.get_regularization_loss() to your loss for this to take effect.
trainable:
Boolean, if `True` also add variables to the graph collection
``GraphKeys.TRAINABLE_VARIABLES`` (see `tf.Variable
<https://www.tensorflow.org/versions/master/api_docs/python/tf/Variable>`_).
name:
String, the name of the layer. Layers with the same name will
share weights, but to avoid mistakes we require ``reuse=True`` in such cases.
use_sparse_grads:
Boolean, if `True` do sparse mat mul with `embedding_lookup_sparse`, which will
make gradients to weight matrix also sparse in backward pass. This can lead to non-trivial
speed up at training time when input_size is large and optimizer handles sparse gradients
correctly (eg. with SGD or LazyAdamOptimizer). If weight matrix is small, it's recommended
to set this flag to `False`; for most use cases of FullSparse, however, weight matrix will
be large, so it's better to set it to `True`
num_partitions:
Number of partitions to use for the weight variable. Defaults to 1.
partition_axis:
If num_partitions is specified, the partition axis for the weight variable
Defaults to 0 (partition by row).
Must be 0 (row) or 1 (column)
use_binary_values:
Assume all non zero values are 1. Defaults to False.
This can improve training if used in conjunction with MDL.
use_compression:
Default False. Set True to enable data compression techniques for
optimization of network traffic for distributed training.
Returns:
Outputs a ``tf.Tensor`` of size ``[batch_size x output_size]``.
"""
# TODO - remove input_size warning.
if input_size:
raise ValueError('input_size is deprecated - it is now \
automatically inferred from your input.')
dtype = None
if isinstance(inputs, twml.SparseTensor):
inputs = inputs.to_tf()
dtype = inputs.dtype.base_dtype
if isinstance(inputs, (list, tuple)):
inputs = [inp.to_tf() if isinstance(inp, twml.SparseTensor) else inp for inp in inputs]
dtype = inputs[0].dtype.base_dtype
layer = FullSparse(output_size=output_size,
activation=activation,
trainable=trainable,
name=name,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
weight_regularizer=weight_regularizer,
bias_regularizer=bias_regularizer,
dtype=dtype,
_scope=name,
_reuse=reuse,
use_sparse_grads=use_sparse_grads,
num_partitions=num_partitions,
partition_axis=partition_axis,
use_compression=use_compression,
use_binary_values=use_binary_values)
return layer(inputs)

View File

@ -0,0 +1,76 @@
# pylint: disable=no-member, invalid-name, attribute-defined-outside-init
"""
Contains the Isotonic Layer
"""
from .layer import Layer
import libtwml
import numpy as np
class Isotonic(Layer):
"""
This layer is created by the IsotonicCalibrator.
Typically it is used intead of sigmoid activation on the output unit.
Arguments:
n_unit:
number of input units to the layer (same as number of output units).
n_bin:
number of bins used for isotonic calibration.
More bins means a more precise isotonic function.
Less bins means a more regularized isotonic function.
xs_input:
A tensor containing the boundaries of the bins.
ys_input:
A tensor containing calibrated values for the corresponding bins.
Output:
output:
A layer containing calibrated probabilities with same shape and size as input.
Expected Sizes:
xs_input, ys_input:
[n_unit, n_bin].
Expected Types:
xs_input, ys_input:
same as input.
"""
def __init__(self, n_unit, n_bin, xs_input=None, ys_input=None, **kwargs):
super(Isotonic, self).__init__(**kwargs)
self._n_unit = n_unit
self._n_bin = n_bin
self.xs_input = np.empty([n_unit, n_bin], dtype=np.float32) if xs_input is None else xs_input
self.ys_input = np.empty([n_unit, n_bin], dtype=np.float32) if ys_input is None else ys_input
def compute_output_shape(self, input_shape):
"""Computes the output shape of the layer given the input shape.
Args:
input_shape: A (possibly nested tuple of) `TensorShape`. It need not
be fully defined (e.g. the batch size may be unknown).
Raises NotImplementedError.
"""
raise NotImplementedError
def build(self, input_shape): # pylint: disable=unused-argument
"""Creates the variables of the layer."""
self.built = True
def call(self, inputs, **kwargs): # pylint: disable=unused-argument
"""The logic of the layer lives here.
Arguments:
inputs: input tensor(s).
Returns:
The output from the layer
"""
calibrate_op = libtwml.ops.isotonic_calibration(inputs, self.xs_input, self.ys_input)
return calibrate_op

50
twml/twml/layers/layer.py Normal file
View File

@ -0,0 +1,50 @@
# pylint: disable=no-member
"""
Implementing a base layer for twml
"""
import tensorflow.compat.v1 as tf
from tensorflow.python.layers import base
class Layer(base.Layer):
"""
Base Layer implementation for twml.
Overloads `twml.layers.Layer
<https://www.tensorflow.org/versions/master/api_docs/python/tf/layers/Layer>`_
from tensorflow and adds a couple of custom methods.
"""
@property
def init(self):
"""
Return initializer ops. By default returns tf.no_op().
This method is overwritten by classes like twml.layers.MDL, which
uses a HashTable internally, that must be initialized with its own op.
"""
return tf.no_op()
def call(self, inputs, **kwargs):
"""The logic of the layer lives here.
Arguments:
inputs:
input tensor(s).
**kwargs:
additional keyword arguments.
Returns:
Output tensor(s).
"""
raise NotImplementedError
def compute_output_shape(self, input_shape):
"""Computes the output shape of the layer given the input shape.
Args:
input_shape: A (possibly nested tuple of) `TensorShape`. It need not
be fully defined (e.g. the batch size may be unknown).
Raise NotImplementedError.
"""
raise NotImplementedError

256
twml/twml/layers/mdl.py Normal file
View File

@ -0,0 +1,256 @@
# pylint: disable=no-member, attribute-defined-outside-init, too-many-instance-attributes
"""
Implementing MDL Layer
"""
from .layer import Layer
from .partition import Partition
from .stitch import Stitch
import libtwml
import numpy as np
import tensorflow.compat.v1 as tf
import twml
class MDL(Layer): # noqa: T000
"""
MDL layer is constructed by MDLCalibrator after accumulating data
and performing minimum description length (MDL) calibration.
MDL takes sparse continuous features and converts then to sparse
binary features. Each binary output feature is associated to an MDL bin.
Each MDL input feature is converted to n_bin bins.
Each MDL calibration tries to find bin delimiters such that the number of features values
per bin is roughly equal (for each given MDL feature).
Note that if an input feature is rarely used, so will its associated output bin/features.
"""
def __init__(
self,
n_feature, n_bin, out_bits,
bin_values=None, hash_keys=None, hash_values=None,
bin_ids=None, feature_offsets=None, **kwargs):
"""
Creates a non-initialized `MDL` object.
Before using the table you will have to initialize it. After initialization
the table will be immutable.
Parent class args:
see [tf.layers.Layer](https://www.tensorflow.org/api_docs/python/tf/layers/Layer)
for documentation of parent class arguments.
Required args:
n_feature:
number of unique features accumulated during MDL calibration.
This is the number of features in the hash map.
Used to initialize bin_values, hash_keys, hash_values,
bin_ids, bin_values and feature_offsets.
n_bin:
number of MDL bins used for MDL calibration.
Used to initialize bin_values, hash_keys, hash_values,
bin_ids, bin_values and feature_offsets.
out_bits:
Determines the maximum value for output feature IDs.
The dense_shape of the SparseTensor returned by lookup(x)
will be [x.shape[0], 1 << output_bits].
Optional args:
hash_keys:
contains the features ID that MDL discretizes and knows about.
The hash map (hash_keys->hash_values) is used for two reasons:
1. divide inputs into two feature spaces: MDL vs non-MDL
2. transate the MDL features into a hash_feature ID that MDL understands.
The hash_map is expected to contain n_feature items.
hash_values:
translates the feature IDs into hash_feature IDs for MDL.
bin_ids:
a 1D Tensor of size n_feature * n_bin + 1 which contains
unique IDs to which the MDL features will be translated to.
For example, tf.Tensor(np.arange(n_feature * n_bin)) would produce
the most efficient output space.
bin_values:
a 1D Tensor aligned with bin_ids.
For a given hash_feature ID j, it's value bin's are indexed between
`j*n_bin` and `j*n_bin + n_bin-1`.
As such, bin_ids[j*n_bin+i] is translated from a hash_feature ID of j
and a inputs value between
`bin_values[j*n_bin + i]` and `bin_values[j*n_bin+i+1]`.
feature_offsets:
a 1D Tensor specifying the starting location of bins for a given feature id.
For example, tf.Tensor(np.arange(0, bin_values.size, n_bin, dtype='int64')).
"""
super(MDL, self).__init__(**kwargs)
tf.logging.warning("MDL will be deprecated. Please use PercentileDiscretizer instead")
max_mdl_feature = n_feature * (n_bin + 1)
self._n_feature = n_feature
self._n_bin = n_bin
self._hash_keys_initializer = tf.constant_initializer(
hash_keys if hash_keys is not None
else np.empty(n_feature, dtype=np.int64),
dtype=np.int64
)
self._hash_values_initializer = tf.constant_initializer(
hash_values if hash_values is not None
else np.empty(n_feature, dtype=np.int64),
dtype=np.int64
)
self._bin_ids_initializer = tf.constant_initializer(
bin_ids if bin_ids is not None
else np.empty(max_mdl_feature, dtype=np.int64),
dtype=np.int64
)
self._bin_values_initializer = tf.constant_initializer(
bin_values if bin_values is not None
else np.empty(max_mdl_feature, dtype=np.float32),
dtype=np.float32
)
self._feature_offsets_initializer = tf.constant_initializer(
feature_offsets if feature_offsets is not None
else np.empty(n_feature, dtype=np.int64),
dtype=np.int64
)
# note that calling build here is an exception as typically __call__ would call build().
# We call it here because we need to initialize hash_map.
# Also note that the variable_scope is set by add_variable in build()
if not self.built:
self.build(input_shape=None)
self.output_size = tf.convert_to_tensor(1 << out_bits, tf.int64)
def build(self, input_shape): # pylint: disable=unused-argument
"""
Creates the variables of the layer:
hash_keys, hash_values, bin_ids, bin_values, feature_offsets and self.output_size.
"""
# build layers
self.partition = Partition()
self.stitch = Stitch()
# build variables
hash_keys = self.add_variable(
'hash_keys',
initializer=self._hash_keys_initializer,
shape=[self._n_feature],
dtype=tf.int64,
trainable=False)
hash_values = self.add_variable(
'hash_values',
initializer=self._hash_values_initializer,
shape=[self._n_feature],
dtype=tf.int64,
trainable=False)
# hashmap converts known features into range [0, n_feature)
initializer = tf.lookup.KeyValueTensorInitializer(hash_keys, hash_values)
self.hash_map = tf.lookup.StaticHashTable(initializer, -1)
self.bin_ids = self.add_variable(
'bin_ids',
initializer=self._bin_ids_initializer,
shape=[self._n_feature * (self._n_bin + 1)],
dtype=tf.int64,
trainable=False)
self.bin_values = self.add_variable(
'bin_values',
initializer=self._bin_values_initializer,
shape=[self._n_feature * (self._n_bin + 1)],
dtype=tf.float32,
trainable=False)
self.feature_offsets = self.add_variable(
'feature_offsets',
initializer=self._feature_offsets_initializer,
shape=[self._n_feature],
dtype=tf.int64,
trainable=False)
# make sure this is last
self.built = True
def call(self, inputs, **kwargs):
"""Looks up `keys` in a table, outputs the corresponding values.
Implements MDL inference where inputs are intersected with a hash_map.
Part of the inputs are discretized using twml.mdl to produce a mdl_output SparseTensor.
This SparseTensor is then joined with the original inputs SparseTensor,
but only for the inputs keys that did not get discretized.
Args:
inputs: A 2D SparseTensor that is input to MDL for discretization.
It has a dense_shape of [batch_size, input_size]
name: A name for the operation (optional).
Returns:
A `SparseTensor` of the same type as `inputs`.
Its dense_shape is [shape_input.dense_shape[0], 1 << output_bits].
"""
if isinstance(inputs, tf.SparseTensor):
inputs = twml.SparseTensor.from_tf(inputs)
assert(isinstance(inputs, twml.SparseTensor))
# sparse column indices
ids = inputs.ids
# sparse row indices
keys = inputs.indices
# sparse values
vals = inputs.values
# get intersect(keys, hash_map)
hashed_keys = self.hash_map.lookup(keys)
found = tf.not_equal(hashed_keys, tf.constant(-1, tf.int64))
partition_ids = tf.cast(found, tf.int32)
vals, key, indices = self.partition(partition_ids, vals, tf.where(found, hashed_keys, keys))
non_mdl_keys, mdl_in_keys = key
non_mdl_vals, mdl_in_vals = vals
self.non_mdl_keys = non_mdl_keys
# run MDL on the keys/values it knows about
mdl_keys, mdl_vals = libtwml.ops.mdl(mdl_in_keys, mdl_in_vals, self.bin_ids, self.bin_values,
self.feature_offsets)
# handle output ID conflicts
mdl_size = tf.size(self.bin_ids, out_type=tf.int64)
non_mdl_size = tf.subtract(self.output_size, mdl_size)
non_mdl_keys = tf.add(tf.floormod(non_mdl_keys, non_mdl_size), mdl_size)
# Stitch the keys and values from mdl and non mdl indices back, with help
# of the Stitch Layer
# out for inference checking
self.mdl_out_keys = mdl_keys
concat_data = self.stitch([non_mdl_vals, mdl_vals],
[non_mdl_keys, mdl_keys],
indices)
concat_vals, concat_keys = concat_data
# Generate output shape using _compute_output_shape
batch_size = tf.to_int64(inputs.dense_shape[0])
output_shape = [batch_size, self.output_size]
return twml.SparseTensor(ids, concat_keys, concat_vals, output_shape).to_tf()
def compute_output_shape(self, input_shape):
"""Computes the output shape of the layer given the input shape.
Args:
input_shape: A (possibly nested tuple of) `TensorShape`. It need not
be fully defined (e.g. the batch size may be unknown).
Raises NotImplementedError.
"""
raise NotImplementedError

View File

@ -0,0 +1,74 @@
"""
Implementing partition Layer
"""
from .layer import Layer
import tensorflow.compat.v1 as tf
class Partition(Layer):
"""
This layer implements:
.. code-block:: python
tf.dynamic_partition(input_vals, partition_ids, self.partitions)
Input:
partitions:
the number of partitions which we will divide the hashmap keys/bvalues
Output:
A layer that performs partitioning
"""
def __init__(self, partitions=2, **kwargs):
self.partitions = partitions
super(Partition, self).__init__(**kwargs)
def compute_output_shape(self, input_shape):
"""Computes the output shape of the layer given the input shape.
Args:
input_shape: A (possibly nested tuple of) `TensorShape`. It need not
be fully defined (e.g. the batch size may be unknown).
Raises NotImplementedError.
"""
raise NotImplementedError
def call(self, partition_ids, input_vals, input_keys, **kwargs):
"""This layer is responsible for partitioning the values/keys of a hashmap
Arguments:
partition_ids:
Tensor that is equivalent to boolean (int32).
input_vals:
Tensor that represents the values of the hashmap(float).
input_keys:
Tensor that represents the keys of the hashmap(float)
Returns:
The output of the partition layer, which is a list of lists which looks
something like:
.. code-block:: python
[[vals_0, vals_1], [keys_0, keys_1], [indices_0, indices_1]]
where:
vals_x:
values of the hashmap for partition x
keys_x:
keys of the hashmap for partition x
indices_x:
indices of the hashmap for partition x
"""
partioned_val = tf.dynamic_partition(input_vals, partition_ids, self.partitions)
partioned_keys = tf.dynamic_partition(input_keys, partition_ids, self.partitions)
partioned_indices = tf.dynamic_partition(tf.range(tf.shape(partition_ids)[0]),
tf.cast(partition_ids, tf.int32), self.partitions)
return [partioned_val, partioned_keys, partioned_indices]

View File

@ -0,0 +1,209 @@
# pylint: disable=no-member, attribute-defined-outside-init, too-many-instance-attributes
"""
Implementing PercentileDiscretizer Layer
"""
import libtwml
import numpy as np
import tensorflow.compat.v1 as tf
import twml
from twml.layers import Layer
class PercentileDiscretizer(Layer):
"""
PercentileDiscretizer layer is constructed by PercentileDiscretizerCalibrator after
accumulating data and performing percentile bucket calibration.
PercentileDiscretizer takes sparse continuous features and converts then to sparse
binary features. Each binary output feature is associated to an PercentileDiscretizer bin.
Each PercentileDiscretizer input feature is converted to n_bin bins.
Each PercentileDiscretizer calibration tries to find bin delimiters such
that the number of features values per bin is roughly equal (for
each given PercentileDiscretizer feature). In other words, bins are calibrated to be approx.
equiprobable, according to the given calibration data.
Note that if an input feature is rarely used, so will its associated output bin/features.
"""
def __init__(
self,
n_feature, n_bin, out_bits,
bin_values=None, hash_keys=None, hash_values=None,
bin_ids=None, feature_offsets=None, num_parts=1, cost_per_unit=100, **kwargs):
"""
Creates a non-initialized `PercentileDiscretizer` object.
Before using the table you will have to initialize it. After initialization
the table will be immutable.
If there are no calibrated features, then the discretizer will only apply
twml.util.limit_bits to the the feature keys (aka "feature_ids"). Essentially,
the discretizer will be a "no-operation", other than obeying `out_bits`
Parent class args:
see [tf.layers.Layer](https://www.tensorflow.org/api_docs/python/tf/layers/Layer)
for documentation of parent class arguments.
Required args:
n_feature:
number of unique features accumulated during PercentileDiscretizer calibration.
This is the number of features in the hash map.
Used to initialize bin_values, hash_keys, hash_values,
bin_ids, bin_values and feature_offsets.
n_bin:
number of PercentileDiscretizer bins used for PercentileDiscretizer calibration.
Used to initialize bin_values, hash_keys, hash_values,
bin_ids, bin_values and feature_offsets.
out_bits:
Determines the maximum value for output feature IDs.
The dense_shape of the SparseTensor returned by lookup(x)
will be [x.shape[0], 1 << output_bits].
Optional args:
hash_keys:
contains the features ID that PercentileDiscretizer discretizes and knows about.
The hash map (hash_keys->hash_values) is used for two reasons:
1. divide inputs into two feature spaces:
PercentileDiscretizer vs non-PercentileDiscretizer
2. transate the PercentileDiscretizer features into a hash_feature ID that
PercentileDiscretizer understands.
The hash_map is expected to contain n_feature items.
hash_values:
translates the feature IDs into hash_feature IDs for PercentileDiscretizer.
bin_ids:
a 1D Tensor of size n_feature * n_bin + 1 which contains
unique IDs to which the PercentileDiscretizer features will be translated to.
For example, tf.Tensor(np.arange(n_feature * n_bin)) would produce
the most efficient output space.
bin_values:
a 1D Tensor aligned with bin_ids.
For a given hash_feature ID j, it's value bin's are indexed between
`j*n_bin` and `j*n_bin + n_bin-1`.
As such, bin_ids[j*n_bin+i] is translated from a hash_feature ID of j
and a inputs value between
`bin_values[j*n_bin + i]` and `bin_values[j*n_bin+i+1]`.
feature_offsets:
a 1D Tensor specifying the starting location of bins for a given feature id.
For example, tf.Tensor(np.arange(0, bin_values.size, n_bin, dtype='int64')).
"""
super(PercentileDiscretizer, self).__init__(**kwargs)
if not self.built:
self.build(input_shape=None)
max_discretizer_feature = n_feature * (n_bin + 1)
self._n_feature = n_feature
self._n_bin = n_bin
# build variables
self._out_bits = out_bits
self._output_size = tf.convert_to_tensor(1 << out_bits, tf.int64)
self._hash_keys = (hash_keys if hash_keys is not None else
np.empty(n_feature, dtype=np.int64))
self._hash_values = (hash_values if hash_values is not None else
np.empty(n_feature, dtype=np.int64))
self._bin_ids = (bin_ids if bin_ids is not None else
np.empty(max_discretizer_feature, dtype=np.int64))
self._bin_values = (bin_values if bin_values is not None else
np.empty(max_discretizer_feature, dtype=np.float32))
self._feature_offsets = (feature_offsets if feature_offsets is not None else
np.empty(n_feature, dtype=np.int64))
self.num_parts = num_parts
self.cost_per_unit = cost_per_unit
def build(self, input_shape): # pylint: disable=unused-argument
"""
Creates the variables of the layer
"""
self.built = True
def call(self, inputs, keep_inputs=False, **kwargs):
"""Looks up `keys` in a table, outputs the corresponding values.
Implements PercentileDiscretizer inference where inputs are intersected with a hash_map.
Input features that were not calibrated have their feature IDs truncated, so as
to be less than 1<<output_bits, but their values remain untouched (not discretized)
If there are no calibrated features, then the discretizer will only apply
twml.util.limit_bits to the the feature keys (aka "feature_ids"). Essentially,
the discretizer will be a "no-operation", other than obeying `out_bits`
Args:
inputs: A 2D SparseTensor that is input to PercentileDiscretizer for discretization.
It has a dense_shape of [batch_size, input_size]
keep_inputs:
Include the original inputs in the output.
Note - if True, undiscretized features will be passed through, but will have
their values doubled (unless there are no calibrated features to discretize).
name: A name for the operation (optional).
Returns:
A `SparseTensor` of the same type as `inputs`.
Its dense_shape is [shape_input.dense_shape[0], 1 << output_bits].
"""
if isinstance(inputs, tf.SparseTensor):
inputs = twml.SparseTensor.from_tf(inputs)
assert(isinstance(inputs, twml.SparseTensor))
# sparse column indices
ids = inputs.ids
# sparse row indices
keys = inputs.indices
# sparse values
vals = inputs.values
if self._n_feature > 0:
discretizer_keys, discretizer_vals = libtwml.ops.percentile_discretizer_v2(
input_ids=keys, # inc key assigned to feature_id, or -1
input_vals=vals, # the observed feature values
bin_ids=self._bin_ids, # n_feat X (n_bin+1) 2D arange
bin_vals=self._bin_values, # bin boundaries
feature_offsets=self._feature_offsets, # 0 : nbin_1 : max_feat
output_bits=self._out_bits,
feature_ids=tf.make_tensor_proto(self._hash_keys), # feature ids to build internal hash map
feature_indices=tf.make_tensor_proto(self._hash_values), # keys associated w/ feat. indices
start_compute=tf.constant(0, shape=[], dtype=tf.int64),
end_compute=tf.constant(-1, shape=[], dtype=tf.int64),
cost_per_unit=self.cost_per_unit
)
else:
discretizer_keys = twml.util.limit_bits(keys, self._out_bits)
discretizer_vals = vals
# don't 2x the input.
keep_inputs = False
batch_size = tf.to_int64(inputs.dense_shape[0])
output_shape = [batch_size, self._output_size]
output = twml.SparseTensor(ids, discretizer_keys, discretizer_vals, output_shape).to_tf()
if keep_inputs:
# Note the non-discretized features will end up doubled,
# since these are already in `output`
# handle output ID conflicts
mdl_size = self._n_feature * (self._n_bin + 1)
non_mdl_size = tf.subtract(self._output_size, mdl_size)
input_keys = tf.add(tf.floormod(keys, non_mdl_size), mdl_size)
new_input = twml.SparseTensor(
ids=ids, indices=input_keys, values=vals, dense_shape=output_shape).to_tf()
# concatenate discretizer output with original input
sparse_add = tf.sparse_add(new_input, output)
output = tf.SparseTensor(sparse_add.indices, sparse_add.values, output_shape)
return output
def compute_output_shape(self, input_shape):
"""Computes the output shape of the layer given the input shape.
Args:
input_shape: A (possibly nested tuple of) `TensorShape`. It need not
be fully defined (e.g. the batch size may be unknown).
Raises NotImplementedError.
"""
raise NotImplementedError

View File

@ -0,0 +1,160 @@
"""
Implementing Sequential Layer container
"""
from .layer import Layer
from tensorflow import keras
from tensorflow.python.layers import base
class Sequential(Layer):
"""
A sequential stack of layers.
Arguments:
layers: list of layers to add to the model.
Output:
the output of the sequential layers
"""
def __init__(self, layers=None, **kwargs):
self._layers = [] # Stack of layers.
self._layer_names = [] # Stack of layers names
self._layer_outputs = []
# Add to the model any layers passed to the constructor.
if layers:
for layer in layers:
self.add(layer)
super(Sequential, self).__init__(**kwargs)
def add(self, layer):
"""Adds a layer instance on top of the layer stack.
Arguments:
layer:
layer instance.
Raises:
TypeError:
if the layer argument is not instance of base.Layer
"""
if not isinstance(layer, base.Layer) and not isinstance(layer, keras.layers.Layer):
raise TypeError('The added layer must be an instance of class Layer')
if layer.name in self._layer_names:
raise ValueError('Layer with name %s already exists in sequential layer' % layer.name)
self._layers.append(layer)
self._layer_names.append(layer.name)
def pop(self):
"""Removes the last layer in the model.
Raises:
TypeError:
if there are no layers in the model.
"""
if not self._layers or not self._layer_names:
raise TypeError('There are no layers in the model.')
self._layers.pop()
self._layer_names.pop()
def call(self, inputs, **kwargs): # pylint: disable=unused-argument
"""The logic of the layer lives here.
Arguments:
inputs:
input tensor(s).
Returns:
The output of the sequential layers
"""
self._layer_outputs = []
for layer in self._layers:
# don't use layer.call because you want to build individual layers
inputs = layer(inputs) # overwrites the current input after it has been processed
self._layer_outputs.append(inputs)
return inputs
@property
def layers(self):
""" Return the layers in the sequential layer """
return self._layers
@property
def layer_names(self):
""" Return the layer names in the sequential layer """
return self._layer_names
@property
def layer_outputs(self):
""" Return the layer outputs in the sequential layer """
return self._layer_outputs
def get(self, key):
"""Retrieves the n-th layer.
Arguments:
key:
index of the layer
Output:
The n-th layer where n is equal to the key.
"""
return self._layers[key]
def get_output(self, key):
"""Retrieves the n-th layer output.
Arguments:
key:
index of the layer
Output:
The intermediary output equivalent to the nth layer, where n is equal to the key.
"""
return self._layer_outputs[key]
def get_layer_by_name(self, name):
"""Retrieves the layer corresponding to the name.
Arguments:
name:
name of the layer
Output:
list of layers that have the name desired
"""
return self._layers[self._layer_names.index(name)]
def get_layer_output_by_name(self, name):
"""Retrieves the layer output corresponding to the name.
Arguments:
name:
name of the layer
Output:
list of the output of the layers that have the desired name
"""
return self._layer_outputs[self._layer_names.index(name)]
@property
def init(self):
""" returns a list of initialization ops (one per layer) """
return [layer.init for layer in self._layers]
def compute_output_shape(self, input_shape):
"""Computes the output shape of the layer given the input shape.
Args:
input_shape: A (possibly nested tuple of) `TensorShape`. It need not
be fully defined (e.g. the batch size may be unknown).
Raise NotImplementedError.
"""
raise NotImplementedError

View File

@ -0,0 +1,221 @@
# pylint: disable=no-member, attribute-defined-outside-init, duplicate-code
"""
Contains the twml.layers.SparseMaxNorm layer.
"""
from .layer import Layer
from libtwml import OPLIB
import tensorflow.compat.v1 as tf
import twml
class SparseMaxNorm(Layer):
"""
Computes a max-normalization and adds bias to the sparse_input,
forwards that through a sparse affine transform followed
by an non-linear activation on the resulting dense representation.
This layer has two parameters, one of which learns through gradient descent:
bias_x (optional):
vector of shape [input_size]. Learned through gradient descent.
max_x:
vector of shape [input_size]. Holds the maximas of input ``x`` for normalization.
Either calibrated through SparseMaxNorm calibrator, or calibrated online, or both.
The pseudo-code for this layer looks like:
.. code-block:: python
abs_x = abs(x)
normed_x = clip_by_value(x / max_x, -1, 1)
biased_x = normed_x + bias_x
return biased
Args:
max_x_initializer:
initializer vector of shape [input_size] used by variable `max_x`
bias_x_initializer:
initializer vector of shape [input_size] used by parameter `bias_x`
is_training:
Are we training the layer to learn the normalization maximas.
If set to True, max_x will be able to learn. This is independent of bias_x
epsilon:
The minimum value used for max_x. Defaults to 1E-5.
use_bias:
Default True. Set to False to not use a bias term.
Returns:
A layer representing the output of the sparse_max_norm transformation.
"""
def __init__(
self,
input_size=None,
max_x_initializer=None,
bias_x_initializer=None,
is_training=True,
epsilon=1E-5,
use_bias=True,
**kwargs):
super(SparseMaxNorm, self).__init__(**kwargs)
if input_size:
raise ValueError('input_size is deprecated - it is now automatically \
inferred from your input.')
if max_x_initializer is None:
max_x_initializer = tf.zeros_initializer()
self.max_x_initializer = max_x_initializer
self._use_bias = use_bias
if use_bias:
if bias_x_initializer is None:
bias_x_initializer = tf.zeros_initializer()
self.bias_x_initializer = bias_x_initializer
self.epsilon = epsilon
self.is_training = is_training
def build(self, input_shape): # pylint: disable=unused-argument
"""Creates the max_x and bias_x tf.Variables of the layer."""
self.max_x = self.add_variable(
'max_x',
initializer=self.max_x_initializer,
shape=[input_shape[1]],
dtype=tf.float32,
trainable=False)
if self._use_bias:
self.bias_x = self.add_variable(
'bias_x',
initializer=self.bias_x_initializer,
shape=[input_shape[1]],
dtype=tf.float32,
trainable=True)
self.built = True
def compute_output_shape(self, input_shape):
"""Computes the output shape of the layer given the input shape.
Args:
input_shape: A (possibly nested tuple of) `TensorShape`. It need not
be fully defined (e.g. the batch size may be unknown).
Raises NotImplementedError.
"""
raise NotImplementedError
def _call(self, inputs, **kwargs): # pylint: disable=unused-argument
"""
The forward propagation logic of the layer lives here.
Arguments:
sparse_input:
A 2D ``tf.SparseTensor`` of dense_shape ``[batch_size, input_size]``
Returns:
A ``tf.SparseTensor`` representing the output of the max_norm transformation, this can
be fed into twml.layers.FullSparse in order to be transformed into a ``tf.Tensor``.
"""
if isinstance(inputs, twml.SparseTensor):
inputs = inputs.to_tf()
elif not isinstance(inputs, tf.SparseTensor):
raise TypeError("The inputs must be of type tf.SparseTensor or twml.SparseTensor")
indices_x = inputs.indices[:, 1]
values_x = inputs.values
if self.is_training is False:
normalized_x = OPLIB.sparse_max_norm_inference(self.max_x,
indices_x,
values_x,
self.epsilon)
update_op = tf.no_op()
else:
max_x, normalized_x = OPLIB.sparse_max_norm_training(self.max_x,
indices_x,
values_x,
self.epsilon)
update_op = tf.assign(self.max_x, max_x)
with tf.control_dependencies([update_op]):
normalized_x = tf.stop_gradient(normalized_x)
# add input bias
if self._use_bias:
normalized_x = normalized_x + tf.gather(self.bias_x, indices_x)
# convert back to sparse tensor
return tf.SparseTensor(inputs.indices, normalized_x, inputs.dense_shape)
def call(self, inputs, **kwargs): # pylint: disable=unused-argument
"""
The forward propagation logic of the layer lives here.
Arguments:
sparse_input:
A 2D ``tf.SparseTensor`` of dense_shape ``[batch_size, input_size]``
Returns:
A ``tf.SparseTensor`` representing the output of the max_norm transformation, this can
be fed into twml.layers.FullSparse in order to be transformed into a ``tf.Tensor``.
"""
with tf.device(self.max_x.device):
return self._call(inputs, **kwargs)
# For backwards compatiblity and also because I don't want to change all the tests.
MaxNorm = SparseMaxNorm
def sparse_max_norm(inputs,
input_size=None,
max_x_initializer=None,
bias_x_initializer=None,
is_training=True,
epsilon=1E-5,
use_bias=True,
name=None,
reuse=None):
"""
Functional inteface to SparseMaxNorm.
Args:
inputs:
A sparse tensor (can be twml.SparseTensor or tf.SparseTensor)
input_size:
number of input units
max_x_initializer:
initializer vector of shape [input_size] used by variable `max_x`
bias_x_initializer:
initializer vector of shape [input_size] used by parameter `bias_x`
is_training:
Are we training the layer to learn the normalization maximas.
If set to True, max_x will be able to learn. This is independent of bias_x
epsilon:
The minimum value used for max_x. Defaults to 1E-5.
use_bias:
Default True. Set to False to not use a bias term.
Returns:
Output after normalizing with the max value.
"""
if input_size:
raise ValueError('input_size is deprecated - it is now automatically \
inferred from your input.')
if isinstance(inputs, twml.SparseTensor):
inputs = inputs.to_tf()
layer = SparseMaxNorm(max_x_initializer=max_x_initializer,
bias_x_initializer=bias_x_initializer,
is_training=is_training,
epsilon=epsilon,
use_bias=use_bias,
name=name,
_scope=name,
_reuse=reuse)
return layer(inputs)

View File

@ -0,0 +1,54 @@
# pylint: disable=useless-super-delegation
"""
Implementing Stitch Layer
"""
from .layer import Layer
import tensorflow.compat.v1 as tf
class Stitch(Layer):
"""
This layer is responsible for stitching a partioned layer together.
Output:
A layer that performs stitching
"""
def compute_output_shape(self, input_shape):
"""Computes the output shape of the layer given the input shape.
Args:
input_shape: A (possibly nested tuple of) `TensorShape`. It need not
be fully defined (e.g. the batch size may be unknown).
Raises NotImplementedError.
"""
raise NotImplementedError
def call(self, partioned_val, partioned_keys,
partioned_indices, **kwargs): # pylint: disable=unused-argument, arguments-differ
"""
This layer is responsible for stitching a partioned layer together.
Input:
partioned_val:
a list of partioned Tensors which represent the vals of the hashmap
partioned_keys:
a list of partioned Tensors which represent the keys of the hashmap
partioned_indices:
a list of partioned Tensors which represent the indices of the hashmap
Output:
List which contains: [output_vals, output_keys]
output_vals:
Values of the HashMap (float)
output_keys:
Keys of HashMap (float)
"""
indices = [tf.to_int32(index) for index in partioned_indices]
concat_keys = tf.dynamic_stitch(indices, partioned_keys)
concat_vals = tf.dynamic_stitch(indices, partioned_val)
return [concat_vals, concat_keys]