I am using U - Net and implementing the weighting technique described in the papers from 2015 (U-Net: Convolutional Networks for Biomedical Image Segmentation) and 2019 (U-Net – Deep Learning for Cell Counting, Detection, and Morphometry). In that technique there is a variance σ and a weight w_0. I would like, especially the σ, to be a learnable parameter instead of guessing which value is best from dataset to dataset.
- From what I found, I can do this using nn.Parameter.
- To use the learned σ from epoch to epoch, I need somehow to pass this new value to the get_item function of the DataSet through the DataLoader.
My current take on this, is to extend torch.utils.data.DataLoader where the new init has an extra parameter accepting the user specified/learnable parameters. Given the source code of torch.utils.data.DataLoader, I do not understand where and how the DataLoader calls the DataSet instance and hence to pass these parameters.
Code wise, in the DataSet definition there is the function
def __getitem__(self, index):
that I can change as
def __getitem__(self, index, sigma):
and to make use of the updated, newly learned σ.
My problem is that during training, I iterate through training dataset as
for epoch in range( checkpoint[ 'epoch'], num_epochs):
....
for ii, ( X, y, y_weight, fname) in enumerate( dataLoader[ phase]):
In that enumeration of DataLoader, how can I pass the new σ to the DataLoader such that the DataLoader will pass it to the DataSet getitem function mentioned above?
EDIT
Currently, I define inside the DataSet class a parameter sigma
class MedicalImageDataset( Dataset):
def __init__(self, fname, img_transform = None, mask_transform = None, weight_transform = None, sigma = 8):
...
self.sigma = sigma
def __getitem__(self, index):
sigma = self.sigma
...
which I update through the DataLoader as
dataLoader[ 'train'].dataset.sigma = model.sigma
where,
model.sigma
is a custom parameter defined as
model.register_parameter( name = 'sigma', param = torch.nn.Parameter( torch.tensor( 16, dtype = torch.float16), requires_grad = True))
after creating the model.
My problem is, that model.sigma
doesn't look being updated from epoch to epoch. Specifically, is the same as the initial value. Why is this?
Having a look at optimizer.state_dict()
I couldn't find any parameter named 'sigma', whereas I can find one in model.named_parameters()
.
Finally, this parameter sigma is not attached to any layer, it's kinda "free".