{ "cells": [ { "cell_type": "markdown", "id": "8896a6c7", "metadata": {}, "source": [ "# Distilling Existing Datasets into Priors\n", "\n", "In Bayesian Regression, the Prior distribution encodes prior knowledge about the model weights, before that model has been conditioned on the training data.The Prior and training data is then used to form the Posterior distribution, which gives model weights which are compatible with both the training data and the prior.\n", "\n", "Further conditioning a Posterior distribution on a second set of training data is identical to using that Posterior distribution as a Prior for this second Posterior. In this way, we can \"distill\" existing datasets into good Prior distributions for the committee models used for selection.\n", "\n", "To start, let's construct a committee calculator in the same manner as the previous example:" ] }, { "cell_type": "code", "execution_count": 1, "id": "f33c266e", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/u2180064/.local/lib/python3.10/site-packages/e3nn/o3/_wigner.py:10: UserWarning: Environment variable TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD detected, since the`weights_only` argument was not explicitly passed to `torch.load`, forcing weights_only=False.\n", " _Jd, _W3j_flat, _W3j_indices = torch.load(os.path.join(os.path.dirname(__file__), 'constants.pt'))\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "cuequivariance or cuequivariance_torch is not available. Cuequivariance acceleration will be disabled.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/u2180064/.local/lib/python3.10/site-packages/matplotlib/projections/__init__.py:63: UserWarning: Unable to import Axes3D. This may be due to multiple versions of Matplotlib being installed (e.g. as a system package and as a pip package). As a result, the 3D projection is not available.\n", " warnings.warn(\"Unable to import Axes3D. This may be due to multiple versions of \"\n", "hwloc/linux: Ignoring PCI device with non-16bit domain.\n", "Pass --enable-32bits-pci-domain to configure to support such devices\n", "(warning: it would break the library ABI, don't enable unless really needed).\n", "hwloc/linux: Ignoring PCI device with non-16bit domain.\n", "Pass --enable-32bits-pci-domain to configure to support such devices\n", "(warning: it would break the library ABI, don't enable unless really needed).\n", "/home/u2180064/.local/lib/python3.10/site-packages/mace/calculators/mace.py:197: UserWarning: Environment variable TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD detected, since the`weights_only` argument was not explicitly passed to `torch.load`, forcing weights_only=False.\n", " torch.load(f=model_path, map_location=device)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Using medium MPA-0 model as default MACE-MP model, to use previous (before 3.10) default model please specify 'medium' as model argument\n", "Using Materials Project MACE for MACECalculator with /home/u2180064/.cache/mace/macempa0mediummodel\n", "Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization.\n", "Using head default out of ['default']\n", "Default dtype float32 does not match model dtype float64, converting models to float32.\n" ] } ], "source": [ "# Imports\n", "import numpy as np\n", "from mace.calculators import mace_mp\n", "from ase_uhal.committee_calculators import MACEHALCalculator\n", "from ase_uhal.bias_calculators import HALBiasCalculator\n", "\n", "# normal MACE MPA medium model calculator (from mace_torch)\n", "mace_calc = mace_mp(\"medium-mpa-0\")\n", "\n", "comm_calc = MACEHALCalculator(\n", " # MACE architecture to use for descriptor evaluation\n", " mace_calc, \n", " # Number of committee members to draw\n", " committee_size=20,\n", " # Weight of the prior in the linear system to find committee weights\n", " prior_weight=0.1,\n", " # Weights on energy and force observations when sampling new structures\n", " energy_weight=1, forces_weight=100,\n", " # Option to use a lower memory strategy to fit the linear system\n", " lowmem=False,\n", " # Lower memory overhead by evaluating descriptor forces in batches \n", " # (Specific to MACE-based calculators)\n", " batch_size=8,\n", " # Seeding for random number generation\n", " rng=np.random.RandomState(42))\n", "\n", "comm_calc.resample_committee() #Build the initial committee " ] }, { "cell_type": "markdown", "id": "030b1e67", "metadata": {}, "source": [ "Now, let's invent some initial training data, by applying strain to a bulk system:" ] }, { "cell_type": "code", "execution_count": 2, "id": "8c07bd2b", "metadata": {}, "outputs": [], "source": [ "from ase.build import bulk\n", "\n", "base_ats = bulk(\"Fe\", cubic=True)\n", "base_cell = base_ats.cell[:, :]\n", "\n", "db_strains = np.linspace(0.9, 1.1, 10)\n", "\n", "dataset = []\n", "\n", "for strain in db_strains:\n", " ats = base_ats.copy()\n", " # Apply volumetric strain\n", " cell = base_cell.copy() * strain\n", " ats.set_cell(cell, scale_atoms=True)\n", "\n", " dataset.append(ats)\n", "\n", "# Also include the bulk, but let's put the weights on this 10x higher\n", "ats = base_ats.copy()\n", "ats.info[\"total_weight\"] = 10\n", "dataset.append(ats)" ] }, { "cell_type": "markdown", "id": "e437665f", "metadata": {}, "source": [ "And now let's use this dataset to produce a prior:" ] }, { "cell_type": "code", "execution_count": 3, "id": "b7d6d1ec", "metadata": {}, "outputs": [], "source": [ "\n", "from ase_uhal.distillation import distill_dataset\n", "\n", "sqrt_prior = distill_dataset(\n", " # Dataset to distill\n", " dataset=dataset,\n", " # Committee calculator to get descriptors from\n", " calc = comm_calc,\n", " # The distillation process will scan the atoms.info dict for this key, \n", " # and use it as an overall weight multiplier (1 is used when the key is missing)\n", " total_weight_key=\"total_weight\"\n", ")\n", "\n", "# We could now save this prior to disk, to save needing to regenerating it each time we want to use the sampling\n", "# np.save(\"Fe_Prior\", sqrt_prior)\n", "# and then load it back with\n", "# sqrt_prior = np.load(\"Fe_Prior.npz\")" ] }, { "cell_type": "markdown", "id": "74d93916", "metadata": {}, "source": [ "What gets returned is the square root of the Posterior covariance matrix, which is also the square root of the new Prior matrix. As ase_uhal uses zero-mean committees, we do not need any information about the Posterior mean (which would also need energies and forces).\n", "\n", "## Using Priors in new Committee Calculators\n", "\n", "We can now create a new commitee calculator, using this new Prior:" ] }, { "cell_type": "code", "execution_count": 4, "id": "25abc137", "metadata": {}, "outputs": [], "source": [ "new_comm_calc = MACEHALCalculator(\n", " # MACE architecture to use for descriptor evaluation\n", " mace_calc, \n", " # Number of committee members to draw\n", " committee_size=20,\n", " # Weight of the prior in the linear system to find committee weights\n", " prior_weight=0.1,\n", " # Weights on energy and force observations when sampling new structures\n", " energy_weight=1, forces_weight=100,\n", " # Option to use a lower memory strategy to fit the linear system\n", " lowmem=False,\n", " # Lower memory overhead by evaluating descriptor forces in batches \n", " # (Specific to MACE-based calculators)\n", " batch_size=8,\n", " # Seeding for random number generation\n", " rng=np.random.RandomState(42),\n", " # Now add in the sqrt of the prior\n", " sqrt_prior=sqrt_prior)\n", "\n", "new_comm_calc.resample_committee()" ] }, { "cell_type": "markdown", "id": "97205b9f", "metadata": {}, "source": [ "To show the impact of the Prior, let's compare the bias forces of both committee calculators " ] }, { "cell_type": "code", "execution_count": 5, "id": "8fb88baa", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Bias Forces for Strained Fe Bulk:\n", "No Prior: 0.0012763528\n", "With Prior: 0.010700304\n" ] } ], "source": [ "ats = dataset[0]\n", "# Very small rattle, to break symmetry\n", "ats.rattle(1e-4)\n", "\n", "bulk_bias_force_no_prior = np.max(np.linalg.norm(comm_calc.get_property(\"bias_forces\", ats), axis=-1))\n", "bulk_bias_force_prior = np.max(np.linalg.norm(new_comm_calc.get_property(\"bias_forces\", ats), axis=-1))\n", "\n", "print(\"Bias Forces for Strained Fe Bulk:\")\n", "print(\"No Prior: \", bulk_bias_force_no_prior)\n", "print(\"With Prior:\", bulk_bias_force_prior)" ] }, { "cell_type": "markdown", "id": "7d418f6e", "metadata": {}, "source": [ "We can see that the distilled prior has increased the bias forces. This means that, assuming that the bias strength tau is constant, the committee with the improved prior will more strongly avoid exploring the minimised bulk configuration, which is already included in our original dataset." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }