TF Estimators KDD Paper

Download as pdf or txt
Download as pdf or txt
You are on page 1of 9

TensorFlow Estimators: Managing Simplicity vs.

Flexibility in
High-Level Machine Learning Frameworks
Heng-Tze Cheng† Zakaria Haque† Lichan Hong† Mustafa Ispir† Clemens Mewald†∗
Illia Polosukhin† Georgios Roumpos† D Sculley† Jamie Smith† David Soergel†
Yuan Tang‡ Philipp Tucker† Martin Wicke†∗ Cassandra Xia† Jianwei Xie†
† ‡
Google, Inc. Uptake Technologies, Inc.

ABSTRACT Engineers working on production systems have only re-


We present a framework for specifying, training, evaluating, cently gained the ability to apply advanced machine learning,
and deploying machine learning models. Our focus is on driven in large part by the availability of machine learning
simplifying cutting edge machine learning for practitioners in frameworks that implement the lower level numerical com-
order to bring such technologies into production. Recognizing putations in efficient ways and allow engineers to focus on
the fast evolution of the field of deep learning, we make no application-specific logic (see e.g., [2–5, 7, 8, 11, 14, 17–20]).
attempt to capture the design space of all possible model However, the huge amounts of data involved in training, espe-
architectures in a domain- specific language (DSL) or similar cially for deep learning models, as well as the complications
configuration language. We allow users to write code to define of running high intensity computations efficiently on hetero-
their models, but provide abstractions that guide develop- geneous and distributed systems, has prevented the most
ers to write models in ways conducive to productionization. advanced methods from being widely adopted in production.
We also provide a unifying Estimator interface, making it As the field of deep learning is still young and develop-
possible to write downstream infrastructure (e.g. distributed ing fast, any framework hoping to remain relevant must be
training, hyperparameter tuning) independent of the model expressive enough to not only represent today’s model archi-
implementation. tectures, but also next year’s. If the framework is to be used
We balance the competing demands for flexibility and for experimentation with model architectures (most serious
simplicity by offering APIs at different levels of abstraction, product work requires at least some experimentation), it is
making common model architectures available out of the also crucial to offer the flexibility to change details of mod-
box, while providing a library of utilities designed to speed els without having to change components that are deeply
up experimentation with model architectures. To make out embedded, and which have a highly optimized, low level
of the box models flexible and usable across a wide range implementation.
of problems, these canned Estimators are parameterized There is a natural tension between such flexibility on the
not only over traditional hyperparameters, but also using one hand, and simplicity and robustness on the other hand.
feature columns, a declarative specification describing how to We use simplicity in a broad sense: From a practitioner’s
interpret input data. point of view, implementing models should not require funda-
We discuss our experience in using this framework in re- mentally new skills, assuming that the model architecture is
search and production environments, and show the impact known. Experimenting with model features should be trans-
on code health, maintainability, and development speed. parent, and should not require deep insights into the inner
workings of the framework used to implement the model. We
talk of robustness both as a quality of the software develop-
1 INTRODUCTION
ment process, as well as a quality of the resulting software.
Machine learning, and in particular, deep learning, is a field We call a framework robust if it is easy to write correct and
of growing importance. With the deployment of large GPU high-quality software using it, but hard to write broken or
clusters in datacenters and cloud computing services, it is poorly performing software. A framework which nudges the
now possible to apply these methods not only in theory, but developer to use best practices, and which makes it hard to
integrate them successfully into production systems. “shoot yourself in the foot” is robust.

