AlphaFold2 code reading

Posted by worldofcarp on Fri, 31 Dec 2021 05:16:26 +0100

2021SC@SDUSC

Preface: continue reading the previous article

atom37_to_torsion_angles

def atom37_to_torsion_angles(
    aatype: jnp.ndarray,  # (B, N)
    all_atom_pos: jnp.ndarray,  # (B, N, 37, 3)
    all_atom_mask: jnp.ndarray,  # (B, N, 37)
    placeholder_for_undefined=False,
) -> Dict[str, jnp.ndarray]:

This function calculates the seven torsion angles of each residue and encodes them in sine and cosine.
The order of the seven torsion angles is [pre_omega, phi, psi, chi_1, chi_2, chi_3, chi_4]
Here, pre_omega represents the Omega twist angle between a given amino acid and the previous amino acid.
The parameters of this function are explained as follows:
Parameters:
aatype: amino acid type, given as an integer array.
all_atom_pos: atom37 representation of all atomic coordinates.
all_atom_mask: atom37 representation of the mask on all atomic coordinates.
placeholder_for_undefined: indicates whether to set the sign angle of shielding torsion to zero

aatype = jnp.minimum(aatype, 20)

The above code segment maps aatype > 20 to 'Unknown' (20)

num_batch, num_res = aatype.shape

pad = jnp.zeros([num_batch, 1, 37, 3], jnp.float32)
prev_all_atom_pos = jnp.concatenate([pad, all_atom_pos[:, :-1, :, :]], axis=1)

pad = jnp.zeros([num_batch, 1, 37], jnp.float32)
prev_all_atom_mask = jnp.concatenate([pad, all_atom_mask[:, :-1, :]], axis=1)

The above code segment is to calculate the backbone angle.

pre_omega_atom_pos = jnp.concatenate(
      [prev_all_atom_pos[:, :, 1:3, :],  # prev CA, C
       all_atom_pos[:, :, 0:2, :]  # this N, CA
      ], axis=-2)
  phi_atom_pos = jnp.concatenate(
      [prev_all_atom_pos[:, :, 2:3, :],  # prev C
       all_atom_pos[:, :, 0:3, :]  # this N, CA, C
      ], axis=-2)
  psi_atom_pos = jnp.concatenate(
      [all_atom_pos[:, :, 0:3, :],  # this N, CA, C
       all_atom_pos[:, :, 4:5, :]  # this O
      ], axis=-2)

For each torsion angle, collect the four atomic positions that define this angle.

  pre_omega_mask = (
      jnp.prod(prev_all_atom_mask[:, :, 1:3], axis=-1)  # prev CA, C
      * jnp.prod(all_atom_mask[:, :, 0:2], axis=-1))  # this N, CA
  phi_mask = (
      prev_all_atom_mask[:, :, 2]  # prev C
      * jnp.prod(all_atom_mask[:, :, 0:3], axis=-1))  # this N, CA, C
  psi_mask = (
      jnp.prod(all_atom_mask[:, :, 0:3], axis=-1) *  # this N, CA, C
      all_atom_mask[:, :, 4])  # this O

The above code snippet is collecting masks from these atoms

 chi_atom_indices  =  get_chi_atom_indices ()

This is collecting atoms at the corner of the card.

  atom_indices = utils.batched_gather(
      params=chi_atom_indices, indices=aatype, axis=0, batch_dims=0)

This is selecting atoms to calculate chis.

chis_atom_pos = utils.batched_gather(
      params=all_atom_pos, indices=atom_indices, axis=-2,
      batch_dims=2)

This is where the atoms are collected

chi_angles_mask = list(residue_constants.chi_angles_mask)
  chi_angles_mask.append([0.0, 0.0, 0.0, 0.0])
  chi_angles_mask = jnp.asarray(chi_angles_mask)

The above code segment is copying chi angle mask and adding UNKNOWN residue.

  chis_mask = utils.batched_gather(params=chi_angles_mask, indices=aatype,
                                   axis=0, batch_dims=0)

This is calculating the chi angle mask. That is, according to which chis angles exist

  chi_angle_atoms_mask = utils.batched_gather(
      params=all_atom_mask, indices=atom_indices, axis=-1,
      batch_dims=2)

This is the chis_mask constraints are those chis, where the ground live coordinates are available for all that define four atoms. And collect chi angle atomic mask.

  chi_angle_atoms_mask = jnp.prod(chi_angle_atoms_mask, axis=[-1])
  chis_mask = chis_mask * (chi_angle_atoms_mask).astype(jnp.float32)

This is to check whether all four chi angle atoms are set.

  torsions_atom_pos = jnp.concatenate(
      [pre_omega_atom_pos[:, :, None, :, :],
       phi_atom_pos[:, :, None, :, :],
       psi_atom_pos[:, :, None, :, :],
       chis_atom_pos
      ], axis=2)

All twist angle atomic positions are stacked here.

  torsion_angles_mask = jnp.concatenate(
      [pre_omega_mask[:, :, None],
       phi_mask[:, :, None],
       psi_mask[:, :, None],
       chis_mask
      ], axis=2)

The masks are stacked here for all twist angles.

  torsion_frames = r3.rigids_from_3_points(
      point_on_neg_x_axis=r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 1, :]),
      origin=r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 2, :]),
      point_on_xy_plane=r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 0, :]))

The above code snippet creates a framework from the first three atoms: the first atom: the point on the xy plane, the second atom: the point is on the negative x axis, and the third atom: origin

  forth_atom_rel_pos = r3.rigids_mul_vecs(
      r3.invert_rigids(torsion_frames),
      r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 3, :]))

The above code calculates the position of the fourth atom in this coordinate system (y and z coordinates, defining the chi angle)

  torsion_angles_sin_cos = jnp.stack(
      [forth_atom_rel_pos.z, forth_atom_rel_pos.y], axis=-1)
  torsion_angles_sin_cos /= jnp.sqrt(
      jnp.sum(jnp.square(torsion_angles_sin_cos), axis=-1, keepdims=True)
      + 1e-8)

Here, it is normalized to sine and cosine with torsion angle

  torsion_angles_sin_cos *= jnp.asarray(
      [1., 1., -1., 1., 1., 1., 1.])[None, None, :, None]

I don't understand it here

  chi_is_ambiguous = utils.batched_gather(
      jnp.asarray(residue_constants.chi_pi_periodic), aatype)
  mirror_torsion_angles = jnp.concatenate(
      [jnp.ones([num_batch, num_res, 3]),
       1.0 - 2.0 * chi_is_ambiguous], axis=-1)
  alt_torsion_angles_sin_cos = (
      torsion_angles_sin_cos * mirror_torsion_angles[:, :, :, None])

The above code creates an alternative angle for atoms that are not yet sure what kind

  if placeholder_for_undefined:

    placeholder_torsions = jnp.stack([
        jnp.ones(torsion_angles_sin_cos.shape[:-1]),
        jnp.zeros(torsion_angles_sin_cos.shape[:-1])
    ], axis=-1)
    torsion_angles_sin_cos = torsion_angles_sin_cos * torsion_angles_mask[
        ..., None] + placeholder_torsions * (1 - torsion_angles_mask[..., None])
    alt_torsion_angles_sin_cos = alt_torsion_angles_sin_cos * torsion_angles_mask[
        ..., None] + placeholder_torsions * (1 - torsion_angles_mask[..., None])

This is when you add a placeholder twist instead of an undefined twist angle

Topics: Python