Skip to content

Commit

Permalink
Merge pull request #114 from hhou435/pipeline
Browse files Browse the repository at this point in the history
add support for pipeline parallelism
  • Loading branch information
hhou435 authored Jan 24, 2024
2 parents 1dc9dd1 + fac9488 commit f5ce4e5
Show file tree
Hide file tree
Showing 12 changed files with 322 additions and 60 deletions.
50 changes: 50 additions & 0 deletions models/deepspeed_mp_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
{
"gradient_accumulation_steps": 1,
"train_micro_batch_size_per_gpu":1,
"steps_per_print": 100,
"optimizer": {
"type": "Adam",
"params": {
"lr": 2e-5,
"weight_decay": 1e-2
}
},
"flops_profiler": {
"enabled": false,
"profile_step": 1,
"module_depth": -1,
"top_modules": 3,
"detailed": true
},
"fp16": {
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"gradient_accumulation_dtype": "bf16",
"fp32_allreduce": true,
"data_types": {
"grad_accum_dtype":"fp32"
},
"zero_optimization": {
"stage": 1,
"offload_param": {
"device": "cpu",
"pin_memory": true
},
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
}
},
"activation_checkpointing": {
"partition_activations": false,
"contiguous_memory_optimization": false,
"cpu_checkpointing": false
},
"wall_clock_breakdown": false,
"zero_allow_untested_optimizer": true,
"zero_force_ds_cpu_optimization": false
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import argparse
import os
import collections
import torch


parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--input_model_path", type=str, default="models/input_model",
help=".")
parser.add_argument("--output_model_path", type=str, default="models/output_model.bin",
help=".")
parser.add_argument("--layers_num", type=int, default=32)
parser.add_argument("--tensor_model_parallel_size", type=int, default=4)

args = parser.parse_args()

if not os.path.exists(args.output_model_path):
os.mkdir(args.output_model_path)

model_piece_list = []
for n in range(args.tensor_model_parallel_size):
model_piece = collections.OrderedDict()
model_index = str(n) if len(str(n))==2 else '0'+str(n)
for i in range(args.layers_num+2):
layer_index = str(i) if len(str(i))==2 else '0'+str(i)
weight_name = f"layer_{layer_index}-model_{model_index}-model_states.pt"
tmp_weight = torch.load(os.path.join(args.input_model_path, weight_name), map_location="cpu")
if i == 0:
model_piece["embedding.word.embedding.weight"] = tmp_weight["embeddings.word.embedding.weight"]
elif i == args.layers_num+1:
model_piece["target.lm.output_layer.weight"] = tmp_weight["target_layer.lm.output_layer.weight"]
else:
for j in range(3):
model_piece["encoder.transformer." + str(i-1) + ".self_attn.linear_layers."+ str(j) +".weight"] = tmp_weight["layer.self_attn.linear_layers."+ str(j) +".weight"]
model_piece["encoder.transformer." + str(i-1) + ".self_attn.final_linear.weight"] = tmp_weight["layer.self_attn.final_linear.weight"]
model_piece["encoder.transformer." + str(i-1) + ".feed_forward.linear_1.weight"] = tmp_weight["layer.feed_forward.linear_1.weight"]
model_piece["encoder.transformer." + str(i-1) + ".feed_forward.linear_2.weight"] = tmp_weight["layer.feed_forward.linear_2.weight"]
model_piece["encoder.transformer." + str(i-1) + ".feed_forward.linear_gate.weight"] = tmp_weight["layer.feed_forward.linear_gate.weight"]
model_piece["encoder.transformer." + str(i-1) + ".layer_norm_1.weight"] = tmp_weight["layer.layer_norm_1.weight"]
model_piece["encoder.transformer." + str(i-1) + ".layer_norm_2.weight"] = tmp_weight["layer.layer_norm_2.weight"]
if i == args.layers_num:
model_piece["encoder.layer_norm.weight"] = tmp_weight["layer_norm.weight"]
model_piece_list.append(model_piece)

