Marginalizing Variables
# !pip install pzflow matplotlib
Marginalization during posterior calculation¶
This example notebook demonstrates how to marginalize over missing variables during posterior calculation. We will use the Flow trained in the redshift example.
import jax.numpy as jnp
import matplotlib.pyplot as plt
from pzflow.examples import get_example_flow
plt.rcParams["figure.facecolor"] = "white"
First let's load the pre-trained flow, and use it to generate some samples:
flow = get_example_flow()
samples = flow.sample(2, seed=0)
samples
redshift | u | g | r | i | z | y | |
---|---|---|---|---|---|---|---|
0 | 0.386102 | 28.406404 | 27.508736 | 26.419685 | 26.196583 | 25.901255 | 25.889046 |
1 | 2.118688 | 28.220659 | 27.798203 | 27.371502 | 27.211376 | 26.869678 | 26.502558 |
Remember that we can calculate posteriors for the data in samples. For example, let's plot redshift posteriors:
grid = jnp.linspace(0.1, 2.4, 100)
pdfs = flow.posterior(samples, column="redshift", grid=grid)
fig, axes = plt.subplots(1, 2, figsize=(4, 2), dpi=120, constrained_layout=True)
for i, ax in enumerate(axes.flatten()):
ax.plot(grid, pdfs[i], label="Redshift posterior")
ztrue = samples["redshift"][i]
ax.axvline(ztrue, c="C3", label="True redshift")
ax.set(
xlabel="redshift",
xlim=(ztrue - 0.25, ztrue + 0.25),
yticks=[]
)
axes[0].legend(
bbox_to_anchor=(0.25, 1.05, 1.5, 0.2),
loc="lower left",
mode="expand",
borderaxespad=0,
ncol=2,
fontsize=8,
)
plt.show()
But what if we have missing values? E.g. let's imagine that galaxy 1 wasn't observed in the u band, while galaxy 2 wasn't observed in the u or y bands. We will mark these non-observations with the value 99:
# make a new copy of the samples
samples2 = samples.copy()
# make the non-observations
samples2.iloc[0, 1] = 99
samples2.iloc[1, 1] = 99
samples2.iloc[1, -1] = 99
# print the new samples
samples2
redshift | u | g | r | i | z | y | |
---|---|---|---|---|---|---|---|
0 | 0.386102 | 99.0 | 27.508736 | 26.419685 | 26.196583 | 25.901255 | 25.889046 |
1 | 2.118688 | 99.0 | 27.798203 | 27.371502 | 27.211376 | 26.869678 | 99.000000 |
Now if we want to calculate posteriors, we can't simply call flow.posterior()
as before because the flow will think that 99 is the actual value for those bands, rather than just a flag for a missing value. What we can do, however, is pass marg_rules
, which is a dictionary of rules that tells the Flow how to marginalize over missing variables.
marg_rules
must include:
- "flag": 99, which tells the posterior method that 99 is the flag for missing values
- "u": callable, which returns an array of values for the u band over which to marginalize
- "y": callable, which returns an array of values for the y band over which to marginalize
"u" and "y" both map to callable, because you can use a function of the other values to decide what values of u and y to marginalize over. For example, maybe you expect the value of u to be close to the value of g. In which case you might use:
"u": lambda row: np.linspace(row["g"] - 1, row["g"] + 1, 100)
The only constraint is that regardless of the values of the other variables, the callable must always return an array of the same length.
For this example, we won't make the marginalization rules a function of the other variables, but will instead return a fixed array.
marg_rules = {
"flag": 99, # tells the posterior method that 99 means missing value
"u": lambda row: jnp.linspace(27, 29, 40), # the array of u values to marginalize over
"y": lambda row: jnp.linspace(25, 27, 40), # the array of y values to marginalize over
}
pdfs2 = flow.posterior(samples2, column="redshift", grid=grid, marg_rules=marg_rules)
fig, axes = plt.subplots(1, 2, figsize=(4, 2), dpi=120, constrained_layout=True)
for i, ax in enumerate(axes.flatten()):
ax.plot(grid, pdfs[i], label="Posterior w/ all bands")
ax.plot(grid, pdfs2[i], label="Posterior w/ missing bands marginalized")
ztrue = samples["redshift"][i]
ax.axvline(ztrue, c="C3", label="True redshift")
ax.set(
xlabel="redshift",
xlim=(ztrue - 0.25, ztrue + 0.25),
yticks=[]
)
axes[0].legend(
bbox_to_anchor=(0, 1.05, 2, 0.2),
loc="lower left",
mode="expand",
borderaxespad=0,
ncol=2,
fontsize=7.5,
)
plt.show()
You can see that for the low-redshift galaxy, throwing away the $u$ band has almost no impact. But for the high-redshift galaxy, throwing away the $u$ band biases the posterior towards higher redshifts.
Warning that marginalizing over fine grids quickly gets very computationally expensive, especially when you have rows in your data frame that are missing multiple values.