The Annotated S4
The Annotated S4
The Annotated S4
Efficiently
Efficiently Modeling
EfficientlyModeling Long
ModelingLong Sequences
LongSequences with
Sequenceswith Structured
withStructured State
StructuredState Spaces
StateSpaces
Spaces
Structured
The Structured State
StructuredState Space
StateSpace for
Spacefor Sequence
forSequence Modeling
SequenceModeling
Modeling (S4) architecture is a new approach to very
long-range sequence modeling tasks for vision, language, and audio, showing a capacity to
capture dependencies over tens of thousands of steps. Especially impressive are the model’s
results on the challenging Long
LongRange
Long RangeArena
Range Arena benchmark, showing an ability to reason over
Arena
sequences of up to 16,000+ elements with high accuracy.
Table of Contents
Part
Part1:
Part 1:State
1: StateSpace
State Space Models (Modeling)
SpaceModels
Models
Discrete-time
Discrete-timeSSM:
Discrete-time SSM:The
SSM: TheRecurrent
The RecurrentRepresentation
Recurrent Representation
Representation
Tangent:
Tangent:A
Tangent: AMechanics
A MechanicsExample
Mechanics Example
Example
Training
TrainingSSMs:
Training SSMs:The
SSMs: TheConvolutional
The ConvolutionalRepresentation
Convolutional Representation
Representation
An
AnSSM
An SSMNeural
SSM NeuralNetwork.
Neural Network.
Network.
Part
Part1b:
Part 1b:Addressing
1b: AddressingLong-Range
Addressing Long-RangeDependencies
Long-Range Dependencieswith
Dependencies withHiPPO
with HiPPO
HiPPO
Part
Part2:
Part 2:Implementing
2: ImplementingS4
Implementing S4 (Advanced)
S4
Step
Step1.
Step 1.SSM
1. SSMGenerating
SSM GeneratingFunctions
Generating Functions
Functions
Step
Step2:
Step 2:Diagonal
2: DiagonalCase
Diagonal Case
Case
Part 1: State Space Models
Discrete-time SSM: The Recurrent Representation
Tangent: A Mechanics Example
Training SSMs: The Convolutional Representation
:
Step
Step3:
Step 3:
3:Diagonal
Diagonal
DiagonalPlus
Plus
PlusLow-Rank
Low-Rank
Low-Rank
Diagonal
DiagonalPlus
Diagonal Plus
PlusLow-Rank
Low-Rank
Low-RankRNN.
RNN.
RNN.
Turning
TurningHiPPO
Turning HiPPO
HiPPOto
to
toDPLR
DPLR
DPLR
Final
FinalCheck
Final Check
Check
Part
Part3:
Part 3:
3:S4
S4
S4in
in
inPractice
Practice
Practice (NN Implementation)
S4
S4CNN
S4 CNN
CNN///RNN
RNN
RNNLayer
Layer
Layer
Sampling
Samplingand
Sampling and
andCaching
Caching
Caching
Experiments:
Experiments:MNIST
Experiments: MNIST
MNIST
Experiments:
Experiments:QuickDraw
Experiments: QuickDraw
QuickDraw
Step 3: Diagonal
Experiments:
Experiments:
Experiments: Plus Digits
Spoken
Spoken
Spoken Low-Rank
Digits
Digits
Diagonal Plus Low-Rank RNN.
Conclusion
Conclusion
Conclusion
Turning HiPPO to DPLR
JAX
JAX with the Flax
Note that this project uses JAX Flax
Flax NN library. While we personally mainly use
Torch, theFinal Check nature of JAX is a good fit for some of the complexities of S4. We make
functional
heavyPart vmap
vmap,
vmap
vmap
use3:ofS4 scan
scan,
scan their NN
scan
in Practice NN
NNcousins
cousins
cousins,
cousins and most importantly jax.jit
jax.jit
jax.jit to compile fast and
efficient S4 layers.
S4 CNN / RNN Layer
Sampling and
from functools Caching
import partial
import jax
Experiments: MNIST
import jax.numpy as np
Experiments:
from flax QuickDraw
import linen as nn
from jax.nn.initializers import lecun_normal, normal
Experiments: Spoken Digits
from jax.numpy.linalg import eigh, inv, matrix_power
Conclusion
from jax.scipy.signal import convolve
JAX Flax
if __name__ == "__main__":
this scan
# Forvmap tutorial,NN cousins a global JAX rng key
construct jax.jit
# But we don't want it when importing as a library
rng = jax.random.PRNGKey(1)
state
statespace
The state space
spacemodel
model
model is defined by this simple equation. It maps a 1-D input signal u(t)
to an N -D latent state x(t) before projecting to a 1-D output signal y(t).
′
( )= ( )+ ()
:
x′ (t) = Ax(t) + Bu(t)
y(t) = Cx(t) + Du(t)
Our goal is to simply use the SSM as a black-box representation in a deep sequence
model, where A, B, C, D are parameters learned by gradient descent. For the
remainder, we will omit the parameter D for exposition (or equivalently, assume D =
0 because the term Du can be viewed as a skip connection and is easy to compute).
An SSM maps a input u(t) to a state representation vector x(t) and an output y(t).
For simplicity, we assume the input and output are one-dimensional, and the state
representation is N -dimensional. The first equation defines the change in x(t) over
time.
Our SSMs will be defined by three matrices – A, B, C – which we will learn. For now we
begin with a random SSM, to define sizes,
bilinear
bilinearmethod
To discretize the continuous-time SSM, we use the bilinear method
method,
method which converts the
state matrix A into an approximation A. The discrete SSM is:
bilinear method
:
Ab = BL @ (I + (step / 2.0) * A)
Bb = (BL * step) @ B
return Ab, Bb, C
xk = Axk−1 + B uk
yk = C xk
As the paper says, this “step” function does look superficially like that of an RNN. We can
scan
scan in JAX,
implement this with a scan
Putting everything together, we can run the SSM by first discretizing, then iterating step by step,
# Run recurrence
return scan_SSM(Ab, Bb, Cb, u[:, np.newaxis], np.zeros((N,)))[1]
In this example, we consider the forward position y(t) of a mass attached to a wall with a
spring. Over time, varying force u(t) is applied to this mass. The system is parameterized by
mass (m), spring constant (k ), friction constant (b). We can relate these with the following
differential equation:
′′ ′
( )= ( )− ( )− ()
:
my ′′ (t) = u(t) − by ′ (t) − ky(t)
0 1
A=[ ]
−k/m −b/m
0
B=[ ] C = [1 0]
1/m
Looking at the C , we should be able to convince ourselves that the first dimension of the hidden
state is the position (since that becomes y(t)). The second dimension is the velocity, as it is
impacted by u(t) through B . The transition A relates these terms.
@partial(np.vectorize, signature="()->()")
def example_force(t):
x = np.sin(10 * t)
return x * (x > 0.5)
def example_ssm():
# SSM
ssm = example_mass(k=40, b=5, m=1)
# L samples of u(t).
L = 100
step = 1.0 / L
ks = np.arange(L)
u = example_force(ks * step)
# Approximation of y(t).
y = run_SSM(*ssm, u)
# Plotting ---
import matplotlib.pyplot as plt
import seaborn
:
from celluloid import Camera
seaborn.set_context("paper")
fig, (ax1, ax2, ax3) = plt.subplots(3)
camera = Camera(fig)
ax1.set_title("Force $u_k$")
ax2.set_title("Position $y_k$")
ax3.set_title("Object")
ax1.set_xticks([], [])
ax2.set_xticks([], [])
if False:
example_ssm()
:
Neat! And that it was just 1 SSM, with 2 hidden states over 100 steps. The final model will have
had 100s of stacked SSMs over thousands of steps. But first – we need to make these models
practical to train.
The recurrent SSM is not practical for training on modern hardware due to its sequential
nature. Instead, there is a well-known connection between linear time-invariant (LTI)
SSMs and continuous convolutions. Correspondingly, the recurrent SSM can actually be
written as a discrete
discreteconvolution
discrete convolution
convolution.
convolution
For simplicity let the initial state be x−1 = 0. Then unrolling explicitly yields:
2
x0 = B u0 x1 = AB u0 + B u1 x2 = A B u0 + AB u1 + B u2
2
y0 = CB u0 y1 = CAB u0 + CB u1 y2 = CA B u0 + CAB u1 + CB u2
This can be vectorized into a convolution with an explicit formula for the convolution
kernel.
k k−1
discrete convolution
:
k k−1
yk = CA B u0 + CA B u1 + ⋯ + CAB uk−1 + CB uk
y =K ∗u
L−1
K ∈ RL = (CB , CAB , … , CA B)
Note that this is a giant filter. It is the size of the entire sequence!
Warning: this implementation is naive and unstable. In practice it will fail to work for more than
very small lengths. However, we are going to replace it with S4 in Part 2, so for now we just
keep it around as a placeholder.
We can compute the result of applying this filter either with a standard direct convolution or by
using convolution theorem with Fast
FastFourier
Fast FourierTransform
Fourier (FFT) The discrete convolution
Transform(FFT)
Transform (FFT)
(FFT).
theorem - for circular convolution of two sequences - allows us to efficiently calculate the
output of convolution by first multiplying FFTs of the input sequences and then applying an
inverse FFT. To utilize this theorem for non-circular convolutions as in our case, we need to pad
the input sequences with zeros, and then unpad the output sequence. As the length gets longer
this FFT method will be more efficient than the direct convolution,
The CNN method and the RNN method yield (roughly) the same result,
# CNN
ssmb = discretize(*ssm, step=step)
conv = causal_convolution(u, K_conv(*ssmb, L))
# Check
assert np.allclose(rec.ravel(), conv.ravel())
return init
For the SSM layer most of the work is to build the filter. The actual call to the network is just
the (huge) convolution we specified above.
Note for Torch users: setup in Flax is called each time the parameters are updated. This is
Torch
similar to the Torch parameterizations
Torchparameterizations
parameterizations.
parameterizations
As noted above this same layer can be used either as an RNN or a CNN. The argument decode
determines which path is used. In the case of RNN we cache the previous state at each call in a
Flax variable collection called cache.
class SSMLayer(nn.Module):
N: int
l_max: int
decode: bool = False
def setup(self):
# SSM parameters
Torch =parameterizations
self.A self.param("A", lecun_normal(), (self.N, self.N))
self.B = self.param("B", lecun_normal(), (self.N, 1))
self.C = self.param("C", lecun_normal(), (1, self.N))
:
self.D = self.param("D", nn.initializers.ones, (1,))
# Step parameter
self.log_step = self.param("log_step", log_step_initializer(), (1,))
step = np.exp(self.log_step)
self.ssm = discretize(self.A, self.B, self.C, step=step)
self.K = K_conv(*self.ssm, self.l_max)
Since our SSMs operate on scalars, we make H different, stacked copies (H different SSMs!)
with different parameters. Here we use the Flax
Flaxvmap
Flax vmap method to easily define these copies,
vmap
def cloneLayer(layer):
return nn.vmap(
layer,
in_axes=1,
out_axes=1,
variable_axes={"params": 1, "cache": 1, "prime": 1},
split_rngs={"params": True},
)
SSMLayer = cloneLayer(SSMLayer)
Flax vmap
This SSM Layer can then be put into a standard NN. Here we add a block that pairs a call to an
SSM with dropout and a linear projection.
class SequenceBlock(nn.Module):
layer_cls: nn.Module
layer: dict # Hyperparameters of inner layer
:
dropout: float
d_model: int
prenorm: bool = True
glu: bool = True
training: bool = True
decode: bool = False
def setup(self):
self.seq = self.layer_cls(**self.layer, decode=self.decode)
self.norm = nn.LayerNorm()
self.out = nn.Dense(self.d_model)
if self.glu:
self.out2 = nn.Dense(self.d_model)
self.drop = nn.Dropout(
self.dropout,
broadcast_dims=[0],
deterministic=not self.training,
)
We can then stack a bunch of these blocks on top of each other to produce a stack of SSM layers.
This can be used for classification or generation in the standard way as a Transformer.
class Embedding(nn.Embed):
num_embeddings: int
features: int
@nn.compact
def __call__(self, x):
y = nn.Embed(self.num_embeddings, self.features)(x[..., 0])
return np.where(x > 0, y, 0.0)
:
class StackedModel(nn.Module):
layer_cls: nn.Module
layer: dict # Extra arguments to pass into layer constructor
d_output: int
d_model: int
n_layers: int
prenorm: bool = True
dropout: float = 0.0
embedding: bool = False # Use nn.Embed instead of nn.Dense encoder
classification: bool = False
training: bool = True
decode: bool = False # Probably should be moved into layer_args
def setup(self):
if self.embedding:
self.encoder = Embedding(self.d_output, self.d_model)
else:
self.encoder = nn.Dense(self.d_model)
self.decoder = nn.Dense(self.d_output)
self.layers = [
SequenceBlock(
layer_cls=self.layer_cls,
layer=self.layer,
prenorm=self.prenorm,
d_model=self.d_model,
dropout=self.dropout,
training=self.training,
decode=self.decode,
)
for _ in range(self.n_layers)
]
In Flax we add the batch dimension as a lifted transformation. We need to route through several
:
variable collections which handle RNN and parameter caching (described below).
BatchStackedModel = nn.vmap(
StackedModel,
in_axes=0,
out_axes=0,
variable_axes={"params": None, "dropout": None, "cache": 0, "prime":
None},
split_rngs={"params": False, "dropout": True},
)
Overall, this defines a sequence-to-sequence map of shape (batch size, sequence length, hidden
dimension), exactly the signature exposed by related sequence models such as Transformers,
RNNs, and CNNs.
While we now have our main model, there are two core problems with SSMs. First, the randomly
initialized SSM actually does not perform very well. Furthermore, computing it naively like
we’ve done so far is really slow and memory inefficient. Next, we’ll complete our discussion of
the modeling aspect of S4 by defining a special initialization for long-range dependencies, and
Part222)!
then figure out how to compute this SSM Layer faster – a lot faster (Part
Part
Part 2
training.py
Part 2
Prior
Prior work found that the basic SSM actually performs very poorly in practice.
Priorwork
work
Intuitively, one explanation is that they suffer from gradients scaling exponentially in the
sequence length (i.e., the vanishing/exploding gradients problem). To address this
problem, previous work developed the HiPPO theory of continuous-time memorization.
Previous work found that simply modifying an SSM from a random matrix A to
HiPPO improved its performance on the sequential MNIST classification benchmark
from 60% to 98%.
This matrix is going to be really important, but it is a bit of magic. For our purposes we mainly
need to know that: 1) we only need to calculate it once, and 2) it has a nice, simple structure
(which we will exploit in part 2). Without going into the ODE math, the main takeaway is that
this matrix aims to compress the past history into a state that has enough information to
approximately reconstruct the history.
def make_HiPPO(N):
P = np.sqrt(1 + 2 * np.arange(N))
A = P[:, np.newaxis] * P[np.newaxis, :]
A = np.tril(A) - np.diag(np.arange(N))
return -A
Diving a bit deeper, the intuitive explanation of this matrix is that it produces a hidden state that
Legendre
memorizes its history. It does this by keeping track of the coefficients of a Legendre polynomial
Legendrepolynomial
polynomial.
polynomial
These coefficients let it approximate all of the previous history. Let us look at an example,
def example_legendre(N=8):
# Random hidden state as coefficients
import numpy as np
import numpy.polynomial.legendre
x = (np.random.rand(N) - 0.5) * 2
t = np.linspace(-1, 1, 100)
f = numpy.polynomial.legendre.Legendre(x)(t)
# Plot
import matplotlib.pyplot as plt Legendre polynomial
import seaborn
seaborn.set_context("talk")
fig = plt.figure(figsize=(20, 10))
ax = fig.gca(projection="3d")
ax.plot(
np.linspace(-25, (N - 1) * 100 + 25, 100),
[0] * 100,
:
zs=-1,
zdir="x",
color="black",
)
ax.plot(t, f, zs=N * 100, zdir="y", c="r")
for i in range(N):
coef = [0] * N
coef[N - i - 1] = 1
ax.set_zlim(-4, 4)
ax.set_yticks([])
ax.set_zticks([])
# Plot basis function.
f = numpy.polynomial.legendre.Legendre(coef)(t)
ax.bar(
[100 * i],
[x[i]],
zs=-1,
zdir="x",
label="x%d" % i,
color="brown",
fill=False,
width=50,
)
ax.plot(t, f, zs=100 * i, zdir="y", c="b", alpha=0.5)
ax.view_init(elev=40.0, azim=-45)
fig.savefig("images/leg.png")
if False:
example_legendre()
The red line represents that curve we are approximating, while the black bars represent the
values of our hidden state. Each is a coefficient for one element of the Legendre series shown as
blue functions. The intuition is that the HiPPO matrix updates these coefficients each step.
:
Part 2: Implementing S4
Warning: this section has a lot of math. Roughly it boils down to finding a way to compute the
filter from Part 1 for “HiPPO-like” matrices really fast. If you are interested, the details are
really neat. If not, skip to Part 3 for some cool applications like MNIST completion.
Skip
SkipButton
Skip Button
Button
To set the stage, recall that S4 has two main differences from a basic SSM. The first addresses a
modeling challenge - long-range dependencies - by using a special formula for the A matrix
defined in the previous part. These special SSMs were considered in predecessor
predecessor works to S4.
predecessor
The second main feature of S4 solves the computational challenge of SSMs by introducing a special
representation and algorithm to be able to work with this matrix!
Skip Button
Specifically, recall this function here:
The contribution of S4 is a stable method for speeding up this particular operation. To do this we
are going to focus on the case where the SSM has special structure: specifically, Diagonal Plus
Low-Rank (DPLR) in complex space.
∗
Λ Λ
:
A DPLR SSM is (Λ − P Q∗ , B, C) for some diagonal Λ and matrices P , Q, B, C ∈
CN ×1 . We assume without loss of generality that the rank is 1, i.e. these matrices are vectors.
Under this DPLR assumption, S4 overcomes the speed bottleneck in three steps
3. We show the low-rank term can now be corrected by applying the Woodbury
Woodbury
Woodbury
∗ −1 −1
identity which reduces (Λ + P Q ) in terms of Λ , truly reducing to
identity
identity
the diagonal case.
def gen(z):
return np.sum(K * (z ** np.arange(L)))
return gen
The generating function essentially converts the SSM convolution filter from the time
domain to frequency domain. This transformation is also called z-transform
z-transform (up to a
z-transform
minus sign) in control engineering literature. Importantly, it preserves the same
information, and the desired SSM convolution filter can be recovered. Once the z-
transform of a discrete sequence known, we can obtain the filter’s discrete fourier
:
transform from evaluations of its z-transform
z-transformat
z-transform atthe
at theroots
the rootsof
roots unity Ω
ofunity
of unity= {exp(2π Lk :
k ∈ [L]}. Then, we can apply inverse fourier transformation, stably in O(L log L)
operations by applying an FFT
FFT to recover the filter.
FFT
FFT,
And for all z ∈ ΩL , we have z L = 1 so that term is removed. We then pull this constant term
into a new C . Critically, this function does not call K_conv,
a = conv_from_gen(K_gen_inverse(*ssm, L=L), L)
assert np.allclose(a, b)
In summary, Step 1 allows us to replace the matrix power with an inverse by utilizing a
truncated generating function. However this inverse still needs to be calculated L times (for each
of the roots of unity).
Now imagine A = Λ for a diagonal Λ. Substituting in the discretization formula the authors
show that the generating function can be written in the following manner:
We have effectively replaced an inverse with a weighted dot product. Let’s make a small helper
function to compute this weight dot product for use.
While not important for our implementation, it is worth noting that this is a Cauchy
Cauchykernel
Cauchy kernel and
kernel
is the subject of many other fast
fastimplementations
fast implementations
implementations.
implementations
A = Λ − P Q∗
Woodbury
The Woodbury identity
identity tells us that the inverse of a diagonal plus rank-1 term is equal to the
Woodburyidentity
inverse of the diagonal plus a rank-1 term. We write it out here adding the low-rank term.
The code consists of collecting up the terms and applying 4 weighted dot products,
Woodbury identity
:
def K_gen_DPLR(Lambda, P, Q, B, C, step, unmat=False):
aterm = (C.conj(), Q.conj())
bterm = (B, P)
def gen(o):
g = (2.0 / step) * ((1.0 - o) / (1.0 + o))
c = 2.0 / (1.0 + o)
def k(a):
# Checkpoint this calculation for memory efficiency.
if unmat:
return jax.remat(cauchy_dot)(a, g, Lambda)
else:
return cauchy_dot(a, g, Lambda)
return gen
This is our final version of the K function. Because conv_from_gen is always called together
with a generating function (e.g. K_gen_DPLR), we’ll fuse them into define a dedicated function to
compute the DPLR SSM kernel from all of its parameters. (With fewer layers of indirection,
this could also make it easier for XLA compiler to optimize.)
@jax.jit
def cauchy(v, omega, lambd):
"""Cauchy matrix multiplication: (n), (l), (n) -> (l)"""
cauchy_dot = lambda _omega: (v / (_omega - lambd)).sum()
return jax.vmap(cauchy_dot)(omega)
Now we can check whether it worked. First, let’s generate a random Diagonal Plus Low Rank
(DPLR) matrix,
We can check that the DPLR method yields the same filter as computing A directly,
We simplify both terms in the definition of A independently. The first term is:
Δ Δ
I+ A = I + (Λ − P Q∗ )
2 2
Δ 2
= [ I + (Λ − P Q∗ )]
2 Δ
Δ
= A0
2
The second term is known as the Backward Euler’s method. Although this inverse term
is normally difficult to deal with, in the DPLR case we can simplify it using Woodbury’s
Identity as described above.
−1 −1
Δ Δ
(I − A) = (I − (Λ − P Q ))
∗
2 2
−1
2 2
= [ − Λ + PQ ]∗
Δ Δ
2
[D − DP (1 + Q∗ DP ) Q∗ D ]
−1
=
Δ
2
= A1
Δ
−1
where D = ( Δ2 − Λ) and A1 is defined as the term in the final brackets.
xk = Axk−1 + B uk
= A1 A0 xk−1 + 2A1 Buk
yk = Cxk .
# Forward Euler
A0 = (2.0 / step) * I + A
# Backward Euler
D = np.diag(1.0 / ((2.0 / step) - Lambda))
Qc = Q.conj().T.reshape(1, -1)
P2 = P.reshape(-1, 1)
A1 = D - (D @ P2 * (1.0 / (1 + (Qc @ D @ P2))) * Qc @ D)
The S4 techniques can apply to any matrix A that can be decomposed as Normal Plus
Low-Rank (NPLR).
A = V ΛV ∗ − P Q⊤ = V (Λ − V ∗ P (V ∗ Q)∗ ) V ∗
normal
For S4, we need to work with a HiPPO matrix for A. This requires first writing it as a normal
plus low-rank term, and then diagonalizing to extract Λ from this decomposition. The appendix
of the paper shows how by writing the normal part as a skew-symmetric
skew-symmetric (plus a constant times
skew-symmetric
the identity matrix), which are a special class of normal matrices.
An additional simplification is that there is actually a representation that ties the low-rank
unitary
:
components terms P = Q, which was shown in follow-up
follow-upwork
follow-up work to be important for stability.
work
def make_NPLR_HiPPO(N):
# Make -HiPPO
nhippo = make_HiPPO(N)
follow-up work
After extracting the normal part, we can diagonalize to get out the DPLR terms. Because the
normal part is actually skew-symmetric, we can extract the real and complex parts of Λ
separately. This serves two purposes. First, this gives us finer-grained control over the real and
imaginary parts, which can be used to improve stability. Second, this lets us use more powerful
diagonalization algorithms for Hermitian
Hermitian matrices - in fact, the current version of JAX does not
Hermitianmatrices
matrices
support GPU diagonalization for non-Hermitian matrices!
def make_DPLR_HiPPO(N):
"""Diagonalize NPLR representation"""
A, P, B = make_NPLR_HiPPO(N)
P = V.conj().T @ P
B = V.conj().T @ B
return Lambda_real + 1j * Lambda_imag, P, B, V
def test_nplr(N=8):
A2, P, B = make_NPLR_HiPPO(N)
Lambda, Pc, Bc, V = make_DPLR_HiPPO(N)
Vc = V.conj().T
P = P[:, np.newaxis]
:
Pc = Pc[:, np.newaxis]
Lambda = np.diag(Lambda)
Final Check
This tests that everything works as planned.
# CNN form.
K = kernel_DPLR(Lambda, P, P, B, C, step, L)
# RNN form.
Ab, Bb, Cb = discrete_DPLR(Lambda, P, P, B, C, step, L)
K2 = K_conv(Ab, Bb, Cb, L=L)
assert np.allclose(K.real, K2.real, atol=1e-5, rtol=1e-5)
# Apply CNN
u = np.arange(L) * 1.0
y1 = causal_convolution(u, K.real)
# Apply RNN
_, y2 = scan_SSM(
Ab, Bb, Cb, u[:, np.newaxis], np.zeros((N,)).astype(np.complex64)
)
assert np.allclose(y1, y2.reshape(-1).real, atol=1e-4, rtol=1e-4)
Part 3: S4 in Practice
That was a lot of work, but now the actual model is concise. In fact we are only using four
functions:
class S4Layer(nn.Module):
N: int
l_max: int
decode: bool = False
def setup(self):
# Learned Parameters (C is complex!)
init_A_re, init_A_im, init_P, init_B = hippo_initializer(self.N)
self.Lambda_re = self.param("Lambda_re", init_A_re, (self.N,))
self.Lambda_im = self.param("Lambda_im", init_A_im, (self.N,))
# Ensure the real part of Lambda is negative
# (described in the SaShiMi follow-up to S4)
self.Lambda = np.clip(self.Lambda_re, None, -1e-4) + 1j *
self.Lambda_im
self.P = self.param("P", init_P, (self.N,))
self.B = self.param("B", init_B, (self.N,))
# C should be init as standard normal
# This doesn't work due to how JAX handles complex optimizers
https://github.com/deepmind/optax/issues/196
# self.C = self.param("C", normal(stddev=1.0, dtype=np.complex64),
(self.N,))
self.C = self.param("C", normal(stddev=0.5**0.5), (self.N, 2))
self.C = self.C[..., 0] + 1j * self.C[..., 1]
self.D = self.param("D", nn.initializers.ones, (1,))
self.step = np.exp(self.param("log_step", log_step_initializer(),
(1,)))
if not self.decode:
:
# CNN mode, compute kernel.
self.K = kernel_DPLR(
self.Lambda,
self.P,
self.P,
self.B,
self.C,
self.step,
self.l_max,
)
else:
# RNN mode, discretize
# RNN Cache
self.x_k_1 = self.variable(
"cache", "cache_x_k", np.zeros, (self.N,), np.complex64
)
return _init
def hippo_initializer(N):
Lambda, P, B, _ = make_DPLR_HiPPO(N)
return init(Lambda.real), init(Lambda.imag), init(P), init(B)
x = jax.vmap(update)(x, out)
return x, rng, vars["cache"].unfreeze()
To get this in a good form, we first precompute the discretized version of the the RNN for each
S4 layers. We do this through the “prime” collection of variables.
:
def init_recurrence(model, params, init_x, rng):
variables = model.init(rng, init_x)
vars = {
"params": params,
"cache": variables["cache"].unfreeze(),
"prime": variables["prime"].unfreeze(),
}
print("[*] Priming")
_, prime_vars = model.apply(vars, init_x, mutable=["prime"])
return vars["params"], prime_vars["prime"], vars["cache"]
Experiments: MNIST
Now that we have the model, we can try it out on some MNIST experiments. For these
experiments we linearize MNIST and just treat each image as a sequence of pixels.
The first experiments we ran were on MNIST classification. While not in theory a hard
problem, treating MNIST as a linear sequence classification task is a bit strange. However in
practice, the model with H = 256 and four layers seems to get up near 99% right away.
A more visually interesting task is generating MNIST digits, by predicting entire sequences of
pixels! Here, we simply feed in a sequence of pixels into the model and have it predict the next
one like language modeling. With a little tweaking, we are able to get the model to an NLL of
0.36 on this task with size 512 and 6 layers (~4m parameters).
# _, dataloader, _, _, _ = Datasets["mnist"](bsz=BATCH)
it = iter(dataloader)
for j, im in enumerate(it):
if n_batches is not None and j >= n_batches:
break
image = im[0].numpy()
image = np.pad(
image[:, :-1, :], [(0, 0), (1, 0), (0, 0)], constant_values=0
)
cur = onp.array(image)
# cur[:, START + 1 :, 0] = 0
# cur = np.pad(cur[:, :-1, 0], [(0, 0), (1, 0)],
constant_values=256)
cur = np.array(cur[:, :])
# Visualization
out = out.reshape(BATCH, *imshape)
final = onp.zeros((BATCH, *imshape, 3))
final2 = onp.zeros((BATCH, *imshape, 3))
final[:, :, :, 0] = out
f = final.reshape(BATCH, LENGTH, 3)
i = image.reshape(BATCH, LENGTH)
f[:, :START, 1] = i[:, :START]
f[:, :START, 2] = i[:, :START]
f = final2.reshape(BATCH, LENGTH, 3)
f[:, :, 1] = i
f[:, :START, 0] = i[:, :START]
f[:, :START, 2] = i[:, :START]
if save:
for k in range(BATCH):
fig, (ax1, ax2) = plt.subplots(ncols=2)
ax1.set_title("Sampled")
:
ax1.imshow(final[k] / 256.0)
ax2.set_title("True")
ax1.axis("off")
ax2.axis("off")
ax2.imshow(final2[k] / 256.0)
fig.savefig("im%d.%d.png" % (j, k))
plt.close()
print(f"Sampled batch {j} image {k}")
return final, final2
Experiments: QuickDraw
QuickDraw
QuickDrawdataset
Next we tried training a model to generate drawings. For this we used the QuickDraw dataset
dataset.
dataset
The dataset includes a version of the dataset downsampled to MNIST size so we can use roughly
the same model as above. The dataset is much larger though (5M images) and more complex.
We only trained for 1 epoch with a H = 256, 4 layer model. Still, the approach was able to
generate relatively coherent completions. These are prefix samples with 500 pixels given.
QuickDraw dataset
:
Experiments: Spoken Digits
Free
FreeSpoken
Finally we played with modeling sound waves directly. For these, we use the Free Spoken
SpokenDigits
Digits
Digits
Datasets
Datasets an MNIST like dataset of various speakers reading off digits. We first trained a
Datasets
classification model and found that the approach was able to reach 97% accuracy just from the
raw soundwave. Next we trained a generation model to produce the sound wave directly. With
H = 512 the model seems to pick up the data relatively well. This dataset only has around
3000 examples, but the model can produce reasonably good (cherry-picked) continuations. Note
Mu
Mu
these sequences are 6400 steps long at an 8kHz sampling rate, discretized to 256 classes with Mu
Law
LawEncoding
Law Encoding
Encoding.
Encoding
Mu
Law Encoding
full
fullcode
Our full code
codebase
base
base contains more examples and infrastructure for training models for generations
and classification.
Conclusion
Putting together this post inspired lots of thoughts about future work in this area. One obvious
conclusion is that long-range models have all sorts of future applications from acoustic modeling
to genomic sequences to trajectories (not to mention our shared area of NLP). Another is some
surprise that linear models can be so effective here, while also opening up a range of efficient
techniques. Finally from a practical level, the transformations in JAX make it really nice to
implement complex models like this in a very concise way (~200 LoC), with similar efficiency
and performance!
full code base
Albert
AlbertGu
We end by thanking the authors Albert Gu
Gu and Karan
Karan
KaranGoel
Goel
Goel,
Goel who were super helpful in putting
paper
paper and codebase
this together, and pointing you again to their paper codebase
codebase.
codebase Thanks to Ankit Gupta, Ekin
Akyürek, Qinsheng Zhang, Nathan Yan, and Junxiong Wang for contributions. We’re also
grateful for Conner Vercellino and Laurel Orr for providing helpful feedback on this post.
Changelog
v3
Major editing pass from Albert.
Fix bug in HiPPO calculation.
Albert Gu Karan Goel
Added training of all S4 parameters.
paper codebase
Fix learning rate / initialization issues.
:
v2
Added RNN decoding
Added Speech examples
v1 - Original version
: