Dequantization
Dequantization example¶
Normalizing flows are designed to model continuous variables, however they can be easily adapted to model discrete data as well. This is achieved by inserting a "dequantizer" into the bijector chain. We will demonstrate the use of a Uniform Dequantizer to model checkerboard data.
(Note this requires pzflow >= 3.4.0)
import numpy as np
import matplotlib.pyplot as plt
from pzflow import Flow
from pzflow.bijectors import (
Chain,
ShiftBounds,
RollingSplineCoupling,
Beta13Dequantizer,
)
from pzflow.examples import get_checkerboard_data
PZFlow includes example data that is discrete in a checkerboard pattern:
data = get_checkerboard_data()
data
| x | y | |
|---|---|---|
| 0 | 2 | 0 |
| 1 | 3 | 1 |
| 2 | 1 | 1 |
| 3 | 1 | 3 |
| 4 | 3 | 3 |
| ... | ... | ... |
| 99995 | 3 | 3 |
| 99996 | 0 | 2 |
| 99997 | 2 | 2 |
| 99998 | 0 | 2 |
| 99999 | 0 | 2 |
100000 rows × 2 columns
# Let's plot this distribution
fig, ax = plt.subplots(figsize=(3, 3), dpi=150)
R = 4
ax.hist2d(data["x"], data["y"], bins=R, range=((0, R), (0, R)))
ax.set(xlabel="x", ylabel="y", yticks=np.arange(R))
plt.show()
Let's see how the default normalizing flow performs with this data:
flow = Flow(data.columns)
losses = flow.train(data, verbose=True)
# losses = flow.train(data, epochs=35, optimizer=adam(1e-5), verbose=True)
# losses += flow.train(data, epochs=20, optimizer=adam(1e-6), verbose=True)
Training 100 epochs Loss: (0) 12.9477 (1) -1.6373 (6) -2.0290 (11) -1.9060 (16) -2.6263 (21) -0.8358 (26) 23.9592 (31) -0.6474 (36) -1.1265 (41) 0.3105 (46) -0.8193 (51) 0.5506 (56) 0.3735 Training stopping after epoch 57 because training loss diverged.
First, let's plot the training loss:
# plot the training losses
plt.plot(losses)
[<matplotlib.lines.Line2D at 0x12f001c70>]
This doesn't look good...
Let's also plot some samples and compare to the truth data.
# plot some samples
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(6.2, 6.2), dpi=150)
ax1.hist2d(data["x"], data["y"], bins=R, range=((0, R), (0, R)))
ax1.set(ylabel="y", title="Truth", yticks=np.arange(5))
ax3.scatter(data["x"], data["y"], marker=".", c="k")
ax3.set(xlabel="x", ylabel="y", yticks=np.arange(5))
samples = flow.sample(data.shape[0])
ax2.hist2d(samples["x"], samples["y"], bins=R, range=((0, R), (0, R)))
ax2.set(title="Samples", yticks=np.arange(5))
ax4.scatter(samples["x"], samples["y"], marker=".", c="k", s=1)
ax4.set(xlabel="x", ylabel="y", yticks=np.arange(5))
plt.show()
Well that's weird... Clearly the default flow does a really bad job with discrete data!
We can do much better if we add a dequantizer to the bijector chain.
The dequantizer will add noise in the range (0, 1) to each variable, and thereby smooth out the distribution.
This allows the RollingSplineCoupling to perform better.
When sampling, everything acts in reverse, so the dequantizer will quantize the samples, resulting in a discrete distribution.
(Note the dequantizer isn't technically bijective, but that's okay)
# we build the exact same flow, except we add a dequantizer to the bijector chain
bijector = Chain(
Beta13Dequantizer([0, 1]), # dequantize the data
ShiftBounds(0, 4, B=5), # shift bounds of data from (0, 4) -> (-5, 5)
RollingSplineCoupling(nlayers=2, B=5), # transform distribution
)
dq_flow = Flow(data.columns, bijector)
dq_losses = dq_flow.train(data, verbose=True) # , optimizer=adam(1e-4))
Training 100 epochs Loss: (0) 11.9246 (1) 1.8581 (6) 1.2172 (11) 1.1780 (16) 1.1786 (21) 1.1829 (26) 1.1699 (31) 1.1782 (36) 1.1789 (41) 1.1777 (46) 1.1894 (51) 1.1761 (56) 1.1737 (61) 1.1806 (66) 1.1732 (71) 1.1709 (76) 1.1745 (81) 1.1774 (86) 1.1816 (91) 1.1747 (96) 1.1733 (100) 1.1713
# plot the training losses
plt.plot(dq_losses)
plt.yscale("log")
plt.xlabel("Epoch")
plt.ylabel("Log[loss]")
Text(0, 0.5, 'Log[loss]')
That went much smoother! Let's look at the samples
# plot some samples
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(6.2, 6.2), dpi=150)
ax1.hist2d(data["x"], data["y"], bins=R, range=((0, R), (0, R)))
ax1.set(ylabel="y", title="Truth", yticks=np.arange(5))
ax3.scatter(data["x"], data["y"], marker=".", c="k")
ax3.set(xlabel="x", ylabel="y", yticks=np.arange(4))
dq_samples = dq_flow.sample(data.shape[0])
ax2.hist2d(dq_samples["x"], dq_samples["y"], bins=R, range=((0, R), (0, R)))
ax2.set(title="Samples", yticks=np.arange(5))
ax4.scatter(dq_samples["x"], dq_samples["y"], marker=".", c="k")
ax4.set(xlabel="x", ylabel="y", yticks=np.arange(4))
plt.show()
We can see that this flow produces only discrete samples that lie on the grid, and, while it did produce some samples in the empty grid points, the histogram shows these are exceedingly rare and that overall the distribution of samples closely matches that of the training data. In addition, the training was a lot easier - we could just do a single round of training with default settings, rather than needing to adjust the learning rate schedule.
One more thing - You can model discrete and continuous variables side-by-side in the same normalizing flow! Just drop in Beta13Dequantizer with column_idx equal to the list of column indices corresponding to the discrete variables.
For example, suppose we want to model data with column names ["a", "b", "c", "d"], and that "a" and "c" are continuous variables, while "b" and "d" are discrete variables. We can use a bijector like this:
bijector = Chain(
Beta13Dequantizer(column_idx=[1, 3]),
...,
)