output_model = model_piece_list[0]

for n in range(1, args.tensor_model_parallel_size):
model_piece = model_piece_list[n]
output_model["embedding.word.embedding.weight"] = torch.cat((output_model["embedding.word.embedding.weight"], model_piece["embedding.word.embedding.weight"]),dim=-2)

for i in range(args.layers_num):
for j in range(3):
tensor_a=output_model["encoder.transformer." + str(i) + ".self_attn.linear_layers."+ str(j) +".weight"]
tensor_b=model_piece["encoder.transformer." + str(i) + ".self_attn.linear_layers."+ str(j) +".weight"]
output_model["encoder.transformer." + str(i) + ".self_attn.linear_layers."+ str(j) +".weight"]=torch.cat((tensor_a,tensor_b),dim=-2)

tensor_a=output_model["encoder.transformer." + str(i) + ".self_attn.final_linear.weight"]
tensor_b=model_piece["encoder.transformer." + str(i) + ".self_attn.final_linear.weight"]

output_model["encoder.transformer." + str(i) + ".self_attn.final_linear.weight"]=torch.cat((tensor_a,tensor_b),dim=-1)

tensor_a=output_model["encoder.transformer." + str(i) + ".feed_forward.linear_1.weight"]
tensor_b=model_piece["encoder.transformer." + str(i) + ".feed_forward.linear_1.weight"]
output_model["encoder.transformer." + str(i) + ".feed_forward.linear_1.weight"]=torch.cat((tensor_a,tensor_b),dim=-2)

tensor_a=output_model["encoder.transformer." + str(i) + ".feed_forward.linear_gate.weight"]
tensor_b=model_piece["encoder.transformer." + str(i) + ".feed_forward.linear_gate.weight"]
output_model["encoder.transformer." + str(i) + ".feed_forward.linear_gate.weight"]=torch.cat((tensor_a,tensor_b),dim=-2)

tensor_a=output_model["encoder.transformer." + str(i) + ".feed_forward.linear_2.weight"]
tensor_b=model_piece["encoder.transformer." + str(i) + ".feed_forward.linear_2.weight"]
output_model["encoder.transformer." + str(i) + ".feed_forward.linear_2.weight"]=torch.cat((tensor_a,tensor_b),dim=-1)

tensor_a=output_model["target.lm.output_layer.weight"]
tensor_b=model_piece["target.lm.output_layer.weight"]
output_model["target.lm.output_layer.weight"]=torch.cat((tensor_a,tensor_b),dim=-2)

torch.save(output_model, args.output_model_path)
12 changes: 12 additions & 0 deletions tencentpretrain/embeddings/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,15 @@ def forward(self, src, seg):
emb = self.layer_norm(emb)
emb = self.dropout(emb)
return emb


class EmbeddingPipe(torch.nn.Module):
def __init__(self, args,model):
super(EmbeddingPipe, self).__init__()
self.embeddings = model.embedding

def forward(self, inputs):
src, seg = inputs
emb = self.embeddings(src, seg)

return emb, seg
2 changes: 1 addition & 1 deletion tencentpretrain/embeddings/word_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class WordEmbedding(nn.Module):

def __init__(self, args, vocab_size):
super(WordEmbedding, self).__init__()
if args.use_mp:
if args.tensor_model_parallel_size > 1:
self.embedding = mpu.VocabParallelEmbedding(vocab_size, args.emb_size)
else:
self.embedding = nn.Embedding(vocab_size, args.emb_size)
Expand Down
8 changes: 4 additions & 4 deletions tencentpretrain/encoders/transformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self, args):
self.relative_position_embedding = args.relative_position_embedding
self.rotary_position_embedding = args.rotary_position_embedding
self.has_residual_attention = args.has_residual_attention
self.use_mp = args.use_mp
self.tensor_model_parallel_size = args.tensor_model_parallel_size

