0

I want to classify text to 2 classes by using this embedding: https://tfhub.dev/google/universal-sentence-encoder-multilingual/3

And I also want to add additional features after the embedding. So I have two inputs:

import tensorflow          as tf
import tensorflow_hub      as tfh
import tensorflow_datasets as tfds
import tensorflow_text     as tft


hp = {
    'embedding':     'https://tfhub.dev/google/universal-sentence-encoder-multilingual/3' EMBEDDINGS['senm'],
    'units':         64, 
    'learning_rate': 1e-3,
    'dropout':       0.2,
    'layers':        2
}

textInput     = tf.keras.Input(shape=(1, ),  name = 'text', dtype = tf.string)
featuresInput = tf.keras.Input(shape=(36, ), name = 'features')

x = tfh.KerasLayer(hp.get('embedding'), dtype = tf.string, trainable = False)(textInput)
x = tf.keras.layers.concatenate([x, featuresInput])

for index in range(hp.get('layers')):
    x = tf.keras.layers.Dense(hp.get('units'), activation = 'relu')(x)
    x = tf.keras.layers.Dropout(hp.get('dropout'))(x)

output = tf.keras.layers.Dense(
    1,
    activation       = 'sigmoid',
    bias_initializer = tf.keras.initializers.Constant(INITIAL_BIAS) if INITIAL_BIAS else None
)(x)

model = tf.keras.Model(inputs = [textInput, featuresInput], outputs = output)
model.compile(
    optimizer = tf.keras.optimizers.Adam(lr = hp.get('learning_rate')),
    loss      = tf.keras.losses.BinaryCrossentropy(),
    metrics   = METRICS,
)

And the code fails with error:

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-17-61aed6f885c9> in <module>
     10 featuresInput = tf.keras.Input(shape=(36, ), name = 'features')
     11 
---> 12 x = tfh.KerasLayer(hp.get('embedding'), dtype = tf.string, trainable = False)(textInput)
     13 x = tf.keras.layers.concatenate([x, featuresInput])
     14 

~/.virtualenvs/python3/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer.py in __call__(self, *args, **kwargs)
    920                     not base_layer_utils.is_in_eager_or_tf_function()):
    921                   with auto_control_deps.AutomaticControlDependencies() as acd:
--> 922                     outputs = call_fn(cast_inputs, *args, **kwargs)
    923                     # Wrap Tensors in `outputs` in `tf.identity` to avoid
    924                     # circular dependencies.

~/.virtualenvs/python3/lib/python3.6/site-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
    263       except Exception as e:  # pylint:disable=broad-except
    264         if hasattr(e, 'ag_error_metadata'):
--> 265           raise e.ag_error_metadata.to_exception(e)
    266         else:
    267           raise

AssertionError: in user code:

    /home/e/.virtualenvs/python3/lib/python3.6/site-packages/tensorflow_hub/keras_layer.py:222 call  *
        result = f()
    /home/e/.virtualenvs/python3/lib/python3.6/site-packages/tensorflow/python/saved_model/load.py:486 _call_attribute  **
        return instance.__call__(*args, **kwargs)
    /home/e/.virtualenvs/python3/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py:580 __call__
        result = self._call(*args, **kwds)
    /home/e/.virtualenvs/python3/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py:650 _call
        return self._concrete_stateful_fn._filtered_call(canon_args, canon_kwds)  # pylint: disable=protected-access
    /home/e/.virtualenvs/python3/lib/python3.6/site-packages/tensorflow/python/eager/function.py:1665 _filtered_call
        self.captured_inputs)
    /home/e/.virtualenvs/python3/lib/python3.6/site-packages/tensorflow/python/eager/function.py:1759 _call_flat
        "StatefulPartitionedCall": self._get_gradient_function()}):
    /usr/lib/python3.6/contextlib.py:81 __enter__
        return next(self.gen)
    /home/e/.virtualenvs/python3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py:4735 _override_gradient_function
        assert not self._gradient_function_map

    AssertionError: 

BUT it works if I use Sequential

    model = tf.keras.Sequential([
        hub.KerasLayer(embedding, input_shape=[], dtype = tf.string, trainable = True),
        tf.keras.layers.Dense(16, activation = 'relu', input_shape = (train_features.shape[-1],)),
        tf.keras.layers.Dropout(0.5),
        tf.keras.layers.Dense(1, activation = 'sigmoid', bias_initializer = output_bias),
    ])

    model.compile(optimizer = tf.keras.optimizers.Adam(lr=1e-3), loss = tf.keras.losses.BinaryCrossentropy(), metrics = metrics)

Is there anything I'm doing wrong with the functual api? Can you please help me with the error

1 Answer 1

1

I've faced a similar problem. My solution looks like this:

def build_model():
    premise = keras.Input(shape=(), dtype=tf.string)
    hypothesis = keras.Input(shape=(), dtype=tf.string)
    keras_emb = hub.KerasLayer(embed, input_shape=(), output_shape = (512), dtype=tf.string, trainable=True)
    prem_emb = keras_emb(premise)
    hyp_emb = keras_emb(hypothesis)
    emb = layers.Concatenate()([prem_emb, hyp_emb])
    dense = layers.Dense(32, activation="relu")(emb)
    classifier = layers.Dense(3)(dense)
    model = keras.Model(inputs=[premise, hypothesis], outputs=classifier, name="elementary_model")
    model.compile(loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer="adam", metrics=['accuracy'])
    return model

Note: the text input shape should be () (empty tuple)

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Not the answer you're looking for? Browse other questions tagged or ask your own question.