Corresponding authors: {clemensm,wicke}@google.com Because of the need to keep up with and enable research,
many deep learning frameworks value flexibility above all
Permission to make digital or hard copies of part or all of this work
for personal or classroom use is granted without fee provided that else (e.g., [2, 11, 20]). They achieve this flexibility by pro-
copies are not made or distributed for profit or commercial advantage viding relatively low-level primitive operations (e.g., matmul,
and that copies bear this notice and the full citation on the first page. add, tanh), and require the user to write code in a regular
Copyrights for third-party components of this work must be honored.
For all other uses, contact the owner/author(s). programming language in order to specify their model. To
KDD’17, August 13–17, 2017, Halifax, NS, Canada. simplify life for their users and speed up development, these
© 2017 Copyright held by the owner/author(s). 978-1-4503-4887- frameworks often provide some higher level components, such
4/17/08.
DOI: http://dx.doi.org/10.1145/3097983.3098171 as layers (e.g., a fully connected neural network layer with an
optional activation function). Development in a fully-fledged
programming language is inherently dangerous. Working at
a low level can also lead to a lot of code duplication, with Some of the lower level components such as layers are
the software maintenance headaches that come with that. closely related in similar frameworks aimed at simplifying
On the other end of the spectrum are systems which use model construction [10, 15, 16, 21].
a DSL to describe the model architecture (e.g., [3, 5, 13, The highest level object in our framework is an Estimator,
17]). Such systems are more likely to be geared for specific which provides an interface similar to that of Scikit-learn [19],
production use cases. They can make common cases very with some adaptations to simplify productionization. Scikit-
simple to implement (the most common models may even be learn has been used in a large number of small to medium
built-in primitives). Their higher level of abstraction allows scale machine learning tasks. Using a widely known interface
these frameworks to make optimizations that are inaccessible allows practitioners who are not specialists in TensorFlow to
to their more flexible peers. They are also robust: users are start working productively immediately.
strongly guided towards model architectures that work, and it In the remainder of the paper, we will first discuss the
is hard to write down models that are fundamentally broken. overall design of our framework (Sec. 2), before describing
Apart from the lack of flexibility when it comes to new model in detail all major components (Sec. 3) and our mechanisms
types and architectures, these DSL based systems can be for distributed computations (Sec. 4). We then discuss case
hard to maintain in the face of an inexorably advancing body studies and show experimental results (Sec. 5).
of new research. Adding more and more primitives to a DSL,
or adding more and more options to existing primitives can 2 DESIGN OVERVIEW
be fatal. Google’s own experience with such a system [13] The design of our framework is guided by the overarching
prompted the development of TensorFlow [2]. principle that users should be led to best practices, with-
TensorFlow is an open source software library for machine out having to abandon established idioms wherever this is
learning, and especially deep learning. It represents compu- possible. Because our framework is built on TensorFlow,
tation as a generalized data flow graph. The graph is first we inherit a number of common design patterns: there is
built, and then executed separately from graph construction. a preference for functions and closures over objects, wher-
Operations such as mul, add, etc., are represented as nodes in ever such closures are sufficient; callbacks are common. Our
the graph. Edges represent the data flowing between nodes layer design is informed by the underlying TensorFlow style:
as a Tensor containing a multi-dimensional array. In the our layer functions are also tensor-in-tensor-out operations.
following, we use op and Tensor interchangeably to denote These preferences are stylistic in nature and have no impact
a node in the graph (op) and the output that is created on the performance or expressivity of the framework, but they
when the node is executed. Most ops are stateless tensor- allow users to easily transition if they are used to working
in-tensor-out functions. State is represented in the graph as with TensorFlow.
Variables, special stateful ops. Users can assign ops and Because one of the greatest strengths of TensorFlow is
variables to any device. A device can be a CPU, GPU, TPU, its flexibility, it is crucial for us to not restrict what users
and can live on the local machine or a remote TensorFlow can accomplish. While we provide guides that nudge people
server. TensorFlow then seamlessly handles communication to best practices, we provide escape hatches and extension
between these devices. This is one of the most powerful points that allow users to use the full power of TensorFlow
aspects of TensorFlow, and we rely on it heavily to enable whenever they need to.
scaling models from a single machine to datacenter-scale. Our requirements include simplifying model building in
The framework described in this paper is implemented on general, offering a harness that encourages best practices
top of TensorFlow1 , and has been made available as part of and guides users to a production-ready implementation, as
the TensorFlow open-source project. Faced with competing well as implementing the most common types of machine
demands, our goal is to provide users with utilities that sim- learning model architectures, and providing an interface for
plify common use cases while still allowing access to the full developers of downstream frameworks and infrastructure. We
generality of TensorFlow. Consequently, we do not attempt are therefore dealing with three distinct (but not necessarily
to capture the design space of machine learning algorithms in disjoint) classes of users: users who want to build custom
a DSL. Instead, we offer a harness which removes boilerplate machine learning models, users who want to use common
by providing best practice implementations of common code models, and users who want to build infrastructure using the
patterns. The components we provide are reusable, and inte- concept of a model, but without knowledge of the specifics.
gration points for users are strategically placed to encourage These user classes inform the high level structure of our
reusable user code. The user configuration is performed by framework. At the heart is the Estimator class (see Section
writing regular TensorFlow code, but a number of lower level 3.2). Its interface (modeled after the eponymous concept in
TensorFlow concepts are safely encapsulated and users do not Scikit-learn [19]) provides an abstraction for a machine learn-
have to reason about them, eliminating a source of common ing model, detailed enough to allow for downstream infras-
problems. tructure to be written, but general enough to not constrain
1
While we hope that our description of the features in this paper the type of model represented by an Estimator. Estimators
is largely self-contained, basic familiarity with TensorFlow will give are given input by a user-defined input function. We provide
valuable context to the reader.
implementations for common types of inputs (e.g., input from
numpy [12]).
The Estimator itself is configured using the model fn, of TensorFlow. Because layers accept and produce regular
a function which builds a TensorFlow graph and returns Tensors, layers and regular TensorFlow ops can be mixed
the information necessary to train a model, evaluate it, and without requiring special care.
predict with it. Users writing custom Estimators only have to We implement layer functions with best practices in mind:
implement this function. It is possible, and in fact, common, layers are generally wrapped in a variable scope. This
that model fn contains regular TensorFlow code that does ensures that they are properly grouped in the TensorBoard
not use any other component of our framework. This is visualization tool, which is essential when inspecting large
often the case because existing models are being adapted or models. All variables that are created as part of a layer are
converted to be implemented in terms of an Estimator. We obtained using get variable, which ensures that variables
do provide a number of utilities to simplify building models, can be reused or shared in different parts of the model. All
which can be used independently of Estimator (see Sec. 3.1). layers assume that the first dimension of input tensors is
This mutual independence of the abstraction layers is an the batch dimension, and accept variable batch size input.
important feature of our design, as it enables users to choose This allows changing the batch size as a hyperparameter
freely the level of abstraction best suited for the problem at during tuning, and it ensures that the model can be reused
hand. for inference, where inputs don’t necessarily arrive in batches.
It is worth noting that an Estimator can be constructed As an example, let’s create a simple convolutional net to
from a Keras Model. Users of this compatibility feature classify an image. The network comprises three convolutional
cannot use all features of Estimator (in particular, one cannot and three pooling layers, as well as a final fully connected
specify a separate inference graph with this method), but it layer. We have set sensible defaults on many arguments, so
is nevertheless useful for comparisons, and to use existing the invocations are compact unless uncommon behavior is
models inside downstream infrastructure (such as [6]). desired:
We also provide a number of Estimator implementations
for common machine learning algorithms, which we called 1 # Input images as a 4D tensor (batch, width,
Canned Estimators (these are subclasses of Estimator, see 2 # height, and channels)
Section 3.3). In our implementations, we use the same mech- 3 net = inputs
anisms that a user who writes a custom model would use. 4 # instantiate 3 convolutional layers with pooling
This ensures that we are users of our own framework. To 5 for _ in range(3):
make them useful for a wide variety of problems, canned 6 net = layers.conv2d(net,
Estimators expose a number of configuration options, the 7 filters=4,
most important of which is the ability to specify input struc- 8 kernel_size=3,
ture using feature columns. 9 activation=relu)
10 net = layers.max_pooling2d(net,
3 COMPONENTS 11 pool_size=2,
In this section we will describe in detail the various compo- 12 strides=1)
nents that make up our framework and their relationships. 13 logits = layers.dense(net, units=num_classes)
We start with layers, lower-level utilities that can be used in-
dependently of Estimator, before discussing various aspects We separate out some classes of layers that share a more
of Estimator itself. restricted interface. Losses are functions which take an
input, a label, and a weight, and return a scalar loss. These
3.1 Layers functions, such as l1 loss or l2 loss are used to produce a
One of the advantages of Deep Learning is that common loss for optimization.
model architectures are built up from composable parts. For Metrics are another special class of layers commonly used
deep neural networks, the smallest of these components are in evaluation: they take again a label, a prediction, and op-
called network layers, and we have adopted this name even tionally a weight, and compute a metric such as log-likelihood,
though the concept is more widely applicable. A layer is accuracy, or a simple mean squared error. While superficially
simply a reusable part of code, and can be as simple as a similar to losses, they support aggregating a metric across
fully connected neural network layer or as complex as a full many minibatches, an important feature whenever the evalu-
inception network. We provide a library of layers which is ation dataset does not fit into memory. Metrics return two
well tested and whose implementation follow best practices. Tensors: update op, which should be run for each minibatch,
We have given our layers a consistent interface in order to and a value op which computes the final metric value. The
ease the cognitive burden on users. In our framework, layers update op does not return a value, and only updates internal
are implemented as free functions, taking Tensors as input variables, aggregating the new information contained in the
arguments (along with other parameters), and returning input minibatch. The value op uses only the internal state
Tensors. TensorFlow itself contains a large number of ops to compute a metric value and returns it. The Estimator’s
that behave in the same manner, so layers are a natural evaluation functionality relies on this usage pattern (see be-
extension of TensorFlow and should feel natural to users low). Properly implementing metrics is nontrivial, and our
experience shows that metrics that are naively implemented
loops is highly suboptimal in terms of performance. Making
this cost very visible discourages users from accidentally
writing badly performing code.
A schematic of Estimator can be found in Figure 1. Be-
low, we first describe how to provide inputs to the train,
evaluate, and predict methods using input functions. Then
we discuss model specification with model fn, followed by
how to specify outputs within the model fn using Heads.

