0

I am new to Trax and am trying to code a transformer using it. When I try to run the training loop, it shows the error:

---------------------------------------------------------------------------

AttributeError                            Traceback (most recent call last)

<ipython-input-153-0ebdab0be47f> in <cell line: 1>()
----> 1 loop = training_loop(transformer, training_pipeline, val_pipeline)
      2 loop.new_rng()
      3 loop.run(max_iterations)

1 frames

<ipython-input-83-f9f54473fec0> in training_loop(transformer, train_data_stream, val_data_stream)
     16   )
     17 
---> 18   loop = training.Loop (
     19       transformer(vocab_size, embedding_len, context_length, num_heads, num_decoder_blocks, fnn_factor, dropout_prob, mode='train'),
     20       train,

/usr/local/lib/python3.10/dist-packages/trax/supervised/training.py in __init__(self, model, tasks, eval_model, eval_tasks, output_dir, checkpoint_at, checkpoint_low_metric, checkpoint_high_metric, permanent_checkpoint_at, eval_at, which_task, n_devices, random_seed, loss_chunk_size, use_memory_efficient_trainer, adasum, callbacks)
    242     # different hosts, leading to different weights on the different hosts.
    243     self._batch_signature = shapes.signature(tasks[0].sample_batch)
--> 244     self._model.rng = self.new_rng()
    245     # In the memory-efficient case, we initialize in init_trainer.
    246     if not use_memory_efficient_trainer:

AttributeError: 'list' object has no attribute 'rng'  

Here is the code I have so far:

def preprocess(data, EOS=1, SEP=0):
  #Adding [EOS], [SEP] to article
  for article_tokens, summary_tokens in data:
    article_tokens = list(article_tokens) + [EOS] + [SEP]
    summary_tokens = list(summary_tokens) + [EOS]

    tokens = np.array(list(article_tokens) + list(summary_tokens))
    summary_mask = np.array(list([0]*(len(article_tokens) + 2) + [1]*(len(summary_tokens) + 1)))

    yield tokens, tokens, summary_mask
def train_data(e):
  for x in dataset['train']:
    yield (x['document'], x['summary'])

def val_data(e):
  for x in dataset['validation']:
    yield (x['document'], x['summary'])
training_pipeline = trax.data.Serial(
  train_data,  
  trax.data.Tokenize(vocab_dir=VOCAB_DIR, vocab_file=VOCAB_FILE),
  preprocess,
  trax.data.BucketByLength(
      boundaries=[128, 256, 512, 1024],
      batch_sizes=[16, 8, 4, 2, 1]
  )
)

val_pipeline = trax.data.Serial(
  val_data,  
  trax.data.Tokenize(vocab_dir=VOCAB_DIR, vocab_file=VOCAB_FILE),
  preprocess,
  trax.data.BucketByLength(
      boundaries=[128, 256, 512, 1024],
      batch_sizes=[16, 8, 4, 2, 1]
  )
)
training_pipeline = training_pipeline(train_data)
val_pipeline = val_pipeline(val_data)

...

loop = training_loop(transformer, training_pipeline, val_pipeline)
loop.run(max_iterations)

I have tried using trax.data.inputs.Inputs to convert the list into a data stream but I was unable to understand how to go about it. I tried using Google and even ChatGPT for help but to no avail.

Edit: Here are the transformer and decoder blocks:

def transformer(vocab_size, embedding_len, context_length, num_heads, num_decoder_blocks, fnn_factor, dropout_prob, mode):

  decoder_blocks = [DecoderBlock(embedding_len, num_heads, fnn_factor, dropout_prob, mode) for _ in range(num_decoder_blocks)]

  positional_encoding = [
      tl.Embedding(vocab_size, embedding_len),
      tl.PositionalEncoding(context_length, dropout_prob, d_feature=embedding_len, mode=mode)
  ]

  return [
      tl.ShiftRight(mode=mode),
      positional_encoding,
      decoder_blocks,
      tl.LayerNorm(),
      tl.Dense(vocab_size),
      tl.LogSoftmax()
  ]
def DecoderBlock(embedding_len, num_heads, fnn_factor, dropout_prob, mode):
  fnn = [
      tl.LayerNorm(),
      tl.Dense(embedding_len*fnn_factor),
      tl.activation_fns.Relu(),
      tl.Dense(embedding_len*fnn_factor),
      tl.Dropout(dropout_prob, mode=mode)
  ]

  return [
      tl.Residual (
          tl.LayerNorm(),
          tl.CausalAttention(d_feature=embedding_len, n_heads=num_heads, dropout=dropout_prob, mode=mode),
          tl.Dropout(dropout_prob, mode=mode)
      ),
      tl.Residual (
          fnn
      )
  ]
4
  • Please include the full traceback error.
    – ewokx
    Commented Jul 3, 2023 at 8:12
  • I have added the full traceback Commented Jul 3, 2023 at 8:46
  • How is transformer defined?
    – Daraan
    Commented Jul 3, 2023 at 8:48
  • Added transformer and decoder block definitions Commented Jul 3, 2023 at 9:29

0

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.