if self.relative_position_embedding:
args.relative_pos_emb = RelativePositionEmbedding(bidirectional=True, heads_num=args.heads_num,
Expand All @@ -40,12 +40,12 @@ def __init__(self, args):
self.linear = nn.Linear(args.emb_size, args.hidden_size)

if self.parameter_sharing:
if self.use_mp:
if self.tensor_model_parallel_size > 1:
self.transformer = ParallelTransformerLayer(args)
else:
self.transformer = TransformerLayer(args)
else:
if self.use_mp:
if self.tensor_model_parallel_size > 1:
self.transformer = nn.ModuleList(
[ParallelTransformerLayer(args) for _ in range(self.layers_num)]
)
Expand Down Expand Up @@ -123,7 +123,7 @@ def custom_forward(*inputs):
return inputs

return custom_forward
if self.use_mp:
if self.tensor_model_parallel_size > 1:
mpu.reset_checkpointed_activations_memory_buffer()
l = 0
while l < self.layers_num:
Expand Down
2 changes: 1 addition & 1 deletion tencentpretrain/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def finish_mpu_init():
if args.deepspeed:
import deepspeed
deepspeed.init_distributed(dist_backend=args.backend)
if args.use_mp:
if args.tensor_model_parallel_size > 1 or args.pipeline_model_parallel_size > 1:
set_global_variables(args)
args = get_args()
finish_mpu_init()
Expand Down
55 changes: 55 additions & 0 deletions tencentpretrain/layers/transformer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
import torch.nn as nn
from tencentpretrain.layers.multi_headed_attn import MultiHeadedAttention, ParallelMultiHeadedAttention
from tencentpretrain.layers import *
Expand Down Expand Up @@ -200,6 +201,60 @@ def forward(self, inputs):
return output, mask


class ParallelTransformerLayerPipe(nn.Module):

def __init__(self, args, model, layer_idx):
super(ParallelTransformerLayerPipe, self).__init__()
self.layer_idx = layer_idx
self.layer_num = args.layers_num
self.layernorm_positioning = args.layernorm_positioning
self.has_residual_attention = args.has_residual_attention
if self.layernorm_positioning == "pre":
self.layer_norm = model.encoder.layer_norm
self.layer = model.encoder.transformer[layer_idx]

def generate_mask(self, seq_length, batch_size, device):
mask = torch.ones(seq_length, seq_length, device=device)
mask = torch.tril(mask)
mask = (1.0 - mask) * -10000
mask = mask.repeat(batch_size, 1, 1, 1)
return mask

def forward(self, inputs):

"""
Args:
hidden: [batch_size x seq_length x emb_size]
mask: [batch_size x 1 x seq_length x seq_length]
position_bias: [1 x heads_num x seq_length x seq_length]
Returns:
output: [batch_size x seq_length x hidden_size]
"""

if len(inputs)==2:
hidden, seg = inputs
prev_attn = None
else:
hidden, seg, prev_attn = inputs
batch_size, seq_length, _ = hidden.size()
mask = self.generate_mask(seq_length, batch_size, hidden.device)
layer_inputs = hidden, mask, prev_attn
outputs = self.layer(layer_inputs)

if self.has_residual_attention:
hidden, mask, prev_attn_out = outputs
else:
hidden, mask = outputs

if self.layer_idx == self.layer_num-1 and self.layernorm_positioning == "pre":
hidden = self.layer_norm(hidden)

if self.has_residual_attention:
return hidden, seg, prev_attn
else:
return hidden, seg


class TransformerDecoderLayer(nn.Module):
def __init__(self, args):
super(TransformerDecoderLayer, self).__init__()
Expand Down
4 changes: 0 additions & 4 deletions tencentpretrain/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,12 +228,8 @@ def deepspeed_opts(parser):


def mp_opts(parser):
parser.add_argument("--use_mp", action="store_true",
help=".")
parser.add_argument("--tensor_model_parallel_size", type=int, default=1,
help="Degree of tensor model parallelism.")
parser.add_argument("--use_pipe", action="store_true",
help=".")
parser.add_argument("--pipeline_model_parallel_size", type=int, default=1,
help="Degree of pipeline model parallelism.")
parser.add_argument("--virtual_pipeline_model_parallel_size", type=int, default=None,
Expand Down
16 changes: 10 additions & 6 deletions tencentpretrain/targets/lm_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ def __init__(self, args, vocab_size):
super(LmTarget, self).__init__()
self.vocab_size = vocab_size
self.hidden_size = args.hidden_size
self.use_mp = args.use_mp
self.tensor_model_parallel_size = args.tensor_model_parallel_size
self.pipeline_model_parallel_size = args.pipeline_model_parallel_size
if "label_smoothing" in args:
self.label_smoothing = args.label_smoothing
else:
Expand All @@ -24,7 +25,7 @@ def __init__(self, args, vocab_size):
else:
self.ignore_index = None
self.prefix_lm_loss = args.prefix_lm_loss
if self.use_mp:
if self.tensor_model_parallel_size > 1:
self.output_layer = mpu.ColumnParallelLinear(self.hidden_size, self.vocab_size, gather_output=False,
skip_bias_add=True, bias=False)
else:
Expand All @@ -34,8 +35,6 @@ def __init__(self, args, vocab_size):

def lm(self, memory_bank, tgt_lm, seg):
# Language modeling (LM) with full softmax prediction.

tgt_lm = tgt_lm.contiguous().view(-1)
seg = seg.contiguous().view(-1)

memory_bank = memory_bank.contiguous().view(-1, self.hidden_size)
Expand All @@ -44,10 +43,15 @@ def lm(self, memory_bank, tgt_lm, seg):
# For example seg=[1,1,1,2,2,0], when loss_prefix = 0 , parts of seg > 0 tokens are computed loss
# when loss_prefix=1 , only parts of seg = 2 tokens are computed loss
memory_bank = memory_bank[seg > loss_mask, :]
tgt_lm = tgt_lm[seg > loss_mask]
if tgt_lm is not None:
tgt_lm = tgt_lm.contiguous().view(-1)
tgt_lm = tgt_lm[seg > loss_mask]

output = self.output_layer(memory_bank)
if self.use_mp:
if self.pipeline_model_parallel_size > 1:

return output, loss_mask
elif self.tensor_model_parallel_size > 1:
losses = mpu.vocab_parallel_cross_entropy(output, tgt_lm)
loss = torch.sum(losses.view(-1)) / len(losses)
else:
Expand Down
26 changes: 26 additions & 0 deletions tencentpretrain/targets/target.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import torch
import torch.nn as nn

from tencentpretrain import mpu


class Target(nn.Module):
def __init__(self):
Expand All @@ -21,3 +24,26 @@ def forward(self, memory_bank, tgt, seg):
self.loss_info = target(memory_bank, tgt, seg)

return self.loss_info


class TargetPipe(nn.Module):
def __init__(self,args,model):
super(TargetPipe, self).__init__()
self.target_layer=model.target
def forward(self,inputs):
hidden, seg=inputs
loss_info=self.target_layer(hidden, None, seg)

return loss_info


def CrossEntropy(outputs, labels):
output, loss_mask = outputs
tgt_lm, seg = labels[0], labels[1]
seg = seg.contiguous().view(-1)
tgt_lm = tgt_lm.contiguous().view(-1)
tgt_lm = tgt_lm[seg > loss_mask]
losses = mpu.vocab_parallel_cross_entropy(output, tgt_lm)
loss = torch.sum(losses.view(-1)) / len(losses)

return loss
Loading

0 comments on commit f5ce4e5

Please sign in to comment.