mirror of
https://github.com/twitter/the-algorithm.git
synced 2025-06-16 09:38:12 -05:00
Twitter Recommendation Algorithm
Please note we have force-pushed a new initial commit in order to remove some publicly-available Twitter user information. Note that this process may be required in the future.
This commit is contained in:
21
twml/twml/layers/__init__.py
Normal file
21
twml/twml/layers/__init__.py
Normal file
@ -0,0 +1,21 @@
|
||||
# pylint: disable=wildcard-import
|
||||
"""
|
||||
This module contains the ``tf.layers.Layer`` subclasses implemented in twml.
|
||||
Layers are used to instantiate common subgraphs.
|
||||
Typically, these layers are used when defining a ``build_graph_fn``
|
||||
for the ``twml.trainers.Trainer``.
|
||||
"""
|
||||
|
||||
from .batch_prediction_tensor_writer import BatchPredictionTensorWriter # noqa: F401
|
||||
from .batch_prediction_writer import BatchPredictionWriter # noqa: F401
|
||||
from .data_record_tensor_writer import DataRecordTensorWriter # noqa: F401
|
||||
from .full_dense import full_dense, FullDense # noqa: F401
|
||||
from .full_sparse import full_sparse, FullSparse # noqa: F401
|
||||
from .isotonic import Isotonic # noqa: F401
|
||||
from .layer import Layer # noqa: F401
|
||||
from .mdl import MDL # noqa: F401
|
||||
from .partition import Partition # noqa: F401
|
||||
from .percentile_discretizer import PercentileDiscretizer # noqa: F401
|
||||
from .sequential import Sequential # noqa: F401
|
||||
from .sparse_max_norm import MaxNorm, sparse_max_norm, SparseMaxNorm # noqa: F401
|
||||
from .stitch import Stitch # noqa: F401
|
51
twml/twml/layers/batch_prediction_tensor_writer.py
Normal file
51
twml/twml/layers/batch_prediction_tensor_writer.py
Normal file
@ -0,0 +1,51 @@
|
||||
# pylint: disable=no-member, invalid-name
|
||||
"""
|
||||
Implementing Writer Layer
|
||||
"""
|
||||
from .layer import Layer
|
||||
|
||||
import libtwml
|
||||
|
||||
|
||||
class BatchPredictionTensorWriter(Layer):
|
||||
"""
|
||||
A layer that packages keys and dense tensors into a BatchPredictionResponse.
|
||||
Typically used at the out of an exported model for use in a the PredictionEngine
|
||||
(that is, in production) when model predictions are dense tensors.
|
||||
|
||||
Arguments:
|
||||
keys:
|
||||
keys to hashmap
|
||||
Output:
|
||||
output:
|
||||
a BatchPredictionResponse serialized using Thrift into a uint8 tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, keys, **kwargs): # pylint: disable=useless-super-delegation
|
||||
super(BatchPredictionTensorWriter, self).__init__(**kwargs)
|
||||
self.keys = keys
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
"""Computes the output shape of the layer given the input shape.
|
||||
|
||||
Args:
|
||||
input_shape: A (possibly nested tuple of) `TensorShape`. It need not
|
||||
be fully defined (e.g. the batch size may be unknown).
|
||||
|
||||
Raise NotImplementedError.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def call(self, values, **kwargs): # pylint: disable=unused-argument, arguments-differ
|
||||
"""The logic of the layer lives here.
|
||||
|
||||
Arguments:
|
||||
values:
|
||||
dense tensors corresponding to keys in hashmap
|
||||
|
||||
Returns:
|
||||
The output from the layer
|
||||
"""
|
||||
write_op = libtwml.ops.batch_prediction_tensor_response_writer(self.keys, values)
|
||||
return write_op
|
51
twml/twml/layers/batch_prediction_writer.py
Normal file
51
twml/twml/layers/batch_prediction_writer.py
Normal file
@ -0,0 +1,51 @@
|
||||
# pylint: disable=no-member, invalid-name
|
||||
"""
|
||||
Implementing Writer Layer
|
||||
"""
|
||||
from .layer import Layer
|
||||
|
||||
import libtwml
|
||||
|
||||
|
||||
class BatchPredictionWriter(Layer):
|
||||
"""
|
||||
A layer that packages keys and values into a BatchPredictionResponse.
|
||||
Typically used at the out of an exported model for use in a the PredictionEngine
|
||||
(that is, in production).
|
||||
|
||||
Arguments:
|
||||
keys:
|
||||
keys to hashmap
|
||||
Output:
|
||||
output:
|
||||
a BatchPredictionResponse serialized using Thrift into a uint8 tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, keys, **kwargs): # pylint: disable=useless-super-delegation
|
||||
super(BatchPredictionWriter, self).__init__(**kwargs)
|
||||
self.keys = keys
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
"""Computes the output shape of the layer given the input shape.
|
||||
|
||||
Args:
|
||||
input_shape: A (possibly nested tuple of) `TensorShape`. It need not
|
||||
be fully defined (e.g. the batch size may be unknown).
|
||||
|
||||
Raise NotImplementedError.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def call(self, values, **kwargs): # pylint: disable=unused-argument, arguments-differ
|
||||
"""The logic of the layer lives here.
|
||||
|
||||
Arguments:
|
||||
values:
|
||||
values corresponding to keys in hashmap
|
||||
|
||||
Returns:
|
||||
The output from the layer
|
||||
"""
|
||||
write_op = libtwml.ops.batch_prediction_response_writer(self.keys, values)
|
||||
return write_op
|
50
twml/twml/layers/data_record_tensor_writer.py
Normal file
50
twml/twml/layers/data_record_tensor_writer.py
Normal file
@ -0,0 +1,50 @@
|
||||
# pylint: disable=no-member, invalid-name
|
||||
"""
|
||||
Implementing Writer Layer
|
||||
"""
|
||||
from .layer import Layer
|
||||
|
||||
import libtwml
|
||||
|
||||
|
||||
class DataRecordTensorWriter(Layer):
|
||||
"""
|
||||
A layer that packages keys and dense tensors into a DataRecord.
|
||||
This layer was initially added to support exporting user embeddings as tensors.
|
||||
|
||||
Arguments:
|
||||
keys:
|
||||
keys to hashmap
|
||||
Output:
|
||||
output:
|
||||
a DataRecord serialized using Thrift into a uint8 tensor
|
||||
"""
|
||||
|
||||
def __init__(self, keys, **kwargs): # pylint: disable=useless-super-delegation
|
||||
super(DataRecordTensorWriter, self).__init__(**kwargs)
|
||||
self.keys = keys
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
"""Computes the output shape of the layer given the input shape.
|
||||
|
||||
Args:
|
||||
input_shape: A (possibly nested tuple of) `TensorShape`. It need not
|
||||
be fully defined (e.g. the batch size may be unknown).
|
||||
|
||||
Raises NotImplementedError.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def call(self, values, **kwargs): # pylint: disable=unused-argument, arguments-differ
|
||||
"""The logic of the layer lives here.
|
||||
|
||||
Arguments:
|
||||
values:
|
||||
dense tensors corresponding to keys in hashmap
|
||||
|
||||
Returns:
|
||||
The output from the layer
|
||||
"""
|
||||
write_op = libtwml.ops.data_record_tensor_writer(self.keys, values)
|
||||
return write_op
|
259
twml/twml/layers/full_dense.py
Normal file
259
twml/twml/layers/full_dense.py
Normal file
@ -0,0 +1,259 @@
|
||||
# pylint: disable=no-member,arguments-differ, attribute-defined-outside-init
|
||||
"""
|
||||
Implementing Full Dense Layer
|
||||
"""
|
||||
from tensorflow.python.layers import core as core_layers
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.keras.engine.base_layer import InputSpec
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
|
||||
class FullDense(core_layers.Dense):
|
||||
"""
|
||||
Densely-connected layer class.
|
||||
This is wrapping tensorflow.python.layers.core.Dense
|
||||
This layer implements the operation:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
outputs = activation(inputs.weight + bias)
|
||||
|
||||
Where ``activation`` is the activation function passed as the ``activation``
|
||||
argument (if not ``None``), ``weight`` is a weights matrix created by the layer,
|
||||
and ``bias`` is a bias vector created by the layer.
|
||||
|
||||
Arguments:
|
||||
output_size:
|
||||
Integer or Long, dimensionality of the output space.
|
||||
activation:
|
||||
Activation function (callable). Set it to None to maintain a linear activation.
|
||||
weight_initializer:
|
||||
Initializer function for the weight matrix.
|
||||
bias_initializer:
|
||||
Initializer function for the bias.
|
||||
weight_regularizer:
|
||||
Regularizer function for the weight matrix.
|
||||
Ensure to add tf.losses.get_regularization_loss() to your loss for this to take effect.
|
||||
bias_regularizer:
|
||||
Regularizer function for the bias.
|
||||
Ensure to add tf.losses.get_regularization_loss() to your loss for this to take effect.
|
||||
activity_regularizer:
|
||||
Regularizer function for the output.
|
||||
weight_constraint:
|
||||
An optional projection function to be applied to the
|
||||
weight after being updated by an `Optimizer` (e.g. used to implement
|
||||
norm constraints or value constraints for layer weights). The function
|
||||
must take as input the unprojected variable and must return the
|
||||
projected variable (which must have the same shape). Constraints are
|
||||
not safe to use when doing asynchronous distributed training.
|
||||
bias_constraint:
|
||||
An optional projection function to be applied to the
|
||||
bias after being updated by an `Optimizer`.
|
||||
trainable:
|
||||
Boolean, if `True` also add variables to the graph collection
|
||||
``GraphKeys.TRAINABLE_VARIABLES`` (see `tf.Variable
|
||||
<https://www.tensorflow.org/versions/master/api_docs/python/tf/Variable>`_).
|
||||
name:
|
||||
String, the name of the layer. Layers with the same name will
|
||||
share weights, but to avoid mistakes we require ``reuse=True`` in such cases.
|
||||
|
||||
Properties:
|
||||
output_size:
|
||||
Python integer, dimensionality of the output space.
|
||||
activation:
|
||||
Activation function (callable).
|
||||
weight_initializer:
|
||||
Initializer instance (or name) for the weight matrix.
|
||||
bias_initializer:
|
||||
Initializer instance (or name) for the bias.
|
||||
weight:
|
||||
Weight matrix (TensorFlow variable or tensor). (weight)
|
||||
bias:
|
||||
Bias vector, if applicable (TensorFlow variable or tensor).
|
||||
weight_regularizer:
|
||||
Regularizer instance for the weight matrix (callable)
|
||||
bias_regularizer:
|
||||
Regularizer instance for the bias (callable).
|
||||
activity_regularizer:
|
||||
Regularizer instance for the output (callable)
|
||||
weight_constraint:
|
||||
Constraint function for the weight matrix.
|
||||
bias_constraint:
|
||||
Constraint function for the bias.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, output_size,
|
||||
weight_initializer=None,
|
||||
weight_regularizer=None,
|
||||
weight_constraint=None,
|
||||
bias_constraint=None,
|
||||
num_partitions=None,
|
||||
**kwargs):
|
||||
super(FullDense, self).__init__(units=output_size,
|
||||
kernel_initializer=weight_initializer,
|
||||
kernel_regularizer=weight_regularizer,
|
||||
kernel_constraint=weight_constraint,
|
||||
**kwargs)
|
||||
self._num_partitions = num_partitions
|
||||
|
||||
def build(self, input_shape):
|
||||
'''
|
||||
code adapted from TF 1.12 Keras Dense layer:
|
||||
https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/python/keras/layers/core.py#L930-L956
|
||||
'''
|
||||
input_shape = tensor_shape.TensorShape(input_shape)
|
||||
if input_shape[-1] is None:
|
||||
raise ValueError('The last dimension of the inputs to `Dense` '
|
||||
'should be defined. Found `None`.')
|
||||
self.input_spec = InputSpec(min_ndim=2,
|
||||
axes={-1: input_shape[-1]})
|
||||
|
||||
partitioner = None
|
||||
if self._num_partitions:
|
||||
partitioner = tf.fixed_size_partitioner(self._num_partitions)
|
||||
|
||||
self.kernel = self.add_weight(
|
||||
'kernel',
|
||||
shape=[input_shape[-1], self.units],
|
||||
initializer=self.kernel_initializer,
|
||||
regularizer=self.kernel_regularizer,
|
||||
constraint=self.kernel_constraint,
|
||||
dtype=self.dtype,
|
||||
partitioner=partitioner,
|
||||
trainable=True)
|
||||
|
||||
if self.use_bias:
|
||||
self.bias = self.add_weight(
|
||||
'bias',
|
||||
shape=[self.units, ],
|
||||
initializer=self.bias_initializer,
|
||||
regularizer=self.bias_regularizer,
|
||||
constraint=self.bias_constraint,
|
||||
dtype=self.dtype,
|
||||
trainable=True)
|
||||
else:
|
||||
self.bias = None
|
||||
self.built = True
|
||||
|
||||
@property
|
||||
def output_size(self):
|
||||
"""
|
||||
Returns output_size
|
||||
"""
|
||||
return self.units
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
"""
|
||||
Returns weight
|
||||
"""
|
||||
return self.kernel
|
||||
|
||||
@property
|
||||
def weight_regularizer(self):
|
||||
"""
|
||||
Returns weight_regularizer
|
||||
"""
|
||||
return self.kernel_regularizer
|
||||
|
||||
@property
|
||||
def weight_initializer(self):
|
||||
"""
|
||||
Returns weight_initializer
|
||||
"""
|
||||
return self.kernel_initializer
|
||||
|
||||
@property
|
||||
def weight_constraint(self):
|
||||
"""
|
||||
Returns weight_constraint
|
||||
"""
|
||||
return self.kernel_constraint
|
||||
|
||||
|
||||
def full_dense(inputs, output_size,
|
||||
activation=None,
|
||||
use_bias=True,
|
||||
weight_initializer=None,
|
||||
bias_initializer=init_ops.zeros_initializer(),
|
||||
weight_regularizer=None,
|
||||
bias_regularizer=None,
|
||||
activity_regularizer=None,
|
||||
weight_constraint=None,
|
||||
bias_constraint=None,
|
||||
trainable=True,
|
||||
name=None,
|
||||
num_partitions=None,
|
||||
reuse=None):
|
||||
"""Functional interface for the densely-connected layer.
|
||||
This layer implements the operation:
|
||||
`outputs = activation(inputs.weight + bias)`
|
||||
Where `activation` is the activation function passed as the `activation`
|
||||
argument (if not `None`), `weight` is a weights matrix created by the layer,
|
||||
and `bias` is a bias vector created by the layer
|
||||
(only if `use_bias` is `True`).
|
||||
|
||||
Arguments:
|
||||
inputs: Tensor input.
|
||||
units: Integer or Long, dimensionality of the output space.
|
||||
activation: Activation function (callable). Set it to None to maintain a
|
||||
linear activation.
|
||||
use_bias: Boolean, whether the layer uses a bias.
|
||||
weight_initializer: Initializer function for the weight matrix.
|
||||
If `None` (default), weights are initialized using the default
|
||||
initializer used by `tf.get_variable`.
|
||||
bias_initializer:
|
||||
Initializer function for the bias.
|
||||
weight_regularizer:
|
||||
Regularizer function for the weight matrix.
|
||||
Ensure to add tf.losses.get_regularization_loss() to your loss for this to take effect.
|
||||
bias_regularizer:
|
||||
Regularizer function for the bias.
|
||||
Ensure to add tf.losses.get_regularization_loss() to your loss for this to take effect.
|
||||
activity_regularizer:
|
||||
Regularizer function for the output.
|
||||
weight_constraint:
|
||||
An optional projection function to be applied to the
|
||||
weight after being updated by an `Optimizer` (e.g. used to implement
|
||||
norm constraints or value constraints for layer weights). The function
|
||||
must take as input the unprojected variable and must return the
|
||||
projected variable (which must have the same shape). Constraints are
|
||||
not safe to use when doing asynchronous distributed training.
|
||||
bias_constraint:
|
||||
An optional projection function to be applied to the
|
||||
bias after being updated by an `Optimizer`.
|
||||
trainable:
|
||||
Boolean, if `True` also add variables to the graph collection
|
||||
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
|
||||
name:
|
||||
String, the name of the layer.
|
||||
reuse:
|
||||
Boolean, whether to reuse the weights of a previous layer
|
||||
by the same name.
|
||||
|
||||
Returns:
|
||||
Output tensor the same shape as `inputs` except the last dimension is of
|
||||
size `units`.
|
||||
|
||||
Raises:
|
||||
ValueError: if eager execution is enabled.
|
||||
"""
|
||||
layer = FullDense(output_size,
|
||||
activation=activation,
|
||||
use_bias=use_bias,
|
||||
weight_initializer=weight_initializer,
|
||||
bias_initializer=bias_initializer,
|
||||
weight_regularizer=weight_regularizer,
|
||||
bias_regularizer=bias_regularizer,
|
||||
activity_regularizer=activity_regularizer,
|
||||
weight_constraint=weight_constraint,
|
||||
bias_constraint=bias_constraint,
|
||||
trainable=trainable,
|
||||
name=name,
|
||||
dtype=inputs.dtype.base_dtype,
|
||||
num_partitions=num_partitions,
|
||||
_scope=name,
|
||||
_reuse=reuse)
|
||||
return layer.apply(inputs)
|
370
twml/twml/layers/full_sparse.py
Normal file
370
twml/twml/layers/full_sparse.py
Normal file
@ -0,0 +1,370 @@
|
||||
# pylint: disable=no-member, arguments-differ, attribute-defined-outside-init, unused-argument
|
||||
"""
|
||||
Implementing Full Sparse Layer
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
from twitter.deepbird.sparse import sparse_dense_matmul
|
||||
|
||||
from .layer import Layer
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
import twml
|
||||
|
||||
|
||||
class FullSparse(Layer):
|
||||
"""Fully-sparse layer class.
|
||||
This layer implements the operation:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
outputs = activation(inputs.weight + bias)
|
||||
|
||||
Arguments:
|
||||
output_size:
|
||||
Long or Integer, dimensionality of the output space.
|
||||
input_size:
|
||||
The number of input units. (Deprecated)
|
||||
weight_initializer:
|
||||
Initializer function for the weight matrix.
|
||||
This argument defaults to zeros_initializer().
|
||||
This is valid when the FullSparse is the first layer of
|
||||
parameters but should be changed otherwise.
|
||||
weight_regularizer:
|
||||
Regularizer function for the weight matrix.
|
||||
Ensure to add tf.losses.get_regularization_loss() to your loss for this to take effect.
|
||||
bias_regularizer:
|
||||
Regularizer function for the bias.
|
||||
Ensure to add tf.losses.get_regularization_loss() to your loss for this to take effect
|
||||
activation:
|
||||
Activation function (callable). Set it to None to maintain a linear activation.
|
||||
bias_initializer:
|
||||
Initializer function for the bias.
|
||||
This argument defaults to tf.constant_initializer(1/output_size)
|
||||
trainable:
|
||||
Boolean, if `True` also add variables to the graph collection
|
||||
``GraphKeys.TRAINABLE_VARIABLES`` (see `tf.Variable
|
||||
<https://www.tensorflow.org/versions/master/api_docs/python/tf/Variable>`_).
|
||||
name:
|
||||
String, the name of the layer. Layers with the same name will
|
||||
share weights, but to avoid mistakes we require ``reuse=True`` in such cases.
|
||||
use_sparse_grads:
|
||||
Boolean, if `True` do sparse mat mul with `embedding_lookup_sparse`, which will
|
||||
make gradients to weight matrix also sparse in backward pass. This can lead to non-trivial
|
||||
speed up at training time when input_size is large and optimizer handles sparse gradients
|
||||
correctly (eg. with SGD or LazyAdamOptimizer). If weight matrix is small, it's recommended
|
||||
to set this flag to `False`; for most use cases of FullSparse, however, weight matrix will
|
||||
be large, so it's better to set it to `True`
|
||||
num_partitions:
|
||||
Number of partitions to use for the weight variable. Defaults to 1.
|
||||
partition_axis:
|
||||
If num_partitions is specified, the partition axis for the weight variable
|
||||
Defaults to 0 (partition by row).
|
||||
Must be 0 (row) or 1 (column)
|
||||
use_binary_values:
|
||||
Assume all non zero values are 1. Defaults to False.
|
||||
This can improve training if used in conjunction with MDL.
|
||||
This parameter can also be a list of binary values if `inputs` passed to `call` a list.
|
||||
use_compression:
|
||||
Default False. Set True to enable data compression techniques for
|
||||
optimization of network traffic for distributed training.
|
||||
use_binary_sparse_dense_matmul:
|
||||
If binary sparse dense matmul op is to be used. It will only be enabled if
|
||||
`use_binary_values` is set true. It only should be used for inference, best practice is
|
||||
to set `use_binary_sparse_dense_matmul = not is_training`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
output_size,
|
||||
input_size=None,
|
||||
weight_initializer=None,
|
||||
activation=None,
|
||||
bias_initializer=None,
|
||||
trainable=True,
|
||||
name=None,
|
||||
use_sparse_grads=True,
|
||||
num_partitions=None,
|
||||
partition_axis=0,
|
||||
use_binary_values=False,
|
||||
bias_regularizer=None,
|
||||
weight_regularizer=None,
|
||||
use_compression=False,
|
||||
use_binary_sparse_dense_matmul=False,
|
||||
**kwargs):
|
||||
super(FullSparse, self).__init__(trainable=trainable, name=name, **kwargs)
|
||||
# TODO - remove input_size warning.
|
||||
if input_size:
|
||||
raise ValueError('input_size is deprecated - it is now automatically \
|
||||
inferred from your input.')
|
||||
|
||||
# The bias initialization and weights initialization is set to match v1's implementation.
|
||||
if bias_initializer is None:
|
||||
bias_initializer = tf.constant_initializer(1 / output_size)
|
||||
# Weights initialization is set to 0s. This is safe for full sparse layers because
|
||||
# you are supposed to learn your embedding from the label.
|
||||
if weight_initializer is None:
|
||||
weight_initializer = tf.zeros_initializer()
|
||||
self.weight_initializer = weight_initializer
|
||||
self.bias_initializer = bias_initializer
|
||||
self.output_size = output_size
|
||||
self.activation = activation
|
||||
self.use_sparse_grads = use_sparse_grads
|
||||
self.num_partitions = num_partitions
|
||||
if partition_axis != 0 and partition_axis != 1:
|
||||
raise ValueError('partition_axis must be 0 or 1')
|
||||
self.partition_axis = partition_axis
|
||||
self.use_binary_values = use_binary_values
|
||||
self.weight_regularizer = weight_regularizer
|
||||
self.bias_regularizer = bias_regularizer
|
||||
self._use_compression = use_compression
|
||||
self._cast_indices_dtype = tf.int32 if self._use_compression else None
|
||||
self.use_binary_sparse_dense_matmul = use_binary_sparse_dense_matmul
|
||||
|
||||
def _make_weight_var(self, shape, partitioner):
|
||||
self.weight = self.add_variable(
|
||||
'weight',
|
||||
initializer=self.weight_initializer,
|
||||
regularizer=self.weight_regularizer,
|
||||
shape=shape,
|
||||
dtype=self.dtype,
|
||||
trainable=True,
|
||||
partitioner=partitioner,
|
||||
)
|
||||
|
||||
def build(self, input_shapes):
|
||||
"""
|
||||
creates the ``bias`` and ``weight`` Variables
|
||||
of shape ``[output_size]`` and ``[input_size, output_size]`` respectively.
|
||||
"""
|
||||
|
||||
if isinstance(input_shapes, (list, tuple)):
|
||||
input_shape = input_shapes[0]
|
||||
is_compatible = True
|
||||
for other_shape in input_shapes[1:]:
|
||||
is_compatible &= input_shape.is_compatible_with(other_shape)
|
||||
if not is_compatible:
|
||||
raise ValueError("Input shapes %s are not compatible." % input_shapes)
|
||||
else:
|
||||
input_shape = input_shapes
|
||||
|
||||
self.bias = self.add_variable(
|
||||
'bias',
|
||||
initializer=self.bias_initializer,
|
||||
regularizer=self.bias_regularizer,
|
||||
shape=[self.output_size, ],
|
||||
dtype=self.dtype,
|
||||
trainable=True
|
||||
)
|
||||
|
||||
partitioner = None
|
||||
shape = [input_shape[1], self.output_size]
|
||||
|
||||
# There is a 2gb limitation for each tensor because of protobuf.
|
||||
# 2**30 is 1GB. 2 * (2**30) is 2GB.
|
||||
dtype = tf.as_dtype(self.dtype)
|
||||
num_partitions = 1 if self.num_partitions is None else self.num_partitions
|
||||
in_shape = input_shape[1]
|
||||
out_shape = self.output_size
|
||||
|
||||
# when v2 behavior is disabled, in_shape is tf.Dimension. otherwise it is int.
|
||||
if isinstance(in_shape, tf.Dimension):
|
||||
in_shape = in_shape.value
|
||||
|
||||
if in_shape is None:
|
||||
raise ValueError("Input tensor should have shape."
|
||||
" You can set it using twml.util.limit_sparse_tensor_size")
|
||||
|
||||
(split_dim, other_dim) = (in_shape, out_shape) if self.partition_axis == 0 else (out_shape, in_shape)
|
||||
requested_size = math.ceil(float(split_dim) / num_partitions) * other_dim * dtype.size
|
||||
if (requested_size >= 2**31):
|
||||
raise ValueError("Weight tensor partitions cannot be larger than 2GB.\n"
|
||||
"Requested Dimensions(%d, %d) of type %s (%d bytes total) over %d partitions.\n"
|
||||
"Possible solutions:\n"
|
||||
"- reduce the params.output_size_bits\n"
|
||||
"- reduce the output_size of the sparse_layer\n"
|
||||
"- specify a larger num_partitions argument\n"
|
||||
"- reduce input_size_bits" %
|
||||
(in_shape, self.output_size, dtype.name, requested_size, num_partitions))
|
||||
|
||||
if self.num_partitions:
|
||||
partition_axis = int(self.partition_axis)
|
||||
partitioner = tf.fixed_size_partitioner(self.num_partitions, axis=partition_axis)
|
||||
else:
|
||||
# Regular variables do not like it when you pass both constant tensors and shape
|
||||
if not callable(self.weight_initializer):
|
||||
shape = None
|
||||
|
||||
self._make_weight_var(shape, partitioner)
|
||||
|
||||
self.built = True
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
"""Computes the output shape of the layer given the input shape.
|
||||
|
||||
Args:
|
||||
input_shape: A (possibly nested tuple of) `TensorShape`. It need not
|
||||
be fully defined (e.g. the batch size may be unknown).
|
||||
|
||||
Raises NotImplementedError.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def call(self, inputs, **kwargs): # pylint: disable=unused-argument
|
||||
"""The logic of the layer lives here.
|
||||
|
||||
Arguments:
|
||||
inputs:
|
||||
A SparseTensor or a list of SparseTensors.
|
||||
If `inputs` is a list, all tensors must have same `dense_shape`.
|
||||
|
||||
Returns:
|
||||
- If `inputs` is `SparseTensor`, then returns `bias + inputs * dense_b`.
|
||||
- If `inputs` is a `list[SparseTensor`, then returns
|
||||
`bias + add_n([sp_a * dense_b for sp_a in inputs])`.
|
||||
|
||||
"""
|
||||
if isinstance(inputs, (list, tuple)):
|
||||
|
||||
if isinstance(self.use_binary_values, (list, tuple)):
|
||||
use_binary_values = self.use_binary_values
|
||||
else:
|
||||
use_binary_values = [self.use_binary_values] * len(inputs)
|
||||
|
||||
num_inputs = len(inputs)
|
||||
if num_inputs != len(use_binary_values):
|
||||
raise ValueError("#inputs is %d while #use_binary_values is %d"
|
||||
% (num_inputs, len(use_binary_values)))
|
||||
|
||||
outputs = []
|
||||
for n in range(num_inputs):
|
||||
outputs.append(sparse_dense_matmul(inputs[n], self.weight,
|
||||
self.use_sparse_grads,
|
||||
use_binary_values[n],
|
||||
name='sparse_mm_' + str(n),
|
||||
partition_axis=self.partition_axis,
|
||||
num_partitions=self.num_partitions,
|
||||
compress_ids=self._use_compression,
|
||||
cast_indices_dtype=self._cast_indices_dtype,
|
||||
use_binary_sparse_dense_matmul=self.use_binary_sparse_dense_matmul))
|
||||
outputs = tf.accumulate_n(outputs)
|
||||
else:
|
||||
|
||||
if isinstance(self.use_binary_values, (list, tuple)):
|
||||
raise ValueError("use_binary_values can not be %s when inputs is %s" %
|
||||
(type(self.use_binary_values), type(inputs)))
|
||||
|
||||
outputs = sparse_dense_matmul(inputs, self.weight,
|
||||
self.use_sparse_grads,
|
||||
self.use_binary_values,
|
||||
name='sparse_mm',
|
||||
partition_axis=self.partition_axis,
|
||||
num_partitions=self.num_partitions,
|
||||
compress_ids=self._use_compression,
|
||||
cast_indices_dtype=self._cast_indices_dtype,
|
||||
use_binary_sparse_dense_matmul=self.use_binary_sparse_dense_matmul)
|
||||
|
||||
if self.bias is not None:
|
||||
outputs = tf.nn.bias_add(outputs, self.bias)
|
||||
|
||||
if self.activation is not None:
|
||||
return self.activation(outputs) # pylint: disable=not-callable
|
||||
return outputs
|
||||
|
||||
|
||||
def full_sparse(
|
||||
inputs, output_size,
|
||||
input_size=None,
|
||||
activation=None,
|
||||
bias_regularizer=None,
|
||||
weight_regularizer=None,
|
||||
bias_initializer=None,
|
||||
weight_initializer=None,
|
||||
trainable=True,
|
||||
name=None,
|
||||
reuse=None,
|
||||
use_sparse_grads=True,
|
||||
num_partitions=None,
|
||||
partition_axis=0,
|
||||
use_binary_values=False,
|
||||
use_compression=False):
|
||||
"""Functional interface for the sparsely-connected layer.
|
||||
|
||||
Arguments:
|
||||
inputs:
|
||||
A sparse tensor (can be twml.SparseTensor or tf.SparseTensor)
|
||||
output_size:
|
||||
Long or Integer, dimensionality of the output space.
|
||||
weight_initializer:
|
||||
Initializer function for the weight matrix.
|
||||
activation:
|
||||
Activation function (callable). Set it to None to maintain a linear activation.
|
||||
bias_initializer:
|
||||
Initializer function for the bias.
|
||||
weight_regularizer:
|
||||
Regularizer function for the weight matrix.
|
||||
Ensure to add tf.losses.get_regularization_loss() to your loss for this to take effect.
|
||||
bias_regularizer:
|
||||
Regularizer function for the bias.
|
||||
Ensure to add tf.losses.get_regularization_loss() to your loss for this to take effect.
|
||||
trainable:
|
||||
Boolean, if `True` also add variables to the graph collection
|
||||
``GraphKeys.TRAINABLE_VARIABLES`` (see `tf.Variable
|
||||
<https://www.tensorflow.org/versions/master/api_docs/python/tf/Variable>`_).
|
||||
name:
|
||||
String, the name of the layer. Layers with the same name will
|
||||
share weights, but to avoid mistakes we require ``reuse=True`` in such cases.
|
||||
use_sparse_grads:
|
||||
Boolean, if `True` do sparse mat mul with `embedding_lookup_sparse`, which will
|
||||
make gradients to weight matrix also sparse in backward pass. This can lead to non-trivial
|
||||
speed up at training time when input_size is large and optimizer handles sparse gradients
|
||||
correctly (eg. with SGD or LazyAdamOptimizer). If weight matrix is small, it's recommended
|
||||
to set this flag to `False`; for most use cases of FullSparse, however, weight matrix will
|
||||
be large, so it's better to set it to `True`
|
||||
num_partitions:
|
||||
Number of partitions to use for the weight variable. Defaults to 1.
|
||||
partition_axis:
|
||||
If num_partitions is specified, the partition axis for the weight variable
|
||||
Defaults to 0 (partition by row).
|
||||
Must be 0 (row) or 1 (column)
|
||||
use_binary_values:
|
||||
Assume all non zero values are 1. Defaults to False.
|
||||
This can improve training if used in conjunction with MDL.
|
||||
use_compression:
|
||||
Default False. Set True to enable data compression techniques for
|
||||
optimization of network traffic for distributed training.
|
||||
Returns:
|
||||
Outputs a ``tf.Tensor`` of size ``[batch_size x output_size]``.
|
||||
"""
|
||||
# TODO - remove input_size warning.
|
||||
if input_size:
|
||||
raise ValueError('input_size is deprecated - it is now \
|
||||
automatically inferred from your input.')
|
||||
|
||||
dtype = None
|
||||
if isinstance(inputs, twml.SparseTensor):
|
||||
inputs = inputs.to_tf()
|
||||
dtype = inputs.dtype.base_dtype
|
||||
|
||||
if isinstance(inputs, (list, tuple)):
|
||||
inputs = [inp.to_tf() if isinstance(inp, twml.SparseTensor) else inp for inp in inputs]
|
||||
dtype = inputs[0].dtype.base_dtype
|
||||
|
||||
layer = FullSparse(output_size=output_size,
|
||||
activation=activation,
|
||||
trainable=trainable,
|
||||
name=name,
|
||||
weight_initializer=weight_initializer,
|
||||
bias_initializer=bias_initializer,
|
||||
weight_regularizer=weight_regularizer,
|
||||
bias_regularizer=bias_regularizer,
|
||||
dtype=dtype,
|
||||
_scope=name,
|
||||
_reuse=reuse,
|
||||
use_sparse_grads=use_sparse_grads,
|
||||
num_partitions=num_partitions,
|
||||
partition_axis=partition_axis,
|
||||
use_compression=use_compression,
|
||||
use_binary_values=use_binary_values)
|
||||
return layer(inputs)
|
76
twml/twml/layers/isotonic.py
Normal file
76
twml/twml/layers/isotonic.py
Normal file
@ -0,0 +1,76 @@
|
||||
# pylint: disable=no-member, invalid-name, attribute-defined-outside-init
|
||||
"""
|
||||
Contains the Isotonic Layer
|
||||
"""
|
||||
|
||||
from .layer import Layer
|
||||
|
||||
import libtwml
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Isotonic(Layer):
|
||||
"""
|
||||
This layer is created by the IsotonicCalibrator.
|
||||
Typically it is used intead of sigmoid activation on the output unit.
|
||||
|
||||
Arguments:
|
||||
n_unit:
|
||||
number of input units to the layer (same as number of output units).
|
||||
n_bin:
|
||||
number of bins used for isotonic calibration.
|
||||
More bins means a more precise isotonic function.
|
||||
Less bins means a more regularized isotonic function.
|
||||
xs_input:
|
||||
A tensor containing the boundaries of the bins.
|
||||
ys_input:
|
||||
A tensor containing calibrated values for the corresponding bins.
|
||||
|
||||
Output:
|
||||
output:
|
||||
A layer containing calibrated probabilities with same shape and size as input.
|
||||
Expected Sizes:
|
||||
xs_input, ys_input:
|
||||
[n_unit, n_bin].
|
||||
Expected Types:
|
||||
xs_input, ys_input:
|
||||
same as input.
|
||||
"""
|
||||
|
||||
def __init__(self, n_unit, n_bin, xs_input=None, ys_input=None, **kwargs):
|
||||
super(Isotonic, self).__init__(**kwargs)
|
||||
|
||||
self._n_unit = n_unit
|
||||
self._n_bin = n_bin
|
||||
|
||||
self.xs_input = np.empty([n_unit, n_bin], dtype=np.float32) if xs_input is None else xs_input
|
||||
self.ys_input = np.empty([n_unit, n_bin], dtype=np.float32) if ys_input is None else ys_input
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
"""Computes the output shape of the layer given the input shape.
|
||||
|
||||
Args:
|
||||
input_shape: A (possibly nested tuple of) `TensorShape`. It need not
|
||||
be fully defined (e.g. the batch size may be unknown).
|
||||
|
||||
Raises NotImplementedError.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def build(self, input_shape): # pylint: disable=unused-argument
|
||||
"""Creates the variables of the layer."""
|
||||
|
||||
self.built = True
|
||||
|
||||
def call(self, inputs, **kwargs): # pylint: disable=unused-argument
|
||||
"""The logic of the layer lives here.
|
||||
|
||||
Arguments:
|
||||
inputs: input tensor(s).
|
||||
|
||||
Returns:
|
||||
The output from the layer
|
||||
"""
|
||||
calibrate_op = libtwml.ops.isotonic_calibration(inputs, self.xs_input, self.ys_input)
|
||||
return calibrate_op
|
50
twml/twml/layers/layer.py
Normal file
50
twml/twml/layers/layer.py
Normal file
@ -0,0 +1,50 @@
|
||||
# pylint: disable=no-member
|
||||
"""
|
||||
Implementing a base layer for twml
|
||||
"""
|
||||
import tensorflow.compat.v1 as tf
|
||||
from tensorflow.python.layers import base
|
||||
|
||||
|
||||
class Layer(base.Layer):
|
||||
"""
|
||||
Base Layer implementation for twml.
|
||||
Overloads `twml.layers.Layer
|
||||
<https://www.tensorflow.org/versions/master/api_docs/python/tf/layers/Layer>`_
|
||||
from tensorflow and adds a couple of custom methods.
|
||||
"""
|
||||
|
||||
@property
|
||||
def init(self):
|
||||
"""
|
||||
Return initializer ops. By default returns tf.no_op().
|
||||
This method is overwritten by classes like twml.layers.MDL, which
|
||||
uses a HashTable internally, that must be initialized with its own op.
|
||||
"""
|
||||
return tf.no_op()
|
||||
|
||||
def call(self, inputs, **kwargs):
|
||||
"""The logic of the layer lives here.
|
||||
|
||||
Arguments:
|
||||
inputs:
|
||||
input tensor(s).
|
||||
**kwargs:
|
||||
additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
Output tensor(s).
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
"""Computes the output shape of the layer given the input shape.
|
||||
|
||||
Args:
|
||||
input_shape: A (possibly nested tuple of) `TensorShape`. It need not
|
||||
be fully defined (e.g. the batch size may be unknown).
|
||||
|
||||
Raise NotImplementedError.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
256
twml/twml/layers/mdl.py
Normal file
256
twml/twml/layers/mdl.py
Normal file
@ -0,0 +1,256 @@
|
||||
# pylint: disable=no-member, attribute-defined-outside-init, too-many-instance-attributes
|
||||
"""
|
||||
Implementing MDL Layer
|
||||
"""
|
||||
|
||||
|
||||
from .layer import Layer
|
||||
from .partition import Partition
|
||||
from .stitch import Stitch
|
||||
|
||||
import libtwml
|
||||
import numpy as np
|
||||
import tensorflow.compat.v1 as tf
|
||||
import twml
|
||||
|
||||
|
||||
class MDL(Layer): # noqa: T000
|
||||
"""
|
||||
MDL layer is constructed by MDLCalibrator after accumulating data
|
||||
and performing minimum description length (MDL) calibration.
|
||||
|
||||
MDL takes sparse continuous features and converts then to sparse
|
||||
binary features. Each binary output feature is associated to an MDL bin.
|
||||
Each MDL input feature is converted to n_bin bins.
|
||||
Each MDL calibration tries to find bin delimiters such that the number of features values
|
||||
per bin is roughly equal (for each given MDL feature).
|
||||
Note that if an input feature is rarely used, so will its associated output bin/features.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_feature, n_bin, out_bits,
|
||||
bin_values=None, hash_keys=None, hash_values=None,
|
||||
bin_ids=None, feature_offsets=None, **kwargs):
|
||||
"""
|
||||
Creates a non-initialized `MDL` object.
|
||||
Before using the table you will have to initialize it. After initialization
|
||||
the table will be immutable.
|
||||
|
||||
Parent class args:
|
||||
see [tf.layers.Layer](https://www.tensorflow.org/api_docs/python/tf/layers/Layer)
|
||||
for documentation of parent class arguments.
|
||||
|
||||
Required args:
|
||||
n_feature:
|
||||
number of unique features accumulated during MDL calibration.
|
||||
This is the number of features in the hash map.
|
||||
Used to initialize bin_values, hash_keys, hash_values,
|
||||
bin_ids, bin_values and feature_offsets.
|
||||
n_bin:
|
||||
number of MDL bins used for MDL calibration.
|
||||
Used to initialize bin_values, hash_keys, hash_values,
|
||||
bin_ids, bin_values and feature_offsets.
|
||||
out_bits:
|
||||
Determines the maximum value for output feature IDs.
|
||||
The dense_shape of the SparseTensor returned by lookup(x)
|
||||
will be [x.shape[0], 1 << output_bits].
|
||||
|
||||
Optional args:
|
||||
hash_keys:
|
||||
contains the features ID that MDL discretizes and knows about.
|
||||
The hash map (hash_keys->hash_values) is used for two reasons:
|
||||
1. divide inputs into two feature spaces: MDL vs non-MDL
|
||||
2. transate the MDL features into a hash_feature ID that MDL understands.
|
||||
The hash_map is expected to contain n_feature items.
|
||||
hash_values:
|
||||
translates the feature IDs into hash_feature IDs for MDL.
|
||||
bin_ids:
|
||||
a 1D Tensor of size n_feature * n_bin + 1 which contains
|
||||
unique IDs to which the MDL features will be translated to.
|
||||
For example, tf.Tensor(np.arange(n_feature * n_bin)) would produce
|
||||
the most efficient output space.
|
||||
bin_values:
|
||||
a 1D Tensor aligned with bin_ids.
|
||||
For a given hash_feature ID j, it's value bin's are indexed between
|
||||
`j*n_bin` and `j*n_bin + n_bin-1`.
|
||||
As such, bin_ids[j*n_bin+i] is translated from a hash_feature ID of j
|
||||
and a inputs value between
|
||||
`bin_values[j*n_bin + i]` and `bin_values[j*n_bin+i+1]`.
|
||||
feature_offsets:
|
||||
a 1D Tensor specifying the starting location of bins for a given feature id.
|
||||
For example, tf.Tensor(np.arange(0, bin_values.size, n_bin, dtype='int64')).
|
||||
"""
|
||||
super(MDL, self).__init__(**kwargs)
|
||||
tf.logging.warning("MDL will be deprecated. Please use PercentileDiscretizer instead")
|
||||
|
||||
max_mdl_feature = n_feature * (n_bin + 1)
|
||||
self._n_feature = n_feature
|
||||
self._n_bin = n_bin
|
||||
|
||||
self._hash_keys_initializer = tf.constant_initializer(
|
||||
hash_keys if hash_keys is not None
|
||||
else np.empty(n_feature, dtype=np.int64),
|
||||
dtype=np.int64
|
||||
)
|
||||
self._hash_values_initializer = tf.constant_initializer(
|
||||
hash_values if hash_values is not None
|
||||
else np.empty(n_feature, dtype=np.int64),
|
||||
dtype=np.int64
|
||||
)
|
||||
self._bin_ids_initializer = tf.constant_initializer(
|
||||
bin_ids if bin_ids is not None
|
||||
else np.empty(max_mdl_feature, dtype=np.int64),
|
||||
dtype=np.int64
|
||||
)
|
||||
self._bin_values_initializer = tf.constant_initializer(
|
||||
bin_values if bin_values is not None
|
||||
else np.empty(max_mdl_feature, dtype=np.float32),
|
||||
dtype=np.float32
|
||||
)
|
||||
self._feature_offsets_initializer = tf.constant_initializer(
|
||||
feature_offsets if feature_offsets is not None
|
||||
else np.empty(n_feature, dtype=np.int64),
|
||||
dtype=np.int64
|
||||
)
|
||||
|
||||
# note that calling build here is an exception as typically __call__ would call build().
|
||||
# We call it here because we need to initialize hash_map.
|
||||
# Also note that the variable_scope is set by add_variable in build()
|
||||
if not self.built:
|
||||
self.build(input_shape=None)
|
||||
|
||||
self.output_size = tf.convert_to_tensor(1 << out_bits, tf.int64)
|
||||
|
||||
def build(self, input_shape): # pylint: disable=unused-argument
|
||||
"""
|
||||
Creates the variables of the layer:
|
||||
hash_keys, hash_values, bin_ids, bin_values, feature_offsets and self.output_size.
|
||||
"""
|
||||
|
||||
# build layers
|
||||
self.partition = Partition()
|
||||
self.stitch = Stitch()
|
||||
|
||||
# build variables
|
||||
|
||||
hash_keys = self.add_variable(
|
||||
'hash_keys',
|
||||
initializer=self._hash_keys_initializer,
|
||||
shape=[self._n_feature],
|
||||
dtype=tf.int64,
|
||||
trainable=False)
|
||||
|
||||
hash_values = self.add_variable(
|
||||
'hash_values',
|
||||
initializer=self._hash_values_initializer,
|
||||
shape=[self._n_feature],
|
||||
dtype=tf.int64,
|
||||
trainable=False)
|
||||
|
||||
# hashmap converts known features into range [0, n_feature)
|
||||
initializer = tf.lookup.KeyValueTensorInitializer(hash_keys, hash_values)
|
||||
self.hash_map = tf.lookup.StaticHashTable(initializer, -1)
|
||||
|
||||
self.bin_ids = self.add_variable(
|
||||
'bin_ids',
|
||||
initializer=self._bin_ids_initializer,
|
||||
shape=[self._n_feature * (self._n_bin + 1)],
|
||||
dtype=tf.int64,
|
||||
trainable=False)
|
||||
|
||||
self.bin_values = self.add_variable(
|
||||
'bin_values',
|
||||
initializer=self._bin_values_initializer,
|
||||
shape=[self._n_feature * (self._n_bin + 1)],
|
||||
dtype=tf.float32,
|
||||
trainable=False)
|
||||
|
||||
self.feature_offsets = self.add_variable(
|
||||
'feature_offsets',
|
||||
initializer=self._feature_offsets_initializer,
|
||||
shape=[self._n_feature],
|
||||
dtype=tf.int64,
|
||||
trainable=False)
|
||||
|
||||
# make sure this is last
|
||||
self.built = True
|
||||
|
||||
def call(self, inputs, **kwargs):
|
||||
"""Looks up `keys` in a table, outputs the corresponding values.
|
||||
|
||||
Implements MDL inference where inputs are intersected with a hash_map.
|
||||
Part of the inputs are discretized using twml.mdl to produce a mdl_output SparseTensor.
|
||||
This SparseTensor is then joined with the original inputs SparseTensor,
|
||||
but only for the inputs keys that did not get discretized.
|
||||
|
||||
Args:
|
||||
inputs: A 2D SparseTensor that is input to MDL for discretization.
|
||||
It has a dense_shape of [batch_size, input_size]
|
||||
name: A name for the operation (optional).
|
||||
Returns:
|
||||
A `SparseTensor` of the same type as `inputs`.
|
||||
Its dense_shape is [shape_input.dense_shape[0], 1 << output_bits].
|
||||
"""
|
||||
if isinstance(inputs, tf.SparseTensor):
|
||||
inputs = twml.SparseTensor.from_tf(inputs)
|
||||
|
||||
assert(isinstance(inputs, twml.SparseTensor))
|
||||
|
||||
# sparse column indices
|
||||
ids = inputs.ids
|
||||
# sparse row indices
|
||||
keys = inputs.indices
|
||||
# sparse values
|
||||
vals = inputs.values
|
||||
|
||||
# get intersect(keys, hash_map)
|
||||
hashed_keys = self.hash_map.lookup(keys)
|
||||
|
||||
found = tf.not_equal(hashed_keys, tf.constant(-1, tf.int64))
|
||||
partition_ids = tf.cast(found, tf.int32)
|
||||
|
||||
vals, key, indices = self.partition(partition_ids, vals, tf.where(found, hashed_keys, keys))
|
||||
non_mdl_keys, mdl_in_keys = key
|
||||
non_mdl_vals, mdl_in_vals = vals
|
||||
|
||||
self.non_mdl_keys = non_mdl_keys
|
||||
|
||||
# run MDL on the keys/values it knows about
|
||||
mdl_keys, mdl_vals = libtwml.ops.mdl(mdl_in_keys, mdl_in_vals, self.bin_ids, self.bin_values,
|
||||
self.feature_offsets)
|
||||
|
||||
# handle output ID conflicts
|
||||
mdl_size = tf.size(self.bin_ids, out_type=tf.int64)
|
||||
non_mdl_size = tf.subtract(self.output_size, mdl_size)
|
||||
non_mdl_keys = tf.add(tf.floormod(non_mdl_keys, non_mdl_size), mdl_size)
|
||||
|
||||
# Stitch the keys and values from mdl and non mdl indices back, with help
|
||||
# of the Stitch Layer
|
||||
|
||||
# out for inference checking
|
||||
self.mdl_out_keys = mdl_keys
|
||||
|
||||
concat_data = self.stitch([non_mdl_vals, mdl_vals],
|
||||
[non_mdl_keys, mdl_keys],
|
||||
indices)
|
||||
|
||||
concat_vals, concat_keys = concat_data
|
||||
|
||||
# Generate output shape using _compute_output_shape
|
||||
|
||||
batch_size = tf.to_int64(inputs.dense_shape[0])
|
||||
output_shape = [batch_size, self.output_size]
|
||||
return twml.SparseTensor(ids, concat_keys, concat_vals, output_shape).to_tf()
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
"""Computes the output shape of the layer given the input shape.
|
||||
|
||||
Args:
|
||||
input_shape: A (possibly nested tuple of) `TensorShape`. It need not
|
||||
be fully defined (e.g. the batch size may be unknown).
|
||||
|
||||
Raises NotImplementedError.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
74
twml/twml/layers/partition.py
Normal file
74
twml/twml/layers/partition.py
Normal file
@ -0,0 +1,74 @@
|
||||
"""
|
||||
Implementing partition Layer
|
||||
"""
|
||||
|
||||
|
||||
from .layer import Layer
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
|
||||
class Partition(Layer):
|
||||
"""
|
||||
This layer implements:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
tf.dynamic_partition(input_vals, partition_ids, self.partitions)
|
||||
|
||||
Input:
|
||||
partitions:
|
||||
the number of partitions which we will divide the hashmap keys/bvalues
|
||||
|
||||
Output:
|
||||
A layer that performs partitioning
|
||||
"""
|
||||
|
||||
def __init__(self, partitions=2, **kwargs):
|
||||
self.partitions = partitions
|
||||
super(Partition, self).__init__(**kwargs)
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
"""Computes the output shape of the layer given the input shape.
|
||||
|
||||
Args:
|
||||
input_shape: A (possibly nested tuple of) `TensorShape`. It need not
|
||||
be fully defined (e.g. the batch size may be unknown).
|
||||
|
||||
Raises NotImplementedError.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def call(self, partition_ids, input_vals, input_keys, **kwargs):
|
||||
"""This layer is responsible for partitioning the values/keys of a hashmap
|
||||
|
||||
Arguments:
|
||||
partition_ids:
|
||||
Tensor that is equivalent to boolean (int32).
|
||||
input_vals:
|
||||
Tensor that represents the values of the hashmap(float).
|
||||
input_keys:
|
||||
Tensor that represents the keys of the hashmap(float)
|
||||
|
||||
Returns:
|
||||
The output of the partition layer, which is a list of lists which looks
|
||||
something like:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
[[vals_0, vals_1], [keys_0, keys_1], [indices_0, indices_1]]
|
||||
|
||||
where:
|
||||
vals_x:
|
||||
values of the hashmap for partition x
|
||||
keys_x:
|
||||
keys of the hashmap for partition x
|
||||
indices_x:
|
||||
indices of the hashmap for partition x
|
||||
"""
|
||||
partioned_val = tf.dynamic_partition(input_vals, partition_ids, self.partitions)
|
||||
partioned_keys = tf.dynamic_partition(input_keys, partition_ids, self.partitions)
|
||||
partioned_indices = tf.dynamic_partition(tf.range(tf.shape(partition_ids)[0]),
|
||||
tf.cast(partition_ids, tf.int32), self.partitions)
|
||||
return [partioned_val, partioned_keys, partioned_indices]
|
209
twml/twml/layers/percentile_discretizer.py
Normal file
209
twml/twml/layers/percentile_discretizer.py
Normal file
@ -0,0 +1,209 @@
|
||||
# pylint: disable=no-member, attribute-defined-outside-init, too-many-instance-attributes
|
||||
"""
|
||||
Implementing PercentileDiscretizer Layer
|
||||
"""
|
||||
|
||||
|
||||
import libtwml
|
||||
import numpy as np
|
||||
import tensorflow.compat.v1 as tf
|
||||
import twml
|
||||
from twml.layers import Layer
|
||||
|
||||
|
||||
class PercentileDiscretizer(Layer):
|
||||
"""
|
||||
PercentileDiscretizer layer is constructed by PercentileDiscretizerCalibrator after
|
||||
accumulating data and performing percentile bucket calibration.
|
||||
|
||||
PercentileDiscretizer takes sparse continuous features and converts then to sparse
|
||||
binary features. Each binary output feature is associated to an PercentileDiscretizer bin.
|
||||
Each PercentileDiscretizer input feature is converted to n_bin bins.
|
||||
Each PercentileDiscretizer calibration tries to find bin delimiters such
|
||||
that the number of features values per bin is roughly equal (for
|
||||
each given PercentileDiscretizer feature). In other words, bins are calibrated to be approx.
|
||||
equiprobable, according to the given calibration data.
|
||||
Note that if an input feature is rarely used, so will its associated output bin/features.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_feature, n_bin, out_bits,
|
||||
bin_values=None, hash_keys=None, hash_values=None,
|
||||
bin_ids=None, feature_offsets=None, num_parts=1, cost_per_unit=100, **kwargs):
|
||||
"""
|
||||
Creates a non-initialized `PercentileDiscretizer` object.
|
||||
Before using the table you will have to initialize it. After initialization
|
||||
the table will be immutable.
|
||||
|
||||
If there are no calibrated features, then the discretizer will only apply
|
||||
twml.util.limit_bits to the the feature keys (aka "feature_ids"). Essentially,
|
||||
the discretizer will be a "no-operation", other than obeying `out_bits`
|
||||
|
||||
Parent class args:
|
||||
see [tf.layers.Layer](https://www.tensorflow.org/api_docs/python/tf/layers/Layer)
|
||||
for documentation of parent class arguments.
|
||||
|
||||
Required args:
|
||||
n_feature:
|
||||
number of unique features accumulated during PercentileDiscretizer calibration.
|
||||
This is the number of features in the hash map.
|
||||
Used to initialize bin_values, hash_keys, hash_values,
|
||||
bin_ids, bin_values and feature_offsets.
|
||||
n_bin:
|
||||
number of PercentileDiscretizer bins used for PercentileDiscretizer calibration.
|
||||
Used to initialize bin_values, hash_keys, hash_values,
|
||||
bin_ids, bin_values and feature_offsets.
|
||||
out_bits:
|
||||
Determines the maximum value for output feature IDs.
|
||||
The dense_shape of the SparseTensor returned by lookup(x)
|
||||
will be [x.shape[0], 1 << output_bits].
|
||||
|
||||
Optional args:
|
||||
hash_keys:
|
||||
contains the features ID that PercentileDiscretizer discretizes and knows about.
|
||||
The hash map (hash_keys->hash_values) is used for two reasons:
|
||||
1. divide inputs into two feature spaces:
|
||||
PercentileDiscretizer vs non-PercentileDiscretizer
|
||||
2. transate the PercentileDiscretizer features into a hash_feature ID that
|
||||
PercentileDiscretizer understands.
|
||||
The hash_map is expected to contain n_feature items.
|
||||
hash_values:
|
||||
translates the feature IDs into hash_feature IDs for PercentileDiscretizer.
|
||||
bin_ids:
|
||||
a 1D Tensor of size n_feature * n_bin + 1 which contains
|
||||
unique IDs to which the PercentileDiscretizer features will be translated to.
|
||||
For example, tf.Tensor(np.arange(n_feature * n_bin)) would produce
|
||||
the most efficient output space.
|
||||
bin_values:
|
||||
a 1D Tensor aligned with bin_ids.
|
||||
For a given hash_feature ID j, it's value bin's are indexed between
|
||||
`j*n_bin` and `j*n_bin + n_bin-1`.
|
||||
As such, bin_ids[j*n_bin+i] is translated from a hash_feature ID of j
|
||||
and a inputs value between
|
||||
`bin_values[j*n_bin + i]` and `bin_values[j*n_bin+i+1]`.
|
||||
feature_offsets:
|
||||
a 1D Tensor specifying the starting location of bins for a given feature id.
|
||||
For example, tf.Tensor(np.arange(0, bin_values.size, n_bin, dtype='int64')).
|
||||
"""
|
||||
|
||||
super(PercentileDiscretizer, self).__init__(**kwargs)
|
||||
|
||||
if not self.built:
|
||||
self.build(input_shape=None)
|
||||
|
||||
max_discretizer_feature = n_feature * (n_bin + 1)
|
||||
self._n_feature = n_feature
|
||||
self._n_bin = n_bin
|
||||
|
||||
# build variables
|
||||
self._out_bits = out_bits
|
||||
self._output_size = tf.convert_to_tensor(1 << out_bits, tf.int64)
|
||||
self._hash_keys = (hash_keys if hash_keys is not None else
|
||||
np.empty(n_feature, dtype=np.int64))
|
||||
self._hash_values = (hash_values if hash_values is not None else
|
||||
np.empty(n_feature, dtype=np.int64))
|
||||
self._bin_ids = (bin_ids if bin_ids is not None else
|
||||
np.empty(max_discretizer_feature, dtype=np.int64))
|
||||
self._bin_values = (bin_values if bin_values is not None else
|
||||
np.empty(max_discretizer_feature, dtype=np.float32))
|
||||
self._feature_offsets = (feature_offsets if feature_offsets is not None else
|
||||
np.empty(n_feature, dtype=np.int64))
|
||||
self.num_parts = num_parts
|
||||
self.cost_per_unit = cost_per_unit
|
||||
|
||||
def build(self, input_shape): # pylint: disable=unused-argument
|
||||
"""
|
||||
Creates the variables of the layer
|
||||
"""
|
||||
self.built = True
|
||||
|
||||
def call(self, inputs, keep_inputs=False, **kwargs):
|
||||
"""Looks up `keys` in a table, outputs the corresponding values.
|
||||
|
||||
Implements PercentileDiscretizer inference where inputs are intersected with a hash_map.
|
||||
Input features that were not calibrated have their feature IDs truncated, so as
|
||||
to be less than 1<<output_bits, but their values remain untouched (not discretized)
|
||||
|
||||
If there are no calibrated features, then the discretizer will only apply
|
||||
twml.util.limit_bits to the the feature keys (aka "feature_ids"). Essentially,
|
||||
the discretizer will be a "no-operation", other than obeying `out_bits`
|
||||
|
||||
Args:
|
||||
inputs: A 2D SparseTensor that is input to PercentileDiscretizer for discretization.
|
||||
It has a dense_shape of [batch_size, input_size]
|
||||
keep_inputs:
|
||||
Include the original inputs in the output.
|
||||
Note - if True, undiscretized features will be passed through, but will have
|
||||
their values doubled (unless there are no calibrated features to discretize).
|
||||
name: A name for the operation (optional).
|
||||
Returns:
|
||||
A `SparseTensor` of the same type as `inputs`.
|
||||
Its dense_shape is [shape_input.dense_shape[0], 1 << output_bits].
|
||||
"""
|
||||
|
||||
if isinstance(inputs, tf.SparseTensor):
|
||||
inputs = twml.SparseTensor.from_tf(inputs)
|
||||
|
||||
assert(isinstance(inputs, twml.SparseTensor))
|
||||
|
||||
# sparse column indices
|
||||
ids = inputs.ids
|
||||
# sparse row indices
|
||||
keys = inputs.indices
|
||||
# sparse values
|
||||
vals = inputs.values
|
||||
|
||||
if self._n_feature > 0:
|
||||
discretizer_keys, discretizer_vals = libtwml.ops.percentile_discretizer_v2(
|
||||
input_ids=keys, # inc key assigned to feature_id, or -1
|
||||
input_vals=vals, # the observed feature values
|
||||
bin_ids=self._bin_ids, # n_feat X (n_bin+1) 2D arange
|
||||
bin_vals=self._bin_values, # bin boundaries
|
||||
feature_offsets=self._feature_offsets, # 0 : nbin_1 : max_feat
|
||||
output_bits=self._out_bits,
|
||||
feature_ids=tf.make_tensor_proto(self._hash_keys), # feature ids to build internal hash map
|
||||
feature_indices=tf.make_tensor_proto(self._hash_values), # keys associated w/ feat. indices
|
||||
start_compute=tf.constant(0, shape=[], dtype=tf.int64),
|
||||
end_compute=tf.constant(-1, shape=[], dtype=tf.int64),
|
||||
cost_per_unit=self.cost_per_unit
|
||||
)
|
||||
else:
|
||||
discretizer_keys = twml.util.limit_bits(keys, self._out_bits)
|
||||
discretizer_vals = vals
|
||||
# don't 2x the input.
|
||||
keep_inputs = False
|
||||
|
||||
batch_size = tf.to_int64(inputs.dense_shape[0])
|
||||
output_shape = [batch_size, self._output_size]
|
||||
|
||||
output = twml.SparseTensor(ids, discretizer_keys, discretizer_vals, output_shape).to_tf()
|
||||
|
||||
if keep_inputs:
|
||||
# Note the non-discretized features will end up doubled,
|
||||
# since these are already in `output`
|
||||
# handle output ID conflicts
|
||||
mdl_size = self._n_feature * (self._n_bin + 1)
|
||||
non_mdl_size = tf.subtract(self._output_size, mdl_size)
|
||||
input_keys = tf.add(tf.floormod(keys, non_mdl_size), mdl_size)
|
||||
|
||||
new_input = twml.SparseTensor(
|
||||
ids=ids, indices=input_keys, values=vals, dense_shape=output_shape).to_tf()
|
||||
|
||||
# concatenate discretizer output with original input
|
||||
sparse_add = tf.sparse_add(new_input, output)
|
||||
output = tf.SparseTensor(sparse_add.indices, sparse_add.values, output_shape)
|
||||
|
||||
return output
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
"""Computes the output shape of the layer given the input shape.
|
||||
|
||||
Args:
|
||||
input_shape: A (possibly nested tuple of) `TensorShape`. It need not
|
||||
be fully defined (e.g. the batch size may be unknown).
|
||||
|
||||
Raises NotImplementedError.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
160
twml/twml/layers/sequential.py
Normal file
160
twml/twml/layers/sequential.py
Normal file
@ -0,0 +1,160 @@
|
||||
"""
|
||||
Implementing Sequential Layer container
|
||||
"""
|
||||
|
||||
|
||||
from .layer import Layer
|
||||
|
||||
from tensorflow import keras
|
||||
from tensorflow.python.layers import base
|
||||
|
||||
|
||||
class Sequential(Layer):
|
||||
"""
|
||||
A sequential stack of layers.
|
||||
|
||||
Arguments:
|
||||
layers: list of layers to add to the model.
|
||||
|
||||
Output:
|
||||
the output of the sequential layers
|
||||
"""
|
||||
|
||||
def __init__(self, layers=None, **kwargs):
|
||||
self._layers = [] # Stack of layers.
|
||||
self._layer_names = [] # Stack of layers names
|
||||
self._layer_outputs = []
|
||||
# Add to the model any layers passed to the constructor.
|
||||
if layers:
|
||||
for layer in layers:
|
||||
self.add(layer)
|
||||
super(Sequential, self).__init__(**kwargs)
|
||||
|
||||
def add(self, layer):
|
||||
"""Adds a layer instance on top of the layer stack.
|
||||
|
||||
Arguments:
|
||||
layer:
|
||||
layer instance.
|
||||
|
||||
Raises:
|
||||
TypeError:
|
||||
if the layer argument is not instance of base.Layer
|
||||
"""
|
||||
if not isinstance(layer, base.Layer) and not isinstance(layer, keras.layers.Layer):
|
||||
raise TypeError('The added layer must be an instance of class Layer')
|
||||
|
||||
if layer.name in self._layer_names:
|
||||
raise ValueError('Layer with name %s already exists in sequential layer' % layer.name)
|
||||
|
||||
self._layers.append(layer)
|
||||
self._layer_names.append(layer.name)
|
||||
|
||||
def pop(self):
|
||||
"""Removes the last layer in the model.
|
||||
|
||||
Raises:
|
||||
TypeError:
|
||||
if there are no layers in the model.
|
||||
"""
|
||||
if not self._layers or not self._layer_names:
|
||||
raise TypeError('There are no layers in the model.')
|
||||
self._layers.pop()
|
||||
self._layer_names.pop()
|
||||
|
||||
def call(self, inputs, **kwargs): # pylint: disable=unused-argument
|
||||
"""The logic of the layer lives here.
|
||||
|
||||
Arguments:
|
||||
inputs:
|
||||
input tensor(s).
|
||||
|
||||
Returns:
|
||||
The output of the sequential layers
|
||||
"""
|
||||
self._layer_outputs = []
|
||||
for layer in self._layers:
|
||||
# don't use layer.call because you want to build individual layers
|
||||
inputs = layer(inputs) # overwrites the current input after it has been processed
|
||||
self._layer_outputs.append(inputs)
|
||||
return inputs
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
""" Return the layers in the sequential layer """
|
||||
return self._layers
|
||||
|
||||
@property
|
||||
def layer_names(self):
|
||||
""" Return the layer names in the sequential layer """
|
||||
return self._layer_names
|
||||
|
||||
@property
|
||||
def layer_outputs(self):
|
||||
""" Return the layer outputs in the sequential layer """
|
||||
return self._layer_outputs
|
||||
|
||||
def get(self, key):
|
||||
"""Retrieves the n-th layer.
|
||||
|
||||
Arguments:
|
||||
key:
|
||||
index of the layer
|
||||
|
||||
Output:
|
||||
The n-th layer where n is equal to the key.
|
||||
"""
|
||||
return self._layers[key]
|
||||
|
||||
def get_output(self, key):
|
||||
"""Retrieves the n-th layer output.
|
||||
|
||||
Arguments:
|
||||
key:
|
||||
index of the layer
|
||||
|
||||
Output:
|
||||
The intermediary output equivalent to the nth layer, where n is equal to the key.
|
||||
"""
|
||||
return self._layer_outputs[key]
|
||||
|
||||
def get_layer_by_name(self, name):
|
||||
"""Retrieves the layer corresponding to the name.
|
||||
|
||||
Arguments:
|
||||
name:
|
||||
name of the layer
|
||||
|
||||
Output:
|
||||
list of layers that have the name desired
|
||||
"""
|
||||
return self._layers[self._layer_names.index(name)]
|
||||
|
||||
def get_layer_output_by_name(self, name):
|
||||
"""Retrieves the layer output corresponding to the name.
|
||||
|
||||
Arguments:
|
||||
name:
|
||||
name of the layer
|
||||
|
||||
Output:
|
||||
list of the output of the layers that have the desired name
|
||||
"""
|
||||
return self._layer_outputs[self._layer_names.index(name)]
|
||||
|
||||
@property
|
||||
def init(self):
|
||||
""" returns a list of initialization ops (one per layer) """
|
||||
return [layer.init for layer in self._layers]
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
"""Computes the output shape of the layer given the input shape.
|
||||
|
||||
Args:
|
||||
input_shape: A (possibly nested tuple of) `TensorShape`. It need not
|
||||
be fully defined (e.g. the batch size may be unknown).
|
||||
|
||||
Raise NotImplementedError.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
221
twml/twml/layers/sparse_max_norm.py
Normal file
221
twml/twml/layers/sparse_max_norm.py
Normal file
@ -0,0 +1,221 @@
|
||||
# pylint: disable=no-member, attribute-defined-outside-init, duplicate-code
|
||||
"""
|
||||
Contains the twml.layers.SparseMaxNorm layer.
|
||||
"""
|
||||
from .layer import Layer
|
||||
|
||||
from libtwml import OPLIB
|
||||
import tensorflow.compat.v1 as tf
|
||||
import twml
|
||||
|
||||
|
||||
class SparseMaxNorm(Layer):
|
||||
"""
|
||||
Computes a max-normalization and adds bias to the sparse_input,
|
||||
forwards that through a sparse affine transform followed
|
||||
by an non-linear activation on the resulting dense representation.
|
||||
|
||||
This layer has two parameters, one of which learns through gradient descent:
|
||||
bias_x (optional):
|
||||
vector of shape [input_size]. Learned through gradient descent.
|
||||
max_x:
|
||||
vector of shape [input_size]. Holds the maximas of input ``x`` for normalization.
|
||||
Either calibrated through SparseMaxNorm calibrator, or calibrated online, or both.
|
||||
|
||||
The pseudo-code for this layer looks like:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
abs_x = abs(x)
|
||||
normed_x = clip_by_value(x / max_x, -1, 1)
|
||||
biased_x = normed_x + bias_x
|
||||
return biased
|
||||
|
||||
|
||||
Args:
|
||||
max_x_initializer:
|
||||
initializer vector of shape [input_size] used by variable `max_x`
|
||||
bias_x_initializer:
|
||||
initializer vector of shape [input_size] used by parameter `bias_x`
|
||||
is_training:
|
||||
Are we training the layer to learn the normalization maximas.
|
||||
If set to True, max_x will be able to learn. This is independent of bias_x
|
||||
epsilon:
|
||||
The minimum value used for max_x. Defaults to 1E-5.
|
||||
use_bias:
|
||||
Default True. Set to False to not use a bias term.
|
||||
|
||||
Returns:
|
||||
A layer representing the output of the sparse_max_norm transformation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size=None,
|
||||
max_x_initializer=None,
|
||||
bias_x_initializer=None,
|
||||
is_training=True,
|
||||
epsilon=1E-5,
|
||||
use_bias=True,
|
||||
**kwargs):
|
||||
|
||||
super(SparseMaxNorm, self).__init__(**kwargs)
|
||||
if input_size:
|
||||
raise ValueError('input_size is deprecated - it is now automatically \
|
||||
inferred from your input.')
|
||||
if max_x_initializer is None:
|
||||
max_x_initializer = tf.zeros_initializer()
|
||||
self.max_x_initializer = max_x_initializer
|
||||
|
||||
self._use_bias = use_bias
|
||||
if use_bias:
|
||||
if bias_x_initializer is None:
|
||||
bias_x_initializer = tf.zeros_initializer()
|
||||
self.bias_x_initializer = bias_x_initializer
|
||||
|
||||
self.epsilon = epsilon
|
||||
self.is_training = is_training
|
||||
|
||||
def build(self, input_shape): # pylint: disable=unused-argument
|
||||
"""Creates the max_x and bias_x tf.Variables of the layer."""
|
||||
|
||||
self.max_x = self.add_variable(
|
||||
'max_x',
|
||||
initializer=self.max_x_initializer,
|
||||
shape=[input_shape[1]],
|
||||
dtype=tf.float32,
|
||||
trainable=False)
|
||||
|
||||
if self._use_bias:
|
||||
self.bias_x = self.add_variable(
|
||||
'bias_x',
|
||||
initializer=self.bias_x_initializer,
|
||||
shape=[input_shape[1]],
|
||||
dtype=tf.float32,
|
||||
trainable=True)
|
||||
|
||||
self.built = True
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
"""Computes the output shape of the layer given the input shape.
|
||||
|
||||
Args:
|
||||
input_shape: A (possibly nested tuple of) `TensorShape`. It need not
|
||||
be fully defined (e.g. the batch size may be unknown).
|
||||
|
||||
Raises NotImplementedError.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _call(self, inputs, **kwargs): # pylint: disable=unused-argument
|
||||
"""
|
||||
The forward propagation logic of the layer lives here.
|
||||
|
||||
Arguments:
|
||||
sparse_input:
|
||||
A 2D ``tf.SparseTensor`` of dense_shape ``[batch_size, input_size]``
|
||||
Returns:
|
||||
A ``tf.SparseTensor`` representing the output of the max_norm transformation, this can
|
||||
be fed into twml.layers.FullSparse in order to be transformed into a ``tf.Tensor``.
|
||||
"""
|
||||
|
||||
if isinstance(inputs, twml.SparseTensor):
|
||||
inputs = inputs.to_tf()
|
||||
elif not isinstance(inputs, tf.SparseTensor):
|
||||
raise TypeError("The inputs must be of type tf.SparseTensor or twml.SparseTensor")
|
||||
|
||||
indices_x = inputs.indices[:, 1]
|
||||
values_x = inputs.values
|
||||
|
||||
if self.is_training is False:
|
||||
normalized_x = OPLIB.sparse_max_norm_inference(self.max_x,
|
||||
indices_x,
|
||||
values_x,
|
||||
self.epsilon)
|
||||
|
||||
update_op = tf.no_op()
|
||||
else:
|
||||
max_x, normalized_x = OPLIB.sparse_max_norm_training(self.max_x,
|
||||
indices_x,
|
||||
values_x,
|
||||
self.epsilon)
|
||||
|
||||
update_op = tf.assign(self.max_x, max_x)
|
||||
|
||||
with tf.control_dependencies([update_op]):
|
||||
normalized_x = tf.stop_gradient(normalized_x)
|
||||
|
||||
# add input bias
|
||||
if self._use_bias:
|
||||
normalized_x = normalized_x + tf.gather(self.bias_x, indices_x)
|
||||
|
||||
# convert back to sparse tensor
|
||||
return tf.SparseTensor(inputs.indices, normalized_x, inputs.dense_shape)
|
||||
|
||||
def call(self, inputs, **kwargs): # pylint: disable=unused-argument
|
||||
"""
|
||||
The forward propagation logic of the layer lives here.
|
||||
|
||||
Arguments:
|
||||
sparse_input:
|
||||
A 2D ``tf.SparseTensor`` of dense_shape ``[batch_size, input_size]``
|
||||
Returns:
|
||||
A ``tf.SparseTensor`` representing the output of the max_norm transformation, this can
|
||||
be fed into twml.layers.FullSparse in order to be transformed into a ``tf.Tensor``.
|
||||
"""
|
||||
with tf.device(self.max_x.device):
|
||||
return self._call(inputs, **kwargs)
|
||||
|
||||
# For backwards compatiblity and also because I don't want to change all the tests.
|
||||
MaxNorm = SparseMaxNorm
|
||||
|
||||
|
||||
def sparse_max_norm(inputs,
|
||||
input_size=None,
|
||||
max_x_initializer=None,
|
||||
bias_x_initializer=None,
|
||||
is_training=True,
|
||||
epsilon=1E-5,
|
||||
use_bias=True,
|
||||
name=None,
|
||||
reuse=None):
|
||||
"""
|
||||
Functional inteface to SparseMaxNorm.
|
||||
|
||||
Args:
|
||||
inputs:
|
||||
A sparse tensor (can be twml.SparseTensor or tf.SparseTensor)
|
||||
input_size:
|
||||
number of input units
|
||||
max_x_initializer:
|
||||
initializer vector of shape [input_size] used by variable `max_x`
|
||||
bias_x_initializer:
|
||||
initializer vector of shape [input_size] used by parameter `bias_x`
|
||||
is_training:
|
||||
Are we training the layer to learn the normalization maximas.
|
||||
If set to True, max_x will be able to learn. This is independent of bias_x
|
||||
epsilon:
|
||||
The minimum value used for max_x. Defaults to 1E-5.
|
||||
use_bias:
|
||||
Default True. Set to False to not use a bias term.
|
||||
|
||||
Returns:
|
||||
Output after normalizing with the max value.
|
||||
"""
|
||||
if input_size:
|
||||
raise ValueError('input_size is deprecated - it is now automatically \
|
||||
inferred from your input.')
|
||||
|
||||
if isinstance(inputs, twml.SparseTensor):
|
||||
inputs = inputs.to_tf()
|
||||
|
||||
layer = SparseMaxNorm(max_x_initializer=max_x_initializer,
|
||||
bias_x_initializer=bias_x_initializer,
|
||||
is_training=is_training,
|
||||
epsilon=epsilon,
|
||||
use_bias=use_bias,
|
||||
name=name,
|
||||
_scope=name,
|
||||
_reuse=reuse)
|
||||
return layer(inputs)
|
54
twml/twml/layers/stitch.py
Normal file
54
twml/twml/layers/stitch.py
Normal file
@ -0,0 +1,54 @@
|
||||
# pylint: disable=useless-super-delegation
|
||||
"""
|
||||
Implementing Stitch Layer
|
||||
"""
|
||||
|
||||
|
||||
from .layer import Layer
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
|
||||
class Stitch(Layer):
|
||||
"""
|
||||
This layer is responsible for stitching a partioned layer together.
|
||||
|
||||
Output:
|
||||
A layer that performs stitching
|
||||
"""
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
"""Computes the output shape of the layer given the input shape.
|
||||
|
||||
Args:
|
||||
input_shape: A (possibly nested tuple of) `TensorShape`. It need not
|
||||
be fully defined (e.g. the batch size may be unknown).
|
||||
|
||||
Raises NotImplementedError.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def call(self, partioned_val, partioned_keys,
|
||||
partioned_indices, **kwargs): # pylint: disable=unused-argument, arguments-differ
|
||||
"""
|
||||
This layer is responsible for stitching a partioned layer together.
|
||||
|
||||
Input:
|
||||
partioned_val:
|
||||
a list of partioned Tensors which represent the vals of the hashmap
|
||||
partioned_keys:
|
||||
a list of partioned Tensors which represent the keys of the hashmap
|
||||
partioned_indices:
|
||||
a list of partioned Tensors which represent the indices of the hashmap
|
||||
Output:
|
||||
List which contains: [output_vals, output_keys]
|
||||
output_vals:
|
||||
Values of the HashMap (float)
|
||||
output_keys:
|
||||
Keys of HashMap (float)
|
||||
"""
|
||||
indices = [tf.to_int32(index) for index in partioned_indices]
|
||||
concat_keys = tf.dynamic_stitch(indices, partioned_keys)
|
||||
concat_vals = tf.dynamic_stitch(indices, partioned_val)
|
||||
return [concat_vals, concat_keys]
|
Reference in New Issue
Block a user