Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch-Lightning #21

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Weekly backup
  • Loading branch information
ValterFallenius committed Apr 21, 2022
commit 89bcfd986cbd3ad591fd23fbc53943e994afd96c
34 changes: 22 additions & 12 deletions MetNet_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,46 +8,56 @@
from pytorch_lightning.callbacks import DeviceStatsMonitor
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
import time

wandb.login()

model = MetNetPylight(
hidden_dim=256, #384 original paper
forecast_steps=8, #240 original paper
forecast_steps=1, #240 original paper
input_channels=15, #46 original paper, hour/day/month = 3, lat/long/elevation = 3, GOES+MRMS = 40
output_channels=128, #512
input_size=112, # 112
n_samples = None, #None = All ~ 23000
num_workers = 4,
batch_size = 8,
learning_rate = 1e-2,
num_att_layers = 8,
num_att_layers = 4,
plot_every = None, #Plot every global_step
rain_step = 0.2,
momentum = 0.9,
att_heads=16,
keep_biggest = 0.15,
leadtime_spacing = 3, #1: 5 minutes, 3: 15 minutes
keep_biggest = 1,
leadtime_spacing = 12, #1: 5 minutes, 3: 15 minutes
)
#PATH_cp = "/proj/berzelius-2022-18/users/sm_valfa/metnet_pylight/metnet/epoch=430-step=22842.ckpt"
#PATH_cp = "/proj/berzelius-2022-18/users/sm_valfa/metnet_pylight/metnet/lit-wandb/jbd0j048/checkpoints/epoch=210-step=14558.ckpt"
#PATH_cp = "/proj/berzelius-2022-18/users/sm_valfa/metnet_pylight/metnet/8leads with agg.ckpt"
#PATH_cp = "/proj/berzelius-2022-18/users/sm_valfa/metnet_pylight/metnet/full_run_continue.ckpt"
#model = MetNetPylight.load_from_checkpoint(PATH_cp)
print(model)

#model.n_samples = 2000
model.testing = False
model.learning_rate = 1e-2
'''print(model)
print(model.forecast_steps)
print(model.input_channels)
print(model.output_channels)
print(model.n_samples)
print(model.file_name)
print(model.keep_biggest)
input()
model.keep_biggest = 0.9
model.n_samples = None'''
#model.printer = True
#model.plot_every = None
#MetNetPylight expects already preprocessed data. Can be change by uncommenting the preprocessing step.
#print(model)

#wandb.restore("/proj/berzelius-2022-18/users/sm_valfa/metnet_pylight/metnet/wandb/run-20220331_162434-21o2u2sj"
#wandb.init(run_id = "21o2u2sj",resume="must")
wandb_logger = WandbLogger(project="lit-wandb")

wandb_logger = WandbLogger(project="lit-wandb", log_model="all")
checkpoint_callback = ModelCheckpoint(monitor="validation/loss_epoch", mode="min", save_top_k=5)


trainer = pl.Trainer(num_sanity_val_steps=2, track_grad_norm = 2, max_epochs=1000, gpus=-1,log_every_n_steps=50, logger = wandb_logger,strategy="ddp", callbacks=[EarlyStopping(monitor="val_loss", mode="min")])
trainer = pl.Trainer(num_sanity_val_steps=2, track_grad_norm = 2, max_epochs=1000, gpus=-1,log_every_n_steps=50, logger = wandb_logger,strategy="ddp", callbacks=[checkpoint_callback])
start_time = time.time()

trainer.fit(model)
Expand Down
5 changes: 4 additions & 1 deletion data_prep/metnet_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def __init__(self,ID, N=None, keep_biggest = 0.5, leadtime_spacing = 1, lead_tim
plt.title("resampling of leadtimes")
plt.show()'''


self.n_samples = len(self.file_names)

def __getitem__(self, index):
Expand All @@ -110,14 +111,16 @@ def __getitem__(self, index):
x = np.load(name_x)

y = np.load(name_y)

persistence = y[0]
y = y[self.leadtime_spacing-1::self.leadtime_spacing]

x = torch.from_numpy(x)
y = torch.from_numpy(y)
if self.ID == "train":

return x, y, self.rainy_leadtimes[index]
if self.ID == "test":
return x, y, persistence
return x, y

def __len__(self):
Expand Down
Loading