Specifying inputs with input fn. The methods train,


evaluate, and predict all take an input function, which is
expected to produce two dictionaries: one containing Tensors
with inputs (features), and one containing Tensors with
Figure 1: Simplified overview of the Estimator inter- labels. Whenever a method of Estimator is called, a new
face. graph is created, the input fn passed as an argument to
the method call is called to produce the input pipeline of
the Estimator, and then the model fn is called with the
from scratch lead to problems when using large datasets
appropriate mode argument to build the actual model graph.
(using TensorFlow queues in evaluation requires extra finesse
Decoupling the core model from input processing allows
to avoid losing examples to logging or TensorBoard summary
users to easily swap datasets. If used in larger infrastructure,
writing).
being able to control the inputs completely is very valuable
to downstream frameworks. A typical input fn has the
3.2 Estimator
following form:
At the heart of our framework is Estimator, a class that
both provides an interface for downstream infrastructure, as
1 def my_input_fn(file_pattern):
well as a convenient harness for developers. The interface
2 feature_dict = learn.io.read_batch_features(
for users of Estimator is loosely modeled after Scikit-learn
3 # path to data in tf.Example format
and consists of only four methods: train trains the model,
4 file_pattern=file_pattern,
given training data. evaluate computes evaluation metrics
5 batch_size=BATCH_SIZE,
over test data, predict performs inference on new data given
6 # whether sparse or dense ...
a trained model, and finally, export savedmodel exports a
7 features=FEATURE_SPEC,
SavedModel, a serialization format which allows the model to
8 # such as TFRecordReader
be used in TensorFlow Serving, a prebuilt production server
9 reader=READER,
for TensorFlow models [1].
10 ...)
The user configures an Estimator by passing a callback,
11
the model fn, to the constructor. When one of its methods
12 estimator.train(input_fn=lambda:
is called, Estimator creates a TensorFlow graph, sets up
13 my_input_fn(TRAINING_FILES), ...)
the input pipeline specified by the user in the arguments to
14 estimator.evaluate(input_fn=lambda:
the method (see Sec. 3.2), and then calls the model fn with
15 my_input_fn(EVAL_FILES), ...)
appropriate arguments to generate the graph representing
the model. The Estimator class itself contains the necessary
code to run a training or evaluation loop, to predict using Specifying the model with model fn. We chose to
a trained model, or to export a prediction model for use in configure Estimator with a single callback, the model fn,
production. which returns ops for training, evaluation, or prediction, de-
Estimator hides some TensorFlow concepts, such as Graph pending on which graph is being requested (which method
and Session, from the user. The Estimator constructor of Estimator is being called). For example, if the train
also receives a configuration object called RunConfig which method is called, model fn will be called with an argument
communicates everything that this Estimator needs to know mode=TRAIN, which the user can then use to build a custom
about the environment in which the model will be run: how graph in the knowledge that it is going to be used for training.
many workers are available, how often to save intermediate Conceptually, three entirely different graphs can be built,
checkpoints, etc. and different information is returned, depending on the mode
To ensure encapsulation, Estimator creates a new graph, parameter representing the called method. Nevertheless, we
and possibly restores from checkpoint, every time a method found it useful to require only a single function for configura-
is called. Rebuilding the graph is expensive, and it could tion. One of the main sources of error in production systems
be cached to make it more economical to run, say, evaluate is training/serving skew. One type of training/serving skew
or predict in a loop. However, we found it very useful to happens when a different model is trained than is later served
explicitly recreate the graph, trading off performance for in production. Of course, models are routinely trained slightly
clarity. Even if we did not rebuild the graph, writing such
differently than they are served. For instance, dropout and a simple model with two multi class objectives can look like
batch normalization layers are only active during training. this:
However, it is easy to make mistakes if one has to rewrite the
whole model three times. Therefore we chose to require a sin- 1 def model_fn(features, target, mode, params):
gle function, effectively encouraging the model developer to 2 last_layer = tf.stack(tf.fully_connected,
write the model only once. For complex models, appropriate 3 [50, 50])
Python conditionals can be used to ensure that legitimate 4 head1 = tf.multi_class_head(n_classes=2,
differences are explicitly represented in the model. A typical 5 label_name=’y’, head_name=’h1’)
model fn for a simple model may look like this: 6 head2 = tf.multi_class_head(n_classes=10,
7 label_name=’z’, head_name=’h2’)
1 def model_fn(features, target, mode, params): 8 head = tf.multi_head([head1, head2])
2 predictions = tf.stack(tf.fully_connected, 9 return head.create_model_fn_ops(features,
3 [50, 50, 1]) 10 features, mode, last_layer,
4 loss = tf.losses.mean_squared_error(target, 11 label=target,
5 predictions) 12 train_op_fn=lambda loss:
6 train_op = tf.train.create_train_op( 13 my_optimizer.minimize(
7 loss, tf.train.get_global_step(), 14 loss, tf.train.get_global_step())
8 params[’learning_rate’], params[’optimizer’])
9 return EstimatorSpec(mode=mode,
10 predictions=predictions, Executing computations. Once the graph is built, the
11 loss=loss, Estimator then initializes a Session, prepares it appropri-
12 train_op=train_op) ately, and runs the training loop, evaluation loop, or iterates
over the inputs to produce predictions.
Most machine learning algorithms are iterative nonlin-
Specifying outputs with Heads. The Head API is an ear optimizations, and therefore have a particularly simple
abstraction for the part of the model behind the last hidden algorithmic form: a single loop which runs the same com-
layer. The key goals of the design are to simplify writing putation over and over again, with different input data in
model fn, to be compatible with a wide range of models, each iteration. When used during training, this is called the
and to simplify supporting multiple heads. A Head knows training loop. In evaluation using mini-batches, much the
how to compute loss, relevant evaluation metrics, predictions same structure is used, except that variables are not updated,
and metadata about the predictions that other systems (like and typically, more metrics than just the loss are computed.
serving, model validation) can use. To support different types An idealized training loop implemented in TensorFlow is
of models (e.g., DNN, linear, Wide & Deep [9], gradient simple: start a Session, then run a training op in a loop.
boosted trees, etc.), Head takes logits and labels as input and However, we have to at least initialize variables and special
generates Tensors for loss, metrics, and predictions. Heads data structures like tables which are used in embeddings.
can also take the activation of the last hidden layer as input Queue runners (implemented as Python threads) have to be
to support DNN with large number of classes where we want started, and should be stopped at the end to ensure a clean
to avoid computing the full logit Tensor. A typical model fn exit. Summaries (which provide data to the TensorBoard
for a simple single objective model may look like this: visualization tool) have to be computed and written to file.
The real challenge begins when distributed training is taken
1 def model_fn(features, target, mode, params): into account. While TensorFlow takes care of distribution
2 last_layer = tf.stack(tf.fully_connected, of the computation and communication between workers,
3 [50, 50]) it requires many coordinated steps before a model can be
4 head = tf.multi_class_head(n_classes=10) successfully trained. The distributed computation introduces
5 return head.create_estimator_spec( a number of opportunities for users to make mistakes: certain
6 features, mode, last_layer, variables must be initialized on all workers, most only on one.
7 label=target, The model state should be saved periodically to ensure that
8 train_op_fn=lambda loss: the computation can recover when workers go down, and
9 my_optimizer.minimize( needs to be recovered safely when they restart. End-of-input
10 loss, tf.train.get_global_step()) signals have to be handled gracefully.
Because the training loop is so ubiquitous, a good imple-
The abstraction is designed in a way that combining multi- mentation removes a lot of duplicated user code. Because
ple Heads for multi objective learning is as simple as creating it is simple only in theory, we can remove a source of error
a special type of Head with a list of other heads. Model func- and frustration for users. Therefore, Estimator implements
tions can take Head as a parameter while remaining agnostic and controls the training loop. It automatically assigns
to what kind of Head they are using. A typical model fn for Variables to parameter servers to simplify distributed com-
putation, and it gives the user only limited access to the
underlying TensorFlow primitives. Users must specify the
graph, and the op(s) to run in each iteration, and they may We offer the FeatureColumn abstraction to simplify input
override the device placement. ingestion. FeatureColumns are a declarative way of specify-
ing inputs. Canned Estimators take FeatureColumns as a
Code injection using Hooks. Hooks make it impossible to constructor argument and handle the conversion of sparse
implement advanced optimization techniques that break the or dense features of all types to a dense Tensor usable by
simple loop abstraction in a safe manner. They are also useful the core model. As an example, the following code shows
for custom processing that has to happen alongside the main a canned Estimator implementation for the Wide & Deep
loop, for recordkeeping, debugging, monitoring or reporting. architecture [9]. The deep part of the model uses embeddings
Hooks let users define custom behaviour at Session creation, while the linear part uses the crosses of base features.
before and after each iteration, and at the end of training.
They also let users add ops other than those specified by the
1 # Define wide model features and crosses.
model fn to be run within the same Session.run call. For
2 query_x_docid = crossed_column(
example, a user who wants to train not for a given number
3 ["query", "docid"], num_buckets)
of steps, but a given amount of wall time, could implement a
4 wide_cols = [query_x_docid, ...]
Hook as follows:
5

6 # Define deep model features and embeddings.


1 class TimeBasedStopHook(tf.train.SessionRunHook): 7 query = categorical_column_with_hash_bucket(
2 def begin(self): 8 "query", num_buckets)
3 self.started_at = time.time() 9 docid = categorical_column_with_hash_bucket(
4 def after_run(self, run_context, run_values): 10 "docid", num_buckets)
5 if time.time() - self.started_at >= TRAIN_TIME: 11 query_emb = embedding_column(query, dimension=32)
6 run_context.request_stop() 12 docid_emb = embedding_column(docid, dimension=32)
13 deep_cols = [query_emb, docid_emb, ...]
Hooks are activated by passing them to the train call. 14 # Define model structure and start training.
When the Hook shown above is passed to train, the model 15 estimator = DNNLinearCombinedClassifier(
training will end after the set time. Much of the functional- 16 wide_cols, deep_cols,
ity that Estimator provides (for instance, summaries, step 17 dnn_hidden_units=[500, 200, 100])
counting, and checkpointing) is internally implemented using 18 estimator.train(input_fn, ...)
such Hooks.

3.3 Canned Estimators 4 DISTRIBUTED EXECUTION


There are many model architectures commonly used by re- With the built-in functionalities and utilities mentioned above,
searchers and practitioners. We decided to provide those ar- Estimators are ready for training, evaluating and exporting
chitectures as canned Estimators so that users don’t need to the model on a single machine. For production usages and
rewrite the same models again and again. Canned Estimators models with large amounts of training data, utilities for
are a good example of how to use Estimator itself. They are distributed execution are also provided together with Estima-
direct subclasses of Estimator that only override their con- tors, which takes the advantage of TensorFlow’s distributed
structors. As such, users of canned Estimators would only training support. The core of distributed execution support is
need to know how to use an Estimator, and how to configure the Experiment class, which groups the Estimator with two
the canned Estimator. This means that canned Estimators input functions for training and evaluation. The architecture
are mainly restricted to define a canned model fn. There are is summarized in Figure 2.
two main reasons behind this restrictive design. First, we
are expecting an increasing number of canned Estimators to
be implemented. To minimize the cognitive load on users,
all these canned Estimators should behave identically. Sec-
ond, this restriction makes the canned Estimator developer
a user of Estimator. This leads to an implicit comprehensive
flexibility test of our API.
Neural networks rely on operations which take dense
Tensors and output dense Tensors. Many machine learn-
ing problems have sparse features such as query keywords,
product id, url, video id, etc. For models with many inputs,
specifying how these features are attached to the model often
consumes a large fraction of the total setup time. Based on Figure 2: Simplified overview of the Experiment in-
our experience, one of the most error prone parts of building a terface.
model is converting these features into a single dense Tensor.
In each TensorFlow cluster, there are several parameter all tasks sharing the same binary but running a different
servers and several worker tasks. Most workers are hand- mode, such as parameter server, training, or continual evalua-
ing the training process, which basically calls the Estimator tion. The runner is simply a utility method to construct the
train method with the training input fn. One of the workers RunConfig, e.g., by parsing the environment variable, and exe-
is designated leader and is responsible for managing check- cute the Experiment/Estimator with that RunConfig. With
points and other maintenance work. Currently, the primary this design, Experiment/Estimator could be easily shared by
mode of replica training in TensorFlow Estimators is between- various execution frameworks including end-to-end machine
graph replication and asynchronous training. However, it learning pipelines [6] and even hyper-parameters tuning.
could be easily extended to support other replicated training
settings. With this architecture, gradient descent training 5 CASE STUDIES AND ADOPTION
can be executed in parallel. For machine learning practitioners within Google, this frame-
We have evaluated scaling of TensorFlow Estimators by work has dramatically reduced the time to launch a working
running different numbers of workers with fixed numbers of model. Before TensorFlow Estimators, the typical model con-
parameter servers. We trained a DNN model on a large in- struction cycle involved writing custom TensorFlow code to
ternal recommendation dataset (100s of billions of examples) ingest and represent features (sparse features were especially
for 48 hours and present average number of training steps per tricky), construction of the model layers itself, establishing
second. Figure 3 shows that we achieve almost linear scaling training and validation loops, productionizing the system to
of global steps per second with the number of workers. run on distributed training clusters, adding evaluation met-
rics, debugging training NaNs, and debugging poor model
quality.
TensorFlow Estimators simplify or automate all but the
debugging steps. Estimators give the practitioner confidence
that, when debugging NaNs or poor quality, these problems
arise either from their choice of hyperparameters or their
choice of features — but not a bug in the wiring of the model
itself.
When TensorFlow Estimators became available, several
TensorFlow models under development greatly benefited from
transitioning to the framework. One multiclass classification
model attained 37% better model accuracy by switching from
a custom model that performed multiple logistic regressions
to a standard Estimator that properly used a softmax cross-
entropy loss — the switch also reduced lines of code required
from 800 to 200. A different TensorFlow CTR model was
Figure 3: Measuring scaling of DNN model training stuck in the debugging phase for several weeks, but was
implemented with TensorFlow Estimators, varying transitioned to the framework within two days and achieved
the number of workers. Shown are measurements as launchable offline metrics.
well as the theoretical perfect linear scaling. It is worth noting that using Estimators and the associ-
ated machinery also requires considerably less expertise than
There is a special worker handling the evaluation pro- would be required to implement the equivalent functionality
cess for the Experiment to evaluate the performance and from scratch. Recently, a cohort of Google data scientists
export the model. It runs in a continuous loop and calls the with limited Python experience and no TensorFlow experi-
Estimator evaluate method with the evaluation input fn. ence were able to bootstrap real models in a two-day class
In order to avoid race conditions and inconsistent model setting.
parameter states, the evaluation process always begins with
loading the latest checkpoint and calculates the evaluation 5.1 Experience in YouTube Watch Next
metrics based on the model parameters from that checkpoint. Using TensorFlow Estimators, we have productionized and
As a simple extension, the Experiment also supports the launched a deep model (DNNClassifier) in the Watch Next
evaluation with the training input fn, which is very useful video recommender system of YouTube. Watch Next is a
to detect overfitting in deep learning in practice. product recommending a ranked set of videos for a user
Furthermore, we also provide utilities, RunConfig and to choose from after the user is done watching the current
runner, to ease the way of using and configuring Experiment video. One unique aspect about our model is that the model
in a cluster for distributed training. RunConfig holds all the is trained over multiple days, with the training data being
execution related configuration the Experiment/Estimator continuously updated.
requires, including cluster specification, model output direc- Our input features consist of both sparse categorical fea-
tory, checkpoints configuration, etc. In particular, RunConfig tures and real-valued features. The sparse features are further
specifies the task type of the current task, which allows
framework allowed teams to build high-quality machine learn-
ing models within an average of one engineer-week, sometimes
as fast as within 2 hours. 74% of respondents say that de-
velopment with this framework is faster than other machine
learning APIs they used before. Most importantly, users
note that they can focus their time on the machine learning
problem as opposed to the implementation of underlying
basics. Among existing users, quick ramp-up, ease of use,
reuse of common code and readability of a commonly used
framework are the most frequently mentioned benefits.

REFERENCES
Figure 4: Current usage of Estimators at Google. [1] Running your models in production with TensorFlow
Serving. https : / / research . googleblog . com / 2016 / 02 /
running-your-models-in-production-with.html, accessed 2017-02-
transformed into embedding columns before being fed into the 08.
hidden layers. The FeatureColumn API greatly simplifies how [2] Martı́n Abadi, Paul Barham, Jianmin Chen, Zhifeng Chen, Andy
we construct the input layer of our model. Additionally, the Davis, Jeffrey Dean, Matthieu Devin, Sanjay Ghemawat, Geoffrey
Irving, Michael Isard, Manjunath Kudlur, Josh Levenberg, Rajat
train-to-serve support of TensorFlow Estimators considerably Monga, Sherry Moore, Derek Gordon Murray, Benoit Steiner,
reduced the engineering effort to productionize the Watch Paul A. Tucker, Vijay Vasudevan, Pete Warden, Martin Wicke,
Yuan Yu, and Xiaoqiang Zheng. 2016. TensorFlow: A System for
Next model. Furthermore, the Estimator framework made Large-Scale Machine Learning. In OSDI. 265–283.
it easy to implement new Estimators and experiment with [3] Amit Agarwal, Eldar Akchurin, Chris Basoglu, Guoguo Chen,
new model architectures such as multiple-objective learning Scott Cyphers, Jasha Droppo, Adam Eversole, Brian Guenter,
Mark Hillebrand, Ryan Hoens, Xuedong Huang, Zhiheng Huang,
to accommodate specific product needs. Vladimir Ivanov, Alexey Kamenev, Philipp Kranen, Oleksii
The initial version of the model pipeline was developed Kuchaiev, Wolfgang Manousek, Avner May, Bhaskar Mitra,
using low-level TensorFlow primitives prior to the release of Olivier Nano, Gaizka Navarro, Alexey Orlov, Marko Padmilac,
Hari Parthasarathi, Baolin Peng, Alexey Reznichenko, Frank
Estimators. While debugging why the model quality failed to Seide, Michael L. Seltzer, Malcolm Slaney, Andreas Stolcke,
match our expectation, we discovered critical bugs related to Yongqiang Wang, Huaming Wang, Kaisheng Yao, Dong Yu, Yu
Zhang, and Geoffrey Zweig. 2014. An Introduction to Com-
how the network layers were constructed and how the input putational Networks and the Computational Network Toolkit.
data were processed. Technical Report MSR-TR-2014-112. http://research.microsoft.
As an early adopter, Watch Next prompted the develop- com/apps/pubs/default.aspx?id=226641
[4] Rami Al-Rfou, Guillaume Alain, Amjad Almahairi, Christof
ment of missing features such as shared embedding columns. Angermueller, Dzmitry Bahdanau, Nicolas Ballas, Frédéric
Shared embedding columns allow multiple semantically simi- Bastien, Justin Bayer, Anatoly Belikov, Alexander Belopolsky,
lar features to share a common embedding space, with the Yoshua Bengio, Arnaud Bergeron, James Bergstra, Valentin Bis-
son, Josh Bleecher Snyder, Nicolas Bouchard, Nicolas Boulanger-
benefit of transfer learning across features and smaller model Lewandowski, Xavier Bouthillier, Alexandre de Brébisson, Olivier
size. Breuleux, Pierre-Luc Carrier, Kyunghyun Cho, Jan Chorowski,
Paul Christiano, Tim Cooijmans, Marc-Alexandre Côté, Myriam
Côté, Aaron Courville, Yann N. Dauphin, Olivier Delalleau, Julien
5.2 Adoption within Google Demouth, Guillaume Desjardins, Sander Dieleman, Laurent Dinh,
Mélanie Ducoffe, Vincent Dumoulin, Samira Ebrahimi Kahou,
Software engineers at Google have a variety of choices for Dumitru Erhan, Ziye Fan, Orhan Firat, Mathieu Germain, Xavier
how to implement their machine learning models. Before we Glorot, Ian Goodfellow, Matt Graham, Caglar Gulcehre, Philippe
developed the higher-level framework in TensorFlow, engi- Hamel, Iban Harlouchet, Jean-Philippe Heng, Balázs Hidasi, Sina
Honari, Arjun Jain, Sébastien Jean, Kai Jia, Mikhail Korobov,
neers were effectively forced to implement one-off versions of Vivek Kulkarni, Alex Lamb, Pascal Lamblin, Eric Larsen, César
the components in our framework. Laurent, Sean Lee, Simon Lefrancois, Simon Lemieux, Nicholas
An internal survey has shown that, since we introduced Léonard, Zhouhan Lin, Jesse A. Livezey, Cory Lorenz, Jeremiah
Lowin, Qianli Ma, Pierre-Antoine Manzagol, Olivier Mastropietro,
this framework and Estimators less than a year ago, close Robert T. McGibbon, Roland Memisevic, Bart van Merriënboer,
to 1,000 Estimators have been checked into the Google Vincent Michalski, Mehdi Mirza, Alberto Orlandi, Christopher
Pal, Razvan Pascanu, Mohammad Pezeshki, Colin Raffel, Daniel
codebase and more than 120,000 experiments have been Renshaw, Matthew Rocklin, Adriana Romero, Markus Roth,
recorded (an experiment in this context is a complete train- Peter Sadowski, John Salvatier, François Savard, Jan Schlüter,
ing run; not all runs are recorded, so the true number is John Schulman, Gabriel Schwartz, Iulian Vlad Serban, Dmitriy
Serdyuk, Samira Shabanian, Étienne Simon, Sigurd Spieckermann,
significantly higher). Of those, over half (57%) use imple- S. Ramana Subramanyam, Jakub Sygnowski, Jérémie Tanguay,
mentations of canned Estimators (e.g., LinearClassifier, Gijs van Tulder, Joseph Turian, Sebastian Urban, Pascal Vincent,
DNNLinearCombinedRegressor). There are now over 20 Esti- Francesco Visin, Harm de Vries, David Warde-Farley, Dustin J.
Webb, Matthew Willson, Kelvin Xu, Lijun Xue, Li Yao, Saizheng
mator classes implementing various standard machine learn- Zhang, and Ying Zhang. 2016. Theano: A Python framework for
ing algorithms in the TensorFlow code base. Examples in- fast computation of mathematical expressions. arXiv e-prints
clude DynamicRnnEstimator (implementing dynamically un- abs/1605.02688 (May 2016). http://arxiv.org/abs/1605.02688
[5] Amazon. 2016. Dsstne. https://github.com/amznlabs/amazon-
rolled RNNs for classification or regression problems) and dsstne. (2016).
TensorForestEstimator (implementing random forests). Fig- [6] Denis Baylor, Eric Breck, Heng-Tze Cheng, Noah Fiedel,
Chuan Yu Foo, Zakaria Haque, Salem Haykal, Mustafa Ispir,
ure 4 shows the current distribution of Estimator usage. This
Vihan Jain, Levent Koc, Chiu Yuen Koo, Lukasz Lew, Clemens Heilman, diogo149, Brian McFee, Hendrik Weideman, takacsg84,
Mewald, Akshay Naresh Modi, Neoklis Polyzotis, Sukriti Ramesh, peterderivaz, Jon, instagibbs, Dr. Kashif Rasul, CongLiu, Brite-
Sudip Roy, Steven Euijong Whang, Martin Wicke, Jarek fury, and Jonas Degrave. 2015. Lasagne: First release. (Aug.
Wilkiewicz, Xin Zhang, and Martin Zinkevich. 2017. The Anatomy 2015). DOI:http://dx.doi.org/10.5281/zenodo.27878
of a Production-Scale Continuously-Training ML Platform. KDD [16] Sergio Guadarrama and Nathan Silberman. 2016. TF Slim. https:
[under review]. (2017). //github.com/tensorflow/tensorflow/tree/master/tensorflow/
[7] Tianqi Chen and Carlos Guestrin. 2016. XGBoost: A Scalable contrib/slim. (2016).
Tree Boosting System. CoRR abs/1603.02754 (2016). http: [17] Yangqing Jia, Evan Shelhamer, Jeff Donahue, Sergey Karayev,
//arxiv.org/abs/1603.02754 Jonathan Long, Ross Girshick, Sergio Guadarrama, and Trevor
[8] Tianqi Chen, Mu Li, Yutian Li, Min Lin, Naiyan Wang, Minjie Darrell. 2014. Caffe: Convolutional Architecture for Fast Feature
Wang, Tianjun Xiao, Bing Xu, Chiyuan Zhang, and Zheng Zhang. Embedding. In Proceedings of the 22Nd ACM International
2015. MXNet: A Flexible and Efficient Machine Learning Library Conference on Multimedia (MM ’14). ACM, New York, NY,
for Heterogeneous Distributed Systems. CoRR abs/1512.01274 USA, 675–678. DOI:http://dx.doi.org/10.1145/2647868.2654889
(2015). http://arxiv.org/abs/1512.01274 [18] Xiangrui Meng, Joseph Bradley, Burak Yavuz, Evan Sparks,
[9] Heng-Tze Cheng, Levent Koc, Jeremiah Harmsen, Tal Shaked, Shivaram Venkataraman, Davies Liu, Jeremy Freeman, DB Tsai,
Tushar Chandra, Hrishi Aradhye, Glen Anderson, Greg Corrado, Manish Amde, Sean Owen, Doris Xin, Reynold Xin, Michael J.
Wei Chai, Mustafa Ispir, Rohan Anil, Zakaria Haque, Lichan Franklin, Reza Zadeh, Matei Zaharia, and Ameet Talwalkar. 2016.
Hong, Vihan Jain, Xiaobing Liu, and Hemal Shah. 2016. Wide & MLlib: Machine Learning in Apache Spark. J. Mach. Learn. Res.
Deep Learning for Recommender Systems. In DLRS. 7–10. 17, 1 (Jan. 2016), 1235–1241. http://dl.acm.org/citation.cfm?
[10] François Chollet. 2015. keras. https://github.com/fchollet/keras. id=2946645.2946679
(2015). [19] Fabian Pedregosa, Gaël Varoquaux, Alexandre Gramfort, Vin-
[11] Ronan Collobert, Samy Bengio, and Johnny Marithoz. 2002. cent Michel, Bertrand Thirion, Olivier Grisel, Mathieu Blondel,
Torch: A Modular Machine Learning Software Library. (2002). Peter Prettenhofer, Ron Weiss, Vincent Dubourg, Jake Vander-
[12] The Scipy community. 2012. NumPy Reference Guide. SciPy.org. plas, Alexandre Passos, David Cournapeau, Matthieu Brucher,
http://docs.scipy.org/doc/numpy/reference/ Matthieu Perrot, and Édouard Duchesnay. 2011. Scikit-learn: Ma-
[13] Jeffrey Dean, Greg S. Corrado, Rajat Monga, Kai Chen, Matthieu chine Learning in Python. J. Mach. Learn. Res. 12 (Nov. 2011),
Devin, Quoc V. Le, Mark Z. Mao, Marc’Aurelio Ranzato, Andrew 2825–2830. http://dl.acm.org/citation.cfm?id=1953048.2078195
Senior, Paul Tucker, Ke Yang, and Andrew Y. Ng. 2012. Large [20] Seiya Tokui, Kenta Oono, Shohei Hido, and Justin Clayton. 2015.
Scale Distributed Deep Networks. In Proceedings of the 25th Chainer: a Next-Generation Open Source Framework for Deep
International Conference on Neural Information Processing Learning. In Proceedings of Workshop on Machine Learning
Systems (NIPS’12). Curran Associates Inc., USA, 1223–1231. Systems (LearningSys) in The Twenty-ninth Annual Conference
http://dl.acm.org/citation.cfm?id=2999134.2999271 on Neural Information Processing Systems (NIPS). http://
[14] Deeplearning4j Development Team. 2016. Deeplearning4j: Open- learningsys.org/papers/LearningSys 2015 paper 33.pdf
source distributed deep learning for the JVM, Apache Software [21] Bart van Merriënboer, Dzmitry Bahdanau, Vincent Dumoulin,
Foundation License 2.0. http://deeplearning4j.org. (2016). Dmitriy Serdyuk, David Warde-Farley, Jan Chorowski, and
[15] Sander Dieleman, Jan Schlüter, Colin Raffel, Eben Olson, Yoshua Bengio. 2015. Blocks and Fuel: Frameworks for deep
Søren Kaae Sønderby, Daniel Nouri, Daniel Maturana, Martin learning. CoRR abs/1506.00619 (2015). http://arxiv.org/abs/
Thoma, Eric Battenberg, Jack Kelly, Jeffrey De Fauw, Michael 1506.00619

You might also like