Skip to content

Commit

Permalink
[BUG] Added level domain protection for normality and permbu me…
Browse files Browse the repository at this point in the history
…thods. (#166)

* Added mypy/flake8 installation to CONTRIBUTING instructions

* Added level (0,100] domain protection

Co-authored-by: fede <[email protected]>
  • Loading branch information
kdgutier and AzulGarza authored Jan 23, 2023
1 parent 7cce341 commit 2eadc28
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 10 deletions.
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ nbdev_export
```

### Check syntax with Linters
This project uses a couple of linters to validate different aspects of the code. Before opening a PR, please make sure that it passes all the linting tasks by following the next steps.
This project uses a couple of linters to validate different aspects of the code. Before opening a PR, please make sure that it passes all the linting tasks by following the next steps. After installing `pip install flake8` and `pip install mypy`.

* `mypy hierarchicalforecast/`
* `flake8 --select=F hierarchicalforecast/`
Expand Down
2 changes: 1 addition & 1 deletion hierarchicalforecast/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,4 +180,4 @@
'hierarchicalforecast.utils.is_strictly_hierarchical': ( 'utils.html#is_strictly_hierarchical',
'hierarchicalforecast/utils.py'),
'hierarchicalforecast.utils.numpy_balance': ( 'utils.html#numpy_balance',
'hierarchicalforecast/utils.py')}}}
'hierarchicalforecast/utils.py')}}}
16 changes: 12 additions & 4 deletions hierarchicalforecast/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,9 @@ def _prepare_fit(self,
S_df: pd.DataFrame,
Y_df: Optional[pd.DataFrame],
tags: Dict[str, np.ndarray],
intervals_method: str,
sort_df: bool):
level: Optional[List[int]] = None,
intervals_method: str = 'normality',
sort_df: bool = True):
"""
Performs preliminary wrangling and protections
"""
Expand Down Expand Up @@ -122,7 +123,13 @@ def _prepare_fit(self,

if self.insample or (intervals_method in ['bootstrap', 'permbu']):
if Y_df is None:
raise Exception('you need to pass `Y_df`')
raise Exception('you need to pass `Y_df`')

# Protect level list
if (level is not None):
level_outside_domain = np.any((np.array(level) <= 0)|(np.array(level) > 100 ))
if level_outside_domain and (intervals_method in ['normality', 'permbu']):
raise Exception('Level outside domain, send `level` list in (0,100]')

# Declare output names
drop_cols = ['ds', 'y'] if 'y' in Y_hat_df.columns else ['ds']
Expand Down Expand Up @@ -186,7 +193,7 @@ def reconcile(self,
If a class of `self.reconciles` receives `y_hat_insample`, `Y_df` must include them as columns.<br>
`S`: pd.DataFrame with summing matrix of size `(base, bottom)`, see [aggregate method](https://nixtla.github.io/hierarchicalforecast/utils.html#aggregate).<br>
`tags`: Each key is a level and its value contains tags associated to that level.<br>
`level`: float list 0-100, confidence levels for prediction intervals.<br>
`level`: positive float list (0-100], confidence levels for prediction intervals.<br>
`intervals_method`: str, method used to calculate prediction intervals, one of `normality`, `bootstrap`, `permbu`.<br>
`num_samples`: int=-1, if positive return that many probabilistic coherent samples.
`seed`: int=0, random seed for numpy generator's replicability.<br>
Expand All @@ -201,6 +208,7 @@ def reconcile(self,
S_df=S,
Y_df=Y_df,
tags=tags,
level=level,
intervals_method=intervals_method,
sort_df=sort_df)

Expand Down
37 changes: 33 additions & 4 deletions nbs/core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,9 @@
" S_df: pd.DataFrame,\n",
" Y_df: Optional[pd.DataFrame],\n",
" tags: Dict[str, np.ndarray],\n",
" intervals_method: str, \n",
" sort_df: bool):\n",
" level: Optional[List[int]] = None,\n",
" intervals_method: str = 'normality',\n",
" sort_df: bool = True):\n",
" \"\"\"\n",
" Performs preliminary wrangling and protections\n",
" \"\"\"\n",
Expand Down Expand Up @@ -222,7 +223,13 @@
"\n",
" if self.insample or (intervals_method in ['bootstrap', 'permbu']):\n",
" if Y_df is None:\n",
" raise Exception('you need to pass `Y_df`') \n",
" raise Exception('you need to pass `Y_df`')\n",
" \n",
" # Protect level list\n",
" if (level is not None):\n",
" level_outside_domain = np.any((np.array(level) <= 0)|(np.array(level) > 100 ))\n",
" if level_outside_domain and (intervals_method in ['normality', 'permbu']):\n",
" raise Exception('Level outside domain, send `level` list in (0,100]')\n",
"\n",
" # Declare output names\n",
" drop_cols = ['ds', 'y'] if 'y' in Y_hat_df.columns else ['ds']\n",
Expand Down Expand Up @@ -286,7 +293,7 @@
" If a class of `self.reconciles` receives `y_hat_insample`, `Y_df` must include them as columns.<br>\n",
" `S`: pd.DataFrame with summing matrix of size `(base, bottom)`, see [aggregate method](https://nixtla.github.io/hierarchicalforecast/utils.html#aggregate).<br>\n",
" `tags`: Each key is a level and its value contains tags associated to that level.<br>\n",
" `level`: float list 0-100, confidence levels for prediction intervals.<br>\n",
" `level`: positive float list (0-100], confidence levels for prediction intervals.<br>\n",
" `intervals_method`: str, method used to calculate prediction intervals, one of `normality`, `bootstrap`, `permbu`.<br>\n",
" `num_samples`: int=-1, if positive return that many probabilistic coherent samples.\n",
" `seed`: int=0, random seed for numpy generator's replicability.<br>\n",
Expand All @@ -301,6 +308,7 @@
" S_df=S,\n",
" Y_df=Y_df,\n",
" tags=tags,\n",
" level=level,\n",
" intervals_method=intervals_method,\n",
" sort_df=sort_df)\n",
"\n",
Expand Down Expand Up @@ -910,6 +918,27 @@
"bootstrap_df"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"# test level protection for PERMBU and Normality probabilistic methods\n",
"hrec = HierarchicalReconciliation([BottomUp()])\n",
"test_fail(\n",
" hrec.reconcile,\n",
" contains='Level outside domain',\n",
" args=(hier_grouped_hat_df, S_grouped_df, tags_grouped, hier_grouped_df, [0, 80, 90], 'permbu',)\n",
")\n",
"test_fail(\n",
" hrec.reconcile,\n",
" contains='Level outside domain',\n",
" args=(hier_grouped_hat_df, S_grouped_df, tags_grouped, hier_grouped_df, [80, 90, 101], 'normality',)\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down

0 comments on commit 2eadc28

Please sign in to comment.