I want to classify text to 2 classes by using this embedding: https://tfhub.dev/google/universal-sentence-encoder-multilingual/3
And I also want to add additional features after the embedding. So I have two inputs:
import tensorflow as tf
import tensorflow_hub as tfh
import tensorflow_datasets as tfds
import tensorflow_text as tft
hp = {
'embedding': 'https://tfhub.dev/google/universal-sentence-encoder-multilingual/3' EMBEDDINGS['senm'],
'units': 64,
'learning_rate': 1e-3,
'dropout': 0.2,
'layers': 2
}
textInput = tf.keras.Input(shape=(1, ), name = 'text', dtype = tf.string)
featuresInput = tf.keras.Input(shape=(36, ), name = 'features')
x = tfh.KerasLayer(hp.get('embedding'), dtype = tf.string, trainable = False)(textInput)
x = tf.keras.layers.concatenate([x, featuresInput])
for index in range(hp.get('layers')):
x = tf.keras.layers.Dense(hp.get('units'), activation = 'relu')(x)
x = tf.keras.layers.Dropout(hp.get('dropout'))(x)
output = tf.keras.layers.Dense(
1,
activation = 'sigmoid',
bias_initializer = tf.keras.initializers.Constant(INITIAL_BIAS) if INITIAL_BIAS else None
)(x)
model = tf.keras.Model(inputs = [textInput, featuresInput], outputs = output)
model.compile(
optimizer = tf.keras.optimizers.Adam(lr = hp.get('learning_rate')),
loss = tf.keras.losses.BinaryCrossentropy(),
metrics = METRICS,
)
And the code fails with error:
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
<ipython-input-17-61aed6f885c9> in <module>
10 featuresInput = tf.keras.Input(shape=(36, ), name = 'features')
11
---> 12 x = tfh.KerasLayer(hp.get('embedding'), dtype = tf.string, trainable = False)(textInput)
13 x = tf.keras.layers.concatenate([x, featuresInput])
14
~/.virtualenvs/python3/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer.py in __call__(self, *args, **kwargs)
920 not base_layer_utils.is_in_eager_or_tf_function()):
921 with auto_control_deps.AutomaticControlDependencies() as acd:
--> 922 outputs = call_fn(cast_inputs, *args, **kwargs)
923 # Wrap Tensors in `outputs` in `tf.identity` to avoid
924 # circular dependencies.
~/.virtualenvs/python3/lib/python3.6/site-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
263 except Exception as e: # pylint:disable=broad-except
264 if hasattr(e, 'ag_error_metadata'):
--> 265 raise e.ag_error_metadata.to_exception(e)
266 else:
267 raise
AssertionError: in user code:
/home/e/.virtualenvs/python3/lib/python3.6/site-packages/tensorflow_hub/keras_layer.py:222 call *
result = f()
/home/e/.virtualenvs/python3/lib/python3.6/site-packages/tensorflow/python/saved_model/load.py:486 _call_attribute **
return instance.__call__(*args, **kwargs)
/home/e/.virtualenvs/python3/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py:580 __call__
result = self._call(*args, **kwds)
/home/e/.virtualenvs/python3/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py:650 _call
return self._concrete_stateful_fn._filtered_call(canon_args, canon_kwds) # pylint: disable=protected-access
/home/e/.virtualenvs/python3/lib/python3.6/site-packages/tensorflow/python/eager/function.py:1665 _filtered_call
self.captured_inputs)
/home/e/.virtualenvs/python3/lib/python3.6/site-packages/tensorflow/python/eager/function.py:1759 _call_flat
"StatefulPartitionedCall": self._get_gradient_function()}):
/usr/lib/python3.6/contextlib.py:81 __enter__
return next(self.gen)
/home/e/.virtualenvs/python3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py:4735 _override_gradient_function
assert not self._gradient_function_map
AssertionError:
BUT it works if I use Sequential
model = tf.keras.Sequential([
hub.KerasLayer(embedding, input_shape=[], dtype = tf.string, trainable = True),
tf.keras.layers.Dense(16, activation = 'relu', input_shape = (train_features.shape[-1],)),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(1, activation = 'sigmoid', bias_initializer = output_bias),
])
model.compile(optimizer = tf.keras.optimizers.Adam(lr=1e-3), loss = tf.keras.losses.BinaryCrossentropy(), metrics = metrics)
Is there anything I'm doing wrong with the functual api? Can you please help me with the error