2021SC@SDUSC
Preface: continue reading the previous article
atom37_to_torsion_angles
def atom37_to_torsion_angles( aatype: jnp.ndarray, # (B, N) all_atom_pos: jnp.ndarray, # (B, N, 37, 3) all_atom_mask: jnp.ndarray, # (B, N, 37) placeholder_for_undefined=False, ) -> Dict[str, jnp.ndarray]:
This function calculates the seven torsion angles of each residue and encodes them in sine and cosine.
The order of the seven torsion angles is [pre_omega, phi, psi, chi_1, chi_2, chi_3, chi_4]
Here, pre_omega represents the Omega twist angle between a given amino acid and the previous amino acid.
The parameters of this function are explained as follows:
Parameters:
aatype: amino acid type, given as an integer array.
all_atom_pos: atom37 representation of all atomic coordinates.
all_atom_mask: atom37 representation of the mask on all atomic coordinates.
placeholder_for_undefined: indicates whether to set the sign angle of shielding torsion to zero
aatype = jnp.minimum(aatype, 20)
The above code segment maps aatype > 20 to 'Unknown' (20)
num_batch, num_res = aatype.shape pad = jnp.zeros([num_batch, 1, 37, 3], jnp.float32) prev_all_atom_pos = jnp.concatenate([pad, all_atom_pos[:, :-1, :, :]], axis=1) pad = jnp.zeros([num_batch, 1, 37], jnp.float32) prev_all_atom_mask = jnp.concatenate([pad, all_atom_mask[:, :-1, :]], axis=1)
The above code segment is to calculate the backbone angle.
pre_omega_atom_pos = jnp.concatenate( [prev_all_atom_pos[:, :, 1:3, :], # prev CA, C all_atom_pos[:, :, 0:2, :] # this N, CA ], axis=-2) phi_atom_pos = jnp.concatenate( [prev_all_atom_pos[:, :, 2:3, :], # prev C all_atom_pos[:, :, 0:3, :] # this N, CA, C ], axis=-2) psi_atom_pos = jnp.concatenate( [all_atom_pos[:, :, 0:3, :], # this N, CA, C all_atom_pos[:, :, 4:5, :] # this O ], axis=-2)
For each torsion angle, collect the four atomic positions that define this angle.
pre_omega_mask = ( jnp.prod(prev_all_atom_mask[:, :, 1:3], axis=-1) # prev CA, C * jnp.prod(all_atom_mask[:, :, 0:2], axis=-1)) # this N, CA phi_mask = ( prev_all_atom_mask[:, :, 2] # prev C * jnp.prod(all_atom_mask[:, :, 0:3], axis=-1)) # this N, CA, C psi_mask = ( jnp.prod(all_atom_mask[:, :, 0:3], axis=-1) * # this N, CA, C all_atom_mask[:, :, 4]) # this O
The above code snippet is collecting masks from these atoms
chi_atom_indices = get_chi_atom_indices ()
This is collecting atoms at the corner of the card.
atom_indices = utils.batched_gather( params=chi_atom_indices, indices=aatype, axis=0, batch_dims=0)
This is selecting atoms to calculate chis.
chis_atom_pos = utils.batched_gather( params=all_atom_pos, indices=atom_indices, axis=-2, batch_dims=2)
This is where the atoms are collected
chi_angles_mask = list(residue_constants.chi_angles_mask) chi_angles_mask.append([0.0, 0.0, 0.0, 0.0]) chi_angles_mask = jnp.asarray(chi_angles_mask)
The above code segment is copying chi angle mask and adding UNKNOWN residue.
chis_mask = utils.batched_gather(params=chi_angles_mask, indices=aatype, axis=0, batch_dims=0)
This is calculating the chi angle mask. That is, according to which chis angles exist
chi_angle_atoms_mask = utils.batched_gather( params=all_atom_mask, indices=atom_indices, axis=-1, batch_dims=2)
This is the chis_mask constraints are those chis, where the ground live coordinates are available for all that define four atoms. And collect chi angle atomic mask.
chi_angle_atoms_mask = jnp.prod(chi_angle_atoms_mask, axis=[-1]) chis_mask = chis_mask * (chi_angle_atoms_mask).astype(jnp.float32)
This is to check whether all four chi angle atoms are set.
torsions_atom_pos = jnp.concatenate( [pre_omega_atom_pos[:, :, None, :, :], phi_atom_pos[:, :, None, :, :], psi_atom_pos[:, :, None, :, :], chis_atom_pos ], axis=2)
All twist angle atomic positions are stacked here.
torsion_angles_mask = jnp.concatenate( [pre_omega_mask[:, :, None], phi_mask[:, :, None], psi_mask[:, :, None], chis_mask ], axis=2)
The masks are stacked here for all twist angles.
torsion_frames = r3.rigids_from_3_points( point_on_neg_x_axis=r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 1, :]), origin=r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 2, :]), point_on_xy_plane=r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 0, :]))
The above code snippet creates a framework from the first three atoms: the first atom: the point on the xy plane, the second atom: the point is on the negative x axis, and the third atom: origin
forth_atom_rel_pos = r3.rigids_mul_vecs( r3.invert_rigids(torsion_frames), r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 3, :]))
The above code calculates the position of the fourth atom in this coordinate system (y and z coordinates, defining the chi angle)
torsion_angles_sin_cos = jnp.stack( [forth_atom_rel_pos.z, forth_atom_rel_pos.y], axis=-1) torsion_angles_sin_cos /= jnp.sqrt( jnp.sum(jnp.square(torsion_angles_sin_cos), axis=-1, keepdims=True) + 1e-8)
Here, it is normalized to sine and cosine with torsion angle
torsion_angles_sin_cos *= jnp.asarray( [1., 1., -1., 1., 1., 1., 1.])[None, None, :, None]
I don't understand it here
chi_is_ambiguous = utils.batched_gather( jnp.asarray(residue_constants.chi_pi_periodic), aatype) mirror_torsion_angles = jnp.concatenate( [jnp.ones([num_batch, num_res, 3]), 1.0 - 2.0 * chi_is_ambiguous], axis=-1) alt_torsion_angles_sin_cos = ( torsion_angles_sin_cos * mirror_torsion_angles[:, :, :, None])
The above code creates an alternative angle for atoms that are not yet sure what kind
if placeholder_for_undefined: placeholder_torsions = jnp.stack([ jnp.ones(torsion_angles_sin_cos.shape[:-1]), jnp.zeros(torsion_angles_sin_cos.shape[:-1]) ], axis=-1) torsion_angles_sin_cos = torsion_angles_sin_cos * torsion_angles_mask[ ..., None] + placeholder_torsions * (1 - torsion_angles_mask[..., None]) alt_torsion_angles_sin_cos = alt_torsion_angles_sin_cos * torsion_angles_mask[ ..., None] + placeholder_torsions * (1 - torsion_angles_mask[..., None])
This is when you add a placeholder twist instead of an undefined twist angle