from ase.atoms import Atoms
from .torch_committee_calculator import TorchCommitteeCalculator
import numpy as np
from abc import ABCMeta, abstractmethod
class BaseMACECalculator(TorchCommitteeCalculator, metaclass=ABCMeta):
implemented_properties = ['energy', 'forces', 'stress', 'desc_energy', 'desc_forces', 'desc_stress',
'comm_energy', 'comm_forces', 'comm_stress', 'bias_energy', 'bias_forces', 'bias_stress']
def __init__(self, mace_calculator, committee_size, prior_weight,
num_layers=-1, invariants_only=True, batch_size=None, **kwargs):
'''
Parameters
----------
mace_calculator: mace.calculators.MACECalculator object
MACE architecture to use to define a MACE descriptor
committee_size: int
Number of members in the linear committee
prior_weight: float
Weight corresponding to the prior matrix in the linear system
num_layers: int (default: -1)
Number of layers in the MACE model to keep for descriptor evaluation
Default of -1 uses all but the readout layer (equivalent to MACECalculator.get_descriptors() default)
invariants_only: bool
Whether to only keep the invariants partition of the descriptor vector, see MACECalculator.get_descriptors
for more details
batch_size: int
Batch size to use for descriptor gradient evaluation. Lower batch size reduces overhead.
If batch_size > len(atoms), then len(atoms) is used as the batch size instead
**kwargs: Keyword Args
Extra keywork arguments fed to :class:`~ase_uhal.committee_calculators.TorchCommitteeCalculator`
'''
from mace.modules.utils import prepare_graph, get_edge_vectors_and_lengths, extract_invariant
from e3nn import o3
self.batch_size = batch_size
self.prepare_graph = prepare_graph
self.get_edge_vectors_and_lengths = get_edge_vectors_and_lengths
self.extract_invariant = extract_invariant
self.o3 = o3
self.mace_calc = mace_calculator
self.model = mace_calculator.models[0]
self.num_layers = num_layers
self.invariants_only = invariants_only
self.torch_device = self.model.atomic_numbers.get_device()
if self.torch_device < 0:
self.torch_device = "cpu"
num_interactions = int(self.model.num_interactions)
irreps_out = self.o3.Irreps(str(self.model.products[0].linear.irreps_out))
l_max = irreps_out.lmax
num_invariant_features = irreps_out.dim // (l_max + 1) ** 2
per_layer_features = [irreps_out.dim for _ in range(num_interactions)]
per_layer_features[-1] = (
num_invariant_features # Equivariant features not created for the last layer
)
if num_layers == -1:
num_layers = num_interactions
to_keep = np.sum(per_layer_features[:num_layers])
self.l_max = l_max
self.num_invariant_features = num_invariant_features
self.num_layers = num_layers
self.to_keep = to_keep
super().__init__(committee_size, prior_weight, **kwargs)
def _get_descriptor_length(self):
# Build an atoms object with a species which the model can handle
ats = Atoms(numbers=[self.model.atomic_numbers.detach().cpu().numpy()[0]], positions=[[0, 0, 0]])
return self._descriptor_base(*self._prep_atoms(ats)).shape[0]
def _prep_atoms(self, atoms):
'''
Convert ASE atoms object into a format suitable for MACE models
'''
batch = self.mace_calc._atoms_to_batch(atoms).to_dict()
ctx = self.prepare_graph(batch, compute_stress=True)
return ctx.positions, ctx.displacement, batch["node_attrs"], batch["edge_index"], batch["unit_shifts"], batch["cell"]
def _descriptor_base(self, positions, displacement, attrs, edge_index, unit_shifts, cell):
'''
Base MACE descriptor, based on results from self._prep_atoms
'''
symmetric_displacement = 0.5 * (
displacement + displacement.transpose(-1, -2)
) # From https://github.com/mir-group/nequip
positions = positions + self.torch.einsum(
"be,bec->bc", positions, symmetric_displacement
)
cell = cell.view(-1, 3, 3)
cell = cell + self.torch.matmul(cell, symmetric_displacement)
shifts = self.torch.einsum(
"be,bec->bc",
unit_shifts,
cell,
)
vectors, lengths = self.get_edge_vectors_and_lengths(
positions=positions,
edge_index=edge_index,
shifts=shifts
)
node_feats = self.model.node_embedding(attrs)
edge_attrs = self.model.spherical_harmonics(vectors)
edge_feats, cutoff = self.model.radial_embedding(
lengths, attrs, edge_index, self.model.atomic_numbers
)
feats = []
for i, (interaction, product) in enumerate(
zip(self.model.interactions, self.model.products)
):
node_attrs_slice = attrs
node_feats, sc = interaction(
node_attrs=node_attrs_slice,
node_feats=node_feats,
edge_attrs=edge_attrs,
edge_feats=edge_feats,
edge_index=edge_index,
cutoff=cutoff,
first_layer=(i == 0),
)
node_feats = product(
node_feats=node_feats, sc=sc, node_attrs=node_attrs_slice
)
feats.append(node_feats)
node_feats_out = self.torch.cat(feats, dim=-1)
if self.invariants_only:
node_feats_out = self.extract_invariant(
node_feats_out,
num_layers=self.num_layers,
num_features=self.num_invariant_features,
l_max=self.l_max,
)
return node_feats_out[:, :self.to_keep].sum(dim=0)
def _desc_energy(self, *args):
return self._descriptor_base(*args)
def _comm_energy(self, *args):
return self.committee_weights @ self._descriptor_base(*args)
def _energy(self, *args):
comm_energy = self._comm_energy(*args)
return self.torch.mean(comm_energy)
@abstractmethod
def _bias_energy(self, *args):
pass
def _desc_forces(self, *args):
positions = args[0]
N = positions.shape[0]
M = self.n_desc
if self.batch_size is None or self.batch_size > N:
batch_size = N
else:
batch_size = self.batch_size
full_jac = self.torch.zeros(M, N, 3, dtype=positions.dtype, device=positions.device)
with self.torch.no_grad():
for start_idx in range(0, N, batch_size):
end_idx = min(start_idx + batch_size, N)
def f_partial(pos_section):
pos_copy = args[0].clone()
pos_copy[start_idx:end_idx] = pos_section
return self._descriptor_base(pos_copy, *args[1:])
pos_section = args[0][start_idx:end_idx, :]
full_jac[:, start_idx:end_idx, :] = self.torch.func.jacfwd(f_partial, argnums=0)(pos_section)
return full_jac
def _comm_forces(self, *args):
'''
Use Vector jacobian product to compute comm forces, to save memory overheads
'''
def f(positions):
return self._descriptor_base(positions, *args[1:])
_, comm_force_func = self.torch.func.vjp(f, args[0])
def g(weights):
return comm_force_func(weights)[0]
return self.torch.vmap(g)(self.committee_weights)
def _forces(self, *args):
comm_forces = self._comm_forces(*args)
return self.torch.mean(comm_forces, dim=0)
def _bias_forces(self, *args):
bias_energy = self._bias_energy(*args)
return self._take_derivative_scalar(bias_energy, args[0])
return self.torch.func.jacfwd(self._bias_energy, argnums=0)(*args)
def _desc_stress(self, *args):
desc_energy = self._descriptor_base(*args)
return self._take_derivative_vector(desc_energy, args[1])
return self.torch.func.jacfwd(self._descriptor_base, argnums=1)(*args)
def _comm_stress(self, *args):
def f(displacements):
return self._descriptor_base(args[0], displacements, *args[2:])
_, comm_stress_func = self.torch.func.vjp(f, args[1])
def g(weights):
return comm_stress_func(weights)[0]
return self.torch.vmap(g)(self.committee_weights)
def _stress(self, *args):
comm_stress = self._comm_stress(*args)
return self.torch.mean(comm_stress, dim=0)
def _bias_stress(self, *args):
bias_energy = self._bias_energy(*args)
return self._take_derivative_scalar(bias_energy, args[1])
return self.torch.func.jacfwd(self._bias_energy, argnums=1)(*args)
def calculate(self, atoms, properties, system_changes):
'''
Calculation for descriptor properties, committee properties, normal properties, and HAL properties
Descriptor properties use a "desc_" prefix, committee properties use "comm_", HAL (bias) properties use "hal_".
'''
super().calculate(atoms, properties, system_changes)
volume = atoms.get_volume()
struct = self._prep_atoms(atoms)
### Energy
if "desc_energy" in properties:
self.results["desc_energy"] = self._descriptor_base(*struct)
if "comm_energy" in properties:
self.results["comm_energy"] = self._comm_energy(*struct)
if "energy" in properties:
self.results["energy"] = self._energy(*struct)
if "bias_energy" in properties:
self.results["bias_energy"] = self._bias_energy(*struct)
### Forces
if "desc_forces" in properties:
self.results["desc_forces"] = -self._desc_forces(*struct)
if "comm_forces" in properties:
self.results["comm_forces"] = -self._comm_forces(*struct)
if "forces" in properties:
self.results["forces"] = -self._forces(*struct)
if "bias_forces" in properties:
self.results["bias_forces"] = -self._bias_forces(*struct)
### Stresses
if "desc_stress" in properties:
self.results["desc_stress"] = self._desc_stress(*struct)[:, 0, :, :] / volume
if "comm_stress" in properties:
self.results["comm_stress"] = self._comm_stress(*struct)[:, 0, :, :] / volume
if "stress" in properties:
self.results["stress"] = self._stress(*struct)[0, :, :] / volume
if "bias_stress" in properties:
self.results["bias_stress"] = self._bias_stress(*struct)[0, :, :] / volume
[docs]
class MACEHALCalculator(BaseMACECalculator):
name = "MACEHALCalculator"
def _bias_energy(self, *args):
comm_energy = self._comm_energy(*args)
return self.torch.std(comm_energy)