Distilling Existing Datasets into Priors#
In Bayesian Regression, the Prior distribution encodes prior knowledge about the model weights, before that model has been conditioned on the training data.The Prior and training data is then used to form the Posterior distribution, which gives model weights which are compatible with both the training data and the prior.
Further conditioning a Posterior distribution on a second set of training data is identical to using that Posterior distribution as a Prior for this second Posterior. In this way, we can “distill” existing datasets into good Prior distributions for the committee models used for selection.
To start, let’s construct a committee calculator in the same manner as the previous example:
[1]:
# Imports
import numpy as np
from mace.calculators import mace_mp
from ase_uhal.committee_calculators import MACEHALCalculator
from ase_uhal.bias_calculators import HALBiasCalculator
# normal MACE MPA medium model calculator (from mace_torch)
mace_calc = mace_mp("medium-mpa-0")
comm_calc = MACEHALCalculator(
# MACE architecture to use for descriptor evaluation
mace_calc,
# Number of committee members to draw
committee_size=20,
# Weight of the prior in the linear system to find committee weights
prior_weight=0.1,
# Weights on energy and force observations when sampling new structures
energy_weight=1, forces_weight=100,
# Option to use a lower memory strategy to fit the linear system
lowmem=False,
# Lower memory overhead by evaluating descriptor forces in batches
# (Specific to MACE-based calculators)
batch_size=8,
# Seeding for random number generation
rng=np.random.RandomState(42))
comm_calc.resample_committee() #Build the initial committee
/opt/hostedtoolcache/Python/3.12.13/x64/lib/python3.12/site-packages/e3nn/o3/_wigner.py:10: UserWarning: Environment variable TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD detected, since the`weights_only` argument was not explicitly passed to `torch.load`, forcing weights_only=False.
_Jd, _W3j_flat, _W3j_indices = torch.load(os.path.join(os.path.dirname(__file__), 'constants.pt'))
cuequivariance or cuequivariance_torch is not available. Cuequivariance acceleration will be disabled.
/opt/hostedtoolcache/Python/3.12.13/x64/lib/python3.12/site-packages/mace/calculators/mace.py:199: UserWarning: Environment variable TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD detected, since the`weights_only` argument was not explicitly passed to `torch.load`, forcing weights_only=False.
torch.load(f=model_path, map_location=device)
WARNING:root:Default dtype float32 does not match model dtype float64, converting models to float32.
Using medium MPA-0 model as default MACE-MP model, to use previous (before 3.10) default model please specify 'medium' as model argument
Using Materials Project MACE for MACECalculator with /home/runner/.cache/mace/macempa0mediummodel
Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization.
Now, let’s invent some initial training data, by applying strain to a bulk system:
[2]:
from ase.build import bulk
base_ats = bulk("Fe", cubic=True)
base_cell = base_ats.cell[:, :]
db_strains = np.linspace(0.9, 1.1, 10)
dataset = []
for strain in db_strains:
ats = base_ats.copy()
# Apply volumetric strain
cell = base_cell.copy() * strain
ats.set_cell(cell, scale_atoms=True)
dataset.append(ats)
# Also include the bulk, but let's put the weights on this 10x higher
ats = base_ats.copy()
ats.info["total_weight"] = 10
dataset.append(ats)
And now let’s use this dataset to produce a prior:
[3]:
from ase_uhal.distillation import distill_dataset
sqrt_prior = distill_dataset(
# Dataset to distill
dataset=dataset,
# Committee calculator to get descriptors from
calc = comm_calc,
# The distillation process will scan the atoms.info dict for this key,
# and use it as an overall weight multiplier (1 is used when the key is missing)
total_weight_key="total_weight"
)
# We could now save this prior to disk, to save needing to regenerating it each time we want to use the sampling
# np.save("Fe_Prior", sqrt_prior)
# and then load it back with
# sqrt_prior = np.load("Fe_Prior.npz")
What gets returned is the square root of the Posterior covariance matrix, which is also the square root of the new Prior matrix. As ase_uhal uses zero-mean committees, we do not need any information about the Posterior mean (which would also need energies and forces).
Using Priors in new Committee Calculators#
We can now create a new commitee calculator, using this new Prior:
[4]:
new_comm_calc = MACEHALCalculator(
# MACE architecture to use for descriptor evaluation
mace_calc,
# Number of committee members to draw
committee_size=20,
# Weight of the prior in the linear system to find committee weights
prior_weight=0.1,
# Weights on energy and force observations when sampling new structures
energy_weight=1, forces_weight=100,
# Option to use a lower memory strategy to fit the linear system
lowmem=False,
# Lower memory overhead by evaluating descriptor forces in batches
# (Specific to MACE-based calculators)
batch_size=8,
# Seeding for random number generation
rng=np.random.RandomState(42),
# Now add in the sqrt of the prior
sqrt_prior=sqrt_prior)
new_comm_calc.resample_committee()
To show the impact of the Prior, let’s compare the bias forces of both committee calculators
[5]:
ats = dataset[0]
# Very small rattle, to break symmetry
ats.rattle(1e-4)
bulk_bias_force_no_prior = np.max(np.linalg.norm(comm_calc.get_property("bias_forces", ats), axis=-1))
bulk_bias_force_prior = np.max(np.linalg.norm(new_comm_calc.get_property("bias_forces", ats), axis=-1))
print("Bias Forces for Strained Fe Bulk:")
print("No Prior: ", bulk_bias_force_no_prior)
print("With Prior:", bulk_bias_force_prior)
Bias Forces for Strained Fe Bulk:
No Prior: 0.0012520942
With Prior: 0.0045943237
We can see that the distilled prior has increased the bias forces. This means that, assuming that the bias strength tau is constant, the committee with the improved prior will more strongly avoid exploring the minimised bulk configuration, which is already included in our original dataset.