Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Always enable flat net construction #1002

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from
Next Next commit
Always enable flat net construction
Fix #992
  • Loading branch information
albertz committed Mar 17, 2022
commit fa90463596fea1c311e350b07014718b42a1edfd
26 changes: 7 additions & 19 deletions returnn/tf/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,17 +786,6 @@ def get_layer(src_name):
self.used_data_keys.update(extra_net.used_data_keys)
return created_layers

def _flat_construction_enabled(self):
"""
:return: whether to use flat construction algorithm in :func:`construct_layer`.
Use this if you get stack overflow errors, such as:
``Fatal Python error: Cannot recover from stack overflow``
or
``RuntimeError: maximum recursion depth exceeded``.
:rtype: bool
"""
return self.get_config().bool("flat_net_construction", False)

def construct_layer(self, net_dict, name, get_layer=None, add_layer=None, check_existing=True):
"""
This triggers the construction of the layer `name` if it is not constructed yet.
Expand Down Expand Up @@ -918,14 +907,13 @@ def construct_layer(self, net_dict, name, get_layer=None, add_layer=None, check_
layer_name=full_name, network=self)
return sub_layer

if self._flat_construction_enabled():
delayed_exc = _DelayedConstructionException(
network=self, layer_name=name,
other_kwargs=dict(net_dict=net_dict, get_layer=get_layer, add_layer=add_layer, check_existing=check_existing))
if not self._construction_stack.in_flat_construct_count:
return self._construction_stack.flat_construct(delayed_exc)
if self._construction_stack.layers:
raise delayed_exc
delayed_exc = _DelayedConstructionException(
network=self, layer_name=name,
other_kwargs=dict(net_dict=net_dict, get_layer=get_layer, add_layer=add_layer, check_existing=check_existing))
if not self._construction_stack.in_flat_construct_count:
return self._construction_stack.flat_construct(delayed_exc)
if self._construction_stack.layers:
raise delayed_exc

layer_desc = layer_desc.copy()
layer_desc.pop("class")
Expand Down