Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch-Lightning #21

Open
wants to merge 18 commits into
base: main
Choose a base branch
from

Conversation

ValterFallenius
Copy link

Pull Request

Description

Converted to pytorch-lightning for easy parallelization. Needs work on the forward pass. How do we create a forward pass dependent on a random lead time and at the same time couple this to the loss function so we pair it with the correct slice of the Y-tensor? Also the paper mentions that they do random lead times during training, but for validation the check every possible lead-time once.

Right now the code creates forecast_steps = 6 copies of the input tensor and concatenates it with 6 different one-hot encodings, this leads to copying the input tensor 6 times which will be unacceptable for larger lead times. I want to go back to using self.ct(x, random_int) in the forward pass to reduce memory usage, but I haven't figured out how to do this and at the same time pair it with the correct ground truth.

Dataset

There's a subsample of my dataset linked in the README, this is not the "prettified" fields that I will use for the thesis, but it has the same dimensions and value range so is good enough for testing. The dataset are DBZ along with longitude, lattitude and elevation encodings (southern half of sweden).

Valterf and others added 3 commits February 24, 2022 17:10
setup() method is ugly and needs fixing before running on default parameters. I still haven't figured out how to do this properly.

We want the train_dataset to have a set of day and val_dataset to have the other days. Furthermore training should be done by sampling random lead_time and random day. However validation should not be random.

Link to datafiles in README.
@jacobbieker
Copy link
Member

First of all, thanks for submitting this! And releasing some of the data! Would it be okay if I rehost it on HuggingFace Datasets the sample you released? Even not being the 'prettified' data, and only a subsample, it still could be quite nice to have some place other than the US for training data. But if not, that's fine too!

For training, one idea would be in the dataloader randomly pick the lead time, and only load that future time slice as the target, returning the history frames, the single future lead time frame, and the lead time index to the model, if that makes sense? That should reduce the memory requirements quite a bit.

#20 changes MetNet and MetNet-2 to only use a single lead time by default, you can look at the test_model.py to see how to predict for multiple, which should help with memory consumption.

For the overall PR, I still need to go through it a bit more, but ideally, we want to keep this repo quite general and with minimal dependencies for importing the models, so I would recommend potentially removing the scripts and such that are specific to running on your machine or your cluster, and potentially moving all of the rest of these under a train/ folder instead. So that this adds an example of how to wrap the model in PyTorch Lightning, potentially load some data, and train it from scratch on some data that we can then point to in the documentation. And then we can have an optional dependency on PyTorch Lightning for people who want to use it, while needing only plain PyTorch for the basic models.

@ValterFallenius
Copy link
Author

Let me get back to you on the dataset. My supervisor still haven't answered me... in the meantime you can use it yourself.

@ValterFallenius
Copy link
Author

Damn, I just solved the conditional lead time problem. It was such an easy fix too... i feel stupid.

I didn't change the data loader, instead I added a random integer into the training loop and fed it as input to the forward pass, then sliced the Y-tensor accordingly.

Berzelius is under maintenance yesterday and today, I hope to run the network next week with some TPUs 😄

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants