Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Always enable flat net construction #1002

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from
Prev Previous commit
Next Next commit
flat construction, small fix
  • Loading branch information
albertz committed Mar 17, 2022
commit 5c4b1bcb5ab9fdd09d61eefb554b69bcaf8e0fda
10 changes: 8 additions & 2 deletions returnn/tf/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,11 @@ class _NetworkConstructionStack:
# for things like CondLayer or RecLayer.
_flat_construction_stack = [] # type: typing.List[_NetworkConstructionStack]

def __init__(self):
def __init__(self, network):
"""
:param TFNetwork network:
"""
self.network = network
self.layers = [] # type: typing.List[str]
self.flat_construct_stack = [] # type: typing.List[typing.Tuple[TFNetwork, str, typing.Dict[str, typing.Any]]]

Expand Down Expand Up @@ -389,6 +393,8 @@ def flat_construct(self, initial_exc):
assert not stack
return res
except _DelayedConstructionException as delayed_exc:
if delayed_exc.network is not self.network:
raise # some parent flat_construct() should handle this
stack.append((delayed_exc.network, delayed_exc.layer_name, delayed_exc.other_kwargs))
except Exception as exc:
attr = "_RETURNN_layer_construction_stack"
Expand Down Expand Up @@ -501,7 +507,7 @@ def __init__(self, config=None, extern_data=None, rnd_seed=None,
self.extra_nets = {} # type: typing.Dict[str,TFNetwork]
self.subnets = {} # type: typing.Dict[str,Subnetwork]
self._selected_train_layers = None
self._construction_stack = _NetworkConstructionStack()
self._construction_stack = _NetworkConstructionStack(self)
self.layers_desc = {} # type: typing.Dict[str,typing.Dict[str]]
self.layers = {} # type: typing.Dict[str,LayerBase]
self.losses_dict = {} # type: typing.Dict[str,LossHolder]
Expand Down