|
|
|
|
|
|
|
|
import copy |
|
|
import os.path |
|
|
import pickle |
|
|
import random |
|
|
from multiprocessing import Pool |
|
|
|
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import torch |
|
|
from rdkit import Chem |
|
|
from rdkit.Chem import AllChem, MolFromSmiles |
|
|
from scipy.spatial.distance import pdist, squareform |
|
|
from torch_geometric.data import Dataset, HeteroData |
|
|
from torch_geometric.utils import subgraph |
|
|
from tqdm import tqdm |
|
|
|
|
|
from datasets.constants import aa_to_cg_indices, amino_acid_smiles, cg_rdkit_indices |
|
|
from datasets.parse_chi import aa_long2short, atom_order |
|
|
from datasets.process_mols import new_extract_receptor_structure, get_lig_graph, generate_conformer |
|
|
from utils.torsion import get_transformation_mask |
|
|
|
|
|
|
|
|
def read_strings_from_txt(path): |
|
|
|
|
|
with open(path) as file: |
|
|
lines = file.readlines() |
|
|
return [line.rstrip() for line in lines] |
|
|
|
|
|
|
|
|
def compute_num_ca_neighbors(coords, cg_coords, idx, is_valid_bb_node, max_dist=5, buffer_residue_num=7): |
|
|
""" |
|
|
Counts number of residues with heavy atoms within max_dist (Angstroms) of this sidechain that are not |
|
|
residues within +/- buffer_residue_num in primary sequence. |
|
|
From Ben's code |
|
|
Note: Gabriele removed the chain_index |
|
|
""" |
|
|
|
|
|
|
|
|
bb_coords = coords |
|
|
|
|
|
|
|
|
excluded_neighbors = [idx - x for x in reversed(range(0, buffer_residue_num+1)) if (idx - x) >= 0] |
|
|
excluded_neighbors.extend([idx + x for x in range(1, buffer_residue_num+1)]) |
|
|
|
|
|
|
|
|
e_idx = torch.stack([ |
|
|
torch.arange(bb_coords.shape[0]).unsqueeze(-1).expand((-1, cg_coords.shape[0])).flatten(), |
|
|
torch.arange(cg_coords.shape[0]).unsqueeze(0).expand((bb_coords.shape[0], -1)).flatten() |
|
|
]) |
|
|
|
|
|
|
|
|
bb_coords_exp = bb_coords[e_idx[0]] |
|
|
cg_coords_exp = cg_coords[e_idx[1]].unsqueeze(1) |
|
|
|
|
|
|
|
|
bb_exp_idces, _ = (torch.cdist(bb_coords_exp, cg_coords_exp).squeeze(-1) < max_dist).nonzero(as_tuple=True) |
|
|
bb_idces_within_thresh = torch.unique(e_idx[0][bb_exp_idces]) |
|
|
|
|
|
|
|
|
bb_idces_within_thresh = bb_idces_within_thresh[~torch.isin(bb_idces_within_thresh, torch.tensor(excluded_neighbors)) & is_valid_bb_node[bb_idces_within_thresh]] |
|
|
|
|
|
return len(bb_idces_within_thresh) |
|
|
|
|
|
|
|
|
def identify_valid_vandermers(args): |
|
|
""" |
|
|
Constructs a tensor containing all the number of contacts for each residue that can be sampled from for chemical groups. |
|
|
By using every sidechain as a chemical group, we will load the actual chemical groups at training time. |
|
|
These can be used to sample as probabilities once divided by the sum. |
|
|
""" |
|
|
complex_graph, max_dist, buffer_residue_num = args |
|
|
|
|
|
|
|
|
|
|
|
coords, seq = complex_graph.coords, complex_graph.seq |
|
|
is_valid_bb_node = (coords[:, :4].isnan().sum(dim=(1,2)) == 0).bool() |
|
|
|
|
|
valid_cg_idces = [] |
|
|
for idx, aa in enumerate(seq): |
|
|
|
|
|
if aa not in aa_to_cg_indices: |
|
|
valid_cg_idces.append(0) |
|
|
else: |
|
|
indices = aa_to_cg_indices[aa] |
|
|
cg_coordinates = coords[idx][indices] |
|
|
|
|
|
|
|
|
if torch.any(cg_coordinates.isnan()).item(): |
|
|
valid_cg_idces.append(0) |
|
|
continue |
|
|
|
|
|
nbr_count = compute_num_ca_neighbors(coords, cg_coordinates, idx, is_valid_bb_node, |
|
|
max_dist=max_dist, buffer_residue_num=buffer_residue_num) |
|
|
valid_cg_idces.append(nbr_count) |
|
|
|
|
|
return complex_graph.name, torch.tensor(valid_cg_idces) |
|
|
|
|
|
|
|
|
def fast_identify_valid_vandermers(coords, seq, max_dist=5, buffer_residue_num=7): |
|
|
|
|
|
offset = 10000 + max_dist |
|
|
R = coords.shape[0] |
|
|
|
|
|
coords = coords.numpy().reshape(-1, 3) |
|
|
pdist_mat = squareform(pdist(coords)) |
|
|
pdist_mat = pdist_mat.reshape((R, 14, R, 14)) |
|
|
pdist_mat = np.nan_to_num(pdist_mat, nan=offset) |
|
|
pdist_mat = np.min(pdist_mat, axis=(1, 3)) |
|
|
|
|
|
|
|
|
pdist_mat = pdist_mat + np.diag(np.ones(len(seq)) * offset) |
|
|
for i in range(1, buffer_residue_num+1): |
|
|
pdist_mat += np.diag(np.ones(len(seq)-i) * offset, k=i) + np.diag(np.ones(len(seq)-i) * offset, k=-i) |
|
|
|
|
|
|
|
|
nbr_count = np.sum(pdist_mat < max_dist, axis=1) |
|
|
return torch.tensor(nbr_count) |
|
|
|
|
|
|
|
|
def compute_cg_features(aa, aa_smile): |
|
|
""" |
|
|
Given an amino acid and a smiles string returns the stacked tensor of chemical group atom encodings. |
|
|
The order of the output tensor rows corresponds to the index the atoms appear in aa_to_cg_indices from constants. |
|
|
""" |
|
|
|
|
|
|
|
|
aa_short = aa_long2short[aa] |
|
|
if aa_short not in aa_to_cg_indices: |
|
|
return None |
|
|
|
|
|
|
|
|
mol = Chem.MolFromSmiles(aa_smile) |
|
|
|
|
|
complex_graph = HeteroData() |
|
|
get_lig_graph(mol, complex_graph) |
|
|
|
|
|
atoms_to_keep = torch.tensor([i for i, _ in cg_rdkit_indices[aa].items()]).long() |
|
|
complex_graph['ligand', 'ligand'].edge_index, complex_graph['ligand', 'ligand'].edge_attr = \ |
|
|
subgraph(atoms_to_keep, complex_graph['ligand', 'ligand'].edge_index, complex_graph['ligand', 'ligand'].edge_attr, relabel_nodes=True) |
|
|
complex_graph['ligand'].x = complex_graph['ligand'].x[atoms_to_keep] |
|
|
|
|
|
edge_mask, mask_rotate = get_transformation_mask(complex_graph) |
|
|
complex_graph['ligand'].edge_mask = torch.tensor(edge_mask) |
|
|
complex_graph['ligand'].mask_rotate = mask_rotate |
|
|
return complex_graph |
|
|
|
|
|
|
|
|
class PDBSidechain(Dataset): |
|
|
def __init__(self, root, transform=None, cache_path='data/cache', split='train', limit_complexes=0, |
|
|
receptor_radius=30, num_workers=1, c_alpha_max_neighbors=None, remove_hs=True, all_atoms=False, |
|
|
atom_radius=5, atom_max_neighbors=None, sequences_to_embeddings=None, |
|
|
knn_only_graph=True, multiplicity=1, vandermers_max_dist=5, vandermers_buffer_residue_num=7, |
|
|
vandermers_min_contacts=5, remove_second_segment=False, merge_clusters=1, vandermers_extraction=True, |
|
|
add_random_ligand=False): |
|
|
|
|
|
super(PDBSidechain, self).__init__(root, transform) |
|
|
assert remove_hs == True, "not implemented yet" |
|
|
self.root = root |
|
|
self.split = split |
|
|
self.limit_complexes = limit_complexes |
|
|
self.receptor_radius = receptor_radius |
|
|
self.knn_only_graph = knn_only_graph |
|
|
self.multiplicity = multiplicity |
|
|
self.c_alpha_max_neighbors = c_alpha_max_neighbors |
|
|
self.num_workers = num_workers |
|
|
self.sequences_to_embeddings = sequences_to_embeddings |
|
|
self.remove_second_segment = remove_second_segment |
|
|
self.merge_clusters = merge_clusters |
|
|
self.vandermers_extraction = vandermers_extraction |
|
|
self.add_random_ligand = add_random_ligand |
|
|
self.all_atoms = all_atoms |
|
|
self.atom_radius = atom_radius |
|
|
self.atom_max_neighbors = atom_max_neighbors |
|
|
|
|
|
if vandermers_extraction: |
|
|
self.cg_node_feature_lookup_dict = {aa_long2short[aa]: compute_cg_features(aa, aa_smile) for aa, aa_smile in |
|
|
amino_acid_smiles.items()} |
|
|
|
|
|
self.cache_path = os.path.join(cache_path, f'PDB3_limit{self.limit_complexes}_INDEX{self.split}' |
|
|
f'_recRad{self.receptor_radius}_recMax{self.c_alpha_max_neighbors}' |
|
|
+ (''if not all_atoms else f'_atomRad{atom_radius}_atomMax{atom_max_neighbors}') |
|
|
+ ('' if not self.knn_only_graph else '_knnOnly')) |
|
|
self.read_split() |
|
|
|
|
|
if not self.check_all_proteins(): |
|
|
os.makedirs(self.cache_path, exist_ok=True) |
|
|
self.preprocess() |
|
|
|
|
|
self.vandermers_max_dist = vandermers_max_dist |
|
|
self.vandermers_buffer_residue_num = vandermers_buffer_residue_num |
|
|
self.vandermers_min_contacts = vandermers_min_contacts |
|
|
self.collect_proteins() |
|
|
|
|
|
filtered_proteins = [] |
|
|
if vandermers_extraction: |
|
|
for complex_graph in tqdm(self.protein_graphs): |
|
|
if complex_graph.name in self.vandermers and torch.any(self.vandermers[complex_graph.name] >= 10): |
|
|
filtered_proteins.append(complex_graph) |
|
|
print(f"Computed vandermers and kept {len(filtered_proteins)} proteins out of {len(self.protein_graphs)}") |
|
|
else: |
|
|
filtered_proteins = self.protein_graphs |
|
|
|
|
|
second_filter = [] |
|
|
for complex_graph in tqdm(filtered_proteins): |
|
|
if sequences_to_embeddings is None or complex_graph.orig_seq in sequences_to_embeddings: |
|
|
second_filter.append(complex_graph) |
|
|
print(f"Checked embeddings available and kept {len(second_filter)} proteins out of {len(filtered_proteins)}") |
|
|
|
|
|
self.protein_graphs = second_filter |
|
|
|
|
|
|
|
|
self.split_clusters = list(set([g.cluster for g in self.protein_graphs])) |
|
|
self.cluster_to_complexes = {c: [] for c in self.split_clusters} |
|
|
for p in self.protein_graphs: |
|
|
self.cluster_to_complexes[p['cluster']].append(p) |
|
|
self.split_clusters = [c for c in self.split_clusters if len(self.cluster_to_complexes[c]) > 0] |
|
|
print("Total elements in set", len(self.split_clusters) * self.multiplicity // self.merge_clusters) |
|
|
|
|
|
self.name_to_complex = {p.name: p for p in self.protein_graphs} |
|
|
self.define_probabilities() |
|
|
|
|
|
if self.add_random_ligand: |
|
|
|
|
|
with open('data/smiles_list.csv', 'r') as f: |
|
|
self.smiles_list = f.readlines() |
|
|
self.smiles_list = [s.split(',')[0] for s in self.smiles_list] |
|
|
|
|
|
def define_probabilities(self): |
|
|
if not self.vandermers_extraction: |
|
|
return |
|
|
|
|
|
if self.vandermers_min_contacts is not None: |
|
|
self.probabilities = torch.arange(1000) - self.vandermers_min_contacts + 1 |
|
|
self.probabilities[:self.vandermers_min_contacts] = 0 |
|
|
else: |
|
|
with open('data/pdbbind_counts.pkl', 'rb') as f: |
|
|
pdbbind_counts = pickle.load(f) |
|
|
|
|
|
pdb_counts = torch.ones(1000) |
|
|
for contacts in self.vandermers.values(): |
|
|
pdb_counts.index_add_(0, contacts, torch.ones(contacts.shape)) |
|
|
print(pdbbind_counts[:30]) |
|
|
print(pdb_counts[:30]) |
|
|
|
|
|
self.probabilities = pdbbind_counts / pdb_counts |
|
|
self.probabilities[:7] = 0 |
|
|
|
|
|
def len(self): |
|
|
return len(self.split_clusters) * self.multiplicity // self.merge_clusters |
|
|
|
|
|
def get(self, idx=None, protein=None, smiles=None): |
|
|
assert idx is not None or (protein is not None and smiles is not None), "provide idx or protein or smile" |
|
|
|
|
|
if protein is None or smiles is None: |
|
|
idx = idx % len(self.split_clusters) |
|
|
if self.merge_clusters > 1: |
|
|
idx = idx * self.merge_clusters |
|
|
idx = idx + random.randint(0, self.merge_clusters - 1) |
|
|
idx = min(idx, len(self.split_clusters) - 1) |
|
|
cluster = self.split_clusters[idx] |
|
|
protein_graph = copy.deepcopy(random.choice(self.cluster_to_complexes[cluster])) |
|
|
else: |
|
|
protein_graph = copy.deepcopy(self.name_to_complex[protein]) |
|
|
|
|
|
if self.sequences_to_embeddings is not None: |
|
|
|
|
|
if len(protein_graph.orig_seq) != len(self.sequences_to_embeddings[protein_graph.orig_seq]): |
|
|
print('problem with ESM embeddings') |
|
|
return self.get(random.randint(0, self.len())) |
|
|
|
|
|
lm_embeddings = self.sequences_to_embeddings[protein_graph.orig_seq][protein_graph.to_keep] |
|
|
protein_graph['receptor'].x = torch.cat([protein_graph['receptor'].x, lm_embeddings], dim=1) |
|
|
|
|
|
if self.vandermers_extraction: |
|
|
|
|
|
vandermers_contacts = self.vandermers[protein_graph.name] |
|
|
vandermers_probs = self.probabilities[vandermers_contacts].numpy() |
|
|
|
|
|
if not np.any(vandermers_contacts.numpy() >= 10): |
|
|
print('no vandarmers >= 10 retrying with new one') |
|
|
return self.get(random.randint(0, self.len())) |
|
|
|
|
|
sidechain_idx = np.random.choice(np.arange(len(vandermers_probs)), p=vandermers_probs / np.sum(vandermers_probs)) |
|
|
|
|
|
|
|
|
residues_to_keep = np.ones(len(protein_graph.seq), dtype=bool) |
|
|
residues_to_keep[max(0, sidechain_idx - self.vandermers_buffer_residue_num): |
|
|
min(sidechain_idx + self.vandermers_buffer_residue_num + 1, len(protein_graph.seq))] = False |
|
|
|
|
|
if self.remove_second_segment: |
|
|
pos_idx = protein_graph['receptor'].pos[sidechain_idx] |
|
|
limit_closeness = 10 |
|
|
far_enough = torch.sum((protein_graph['receptor'].pos - pos_idx[None, :]) ** 2, dim=-1) > limit_closeness ** 2 |
|
|
vandermers_probs = vandermers_probs * far_enough.float().numpy() |
|
|
vandermers_probs[max(0, sidechain_idx - self.vandermers_buffer_residue_num): |
|
|
min(sidechain_idx + self.vandermers_buffer_residue_num + 1, len(protein_graph.seq))] = 0 |
|
|
if np.all(vandermers_probs<=0): |
|
|
print('no second vandermer available retrying with new one') |
|
|
return self.get(random.randint(0, self.len())) |
|
|
sc2_idx = np.random.choice(np.arange(len(vandermers_probs)), p=vandermers_probs / np.sum(vandermers_probs)) |
|
|
|
|
|
residues_to_keep[max(0, sc2_idx - self.vandermers_buffer_residue_num): |
|
|
min(sc2_idx + self.vandermers_buffer_residue_num + 1, len(protein_graph.seq))] = False |
|
|
|
|
|
residues_to_keep = torch.from_numpy(residues_to_keep) |
|
|
protein_graph['receptor'].pos = protein_graph['receptor'].pos[residues_to_keep] |
|
|
protein_graph['receptor'].x = protein_graph['receptor'].x[residues_to_keep] |
|
|
protein_graph['receptor'].side_chain_vecs = protein_graph['receptor'].side_chain_vecs[residues_to_keep] |
|
|
protein_graph['receptor', 'rec_contact', 'receptor'].edge_index = \ |
|
|
subgraph(residues_to_keep, protein_graph['receptor', 'rec_contact', 'receptor'].edge_index, relabel_nodes=True)[0] |
|
|
|
|
|
|
|
|
sidechain_aa = protein_graph.seq[sidechain_idx] |
|
|
ligand_graph = self.cg_node_feature_lookup_dict[sidechain_aa] |
|
|
ligand_graph['ligand'].pos = protein_graph.coords[sidechain_idx][protein_graph.mask[sidechain_idx]] |
|
|
|
|
|
for type in ligand_graph.node_types + ligand_graph.edge_types: |
|
|
for key, value in ligand_graph[type].items(): |
|
|
protein_graph[type][key] = value |
|
|
|
|
|
protein_graph['ligand'].orig_pos = protein_graph['ligand'].pos.numpy() |
|
|
protein_center = torch.mean(protein_graph['receptor'].pos, dim=0, keepdim=True) |
|
|
protein_graph['receptor'].pos = protein_graph['receptor'].pos - protein_center |
|
|
protein_graph['ligand'].pos = protein_graph['ligand'].pos - protein_center |
|
|
protein_graph.original_center = protein_center |
|
|
protein_graph['receptor_name'] = protein_graph.name |
|
|
else: |
|
|
protein_center = torch.mean(protein_graph['receptor'].pos, dim=0, keepdim=True) |
|
|
protein_graph['receptor'].pos = protein_graph['receptor'].pos - protein_center |
|
|
protein_graph.original_center = protein_center |
|
|
protein_graph['receptor_name'] = protein_graph.name |
|
|
|
|
|
if self.add_random_ligand: |
|
|
if smiles is not None: |
|
|
mol = MolFromSmiles(smiles) |
|
|
try: |
|
|
generate_conformer(mol) |
|
|
except Exception as e: |
|
|
print("failed to generate the given ligand returning None", e) |
|
|
return None |
|
|
else: |
|
|
success = False |
|
|
while not success: |
|
|
smiles = random.choice(self.smiles_list) |
|
|
mol = MolFromSmiles(smiles) |
|
|
try: |
|
|
success = not generate_conformer(mol) |
|
|
except Exception as e: |
|
|
print(e, "changing ligand") |
|
|
|
|
|
lig_graph = HeteroData() |
|
|
get_lig_graph(mol, lig_graph) |
|
|
|
|
|
edge_mask, mask_rotate = get_transformation_mask(lig_graph) |
|
|
lig_graph['ligand'].edge_mask = torch.tensor(edge_mask) |
|
|
lig_graph['ligand'].mask_rotate = mask_rotate |
|
|
lig_graph['ligand'].smiles = smiles |
|
|
lig_graph['ligand'].pos = lig_graph['ligand'].pos - torch.mean(lig_graph['ligand'].pos, dim=0, keepdim=True) |
|
|
|
|
|
for type in lig_graph.node_types + lig_graph.edge_types: |
|
|
for key, value in lig_graph[type].items(): |
|
|
protein_graph[type][key] = value |
|
|
|
|
|
for a in ['random_coords', 'coords', 'seq', 'sequence', 'mask', 'rmsd_matching', 'cluster', 'orig_seq', 'to_keep', 'chain_ids']: |
|
|
if hasattr(protein_graph, a): |
|
|
delattr(protein_graph, a) |
|
|
if hasattr(protein_graph['receptor'], a): |
|
|
delattr(protein_graph['receptor'], a) |
|
|
|
|
|
return protein_graph |
|
|
|
|
|
def read_split(self): |
|
|
|
|
|
df = pd.read_csv(self.root + "/list.csv") |
|
|
print("Loaded list CSV file") |
|
|
|
|
|
|
|
|
if self.split == "train": |
|
|
val_clusters = set(read_strings_from_txt(self.root + "/valid_clusters.txt")) |
|
|
test_clusters = set(read_strings_from_txt(self.root + "/test_clusters.txt")) |
|
|
clusters = df["CLUSTER"].unique() |
|
|
clusters = [int(c) for c in clusters if c not in val_clusters and c not in test_clusters] |
|
|
elif self.split == "val": |
|
|
clusters = [int(s) for s in read_strings_from_txt(self.root + "/valid_clusters.txt")] |
|
|
elif self.split == "test": |
|
|
clusters = [int(s) for s in read_strings_from_txt(self.root + "/test_clusters.txt")] |
|
|
else: |
|
|
raise ValueError("Split must be train, val or test") |
|
|
print(self.split, "clusters", len(clusters)) |
|
|
clusters = set(clusters) |
|
|
|
|
|
self.chains_in_cluster = [] |
|
|
complexes_in_cluster = set() |
|
|
for chain, cluster in zip(df["CHAINID"], df["CLUSTER"]): |
|
|
if cluster not in clusters: |
|
|
continue |
|
|
|
|
|
if chain[:4] not in complexes_in_cluster: |
|
|
self.chains_in_cluster.append((chain, cluster)) |
|
|
complexes_in_cluster.add(chain[:4]) |
|
|
print("Filtered chains in cluster", len(self.chains_in_cluster)) |
|
|
|
|
|
if self.limit_complexes > 0: |
|
|
self.chains_in_cluster = self.chains_in_cluster[:self.limit_complexes] |
|
|
|
|
|
def check_all_proteins(self): |
|
|
for i in range(len(self.chains_in_cluster)//10000+1): |
|
|
if not os.path.exists(os.path.join(self.cache_path, f"protein_graphs{i}.pkl")): |
|
|
return False |
|
|
return True |
|
|
|
|
|
def collect_proteins(self): |
|
|
self.protein_graphs = [] |
|
|
self.vandermers = {} |
|
|
total_recovered = 0 |
|
|
print(f'Loading {len(self.chains_in_cluster)} protein graphs.') |
|
|
list_indices = list(range(len(self.chains_in_cluster) // 10000 + 1)) |
|
|
random.shuffle(list_indices) |
|
|
for i in list_indices: |
|
|
with open(os.path.join(self.cache_path, f"protein_graphs{i}.pkl"), 'rb') as f: |
|
|
print(i) |
|
|
l = pickle.load(f) |
|
|
total_recovered += len(l) |
|
|
self.protein_graphs.extend(l) |
|
|
|
|
|
if not self.vandermers_extraction: |
|
|
continue |
|
|
|
|
|
if os.path.exists(os.path.join(self.cache_path, f'vandermers{i}_{self.vandermers_max_dist}_{self.vandermers_buffer_residue_num}.pkl')): |
|
|
with open(os.path.join(self.cache_path, f'vandermers{i}_{self.vandermers_max_dist}_{self.vandermers_buffer_residue_num}.pkl'), 'rb') as f: |
|
|
vandermers = pickle.load(f) |
|
|
self.vandermers.update(vandermers) |
|
|
continue |
|
|
|
|
|
vandermers = {} |
|
|
if self.num_workers > 1: |
|
|
p = Pool(self.num_workers, maxtasksperchild=1) |
|
|
p.__enter__() |
|
|
with tqdm(total=len(l), desc=f'computing vandermers {i}') as pbar: |
|
|
map_fn = p.imap_unordered if self.num_workers > 1 else map |
|
|
arguments = zip(l, [self.vandermers_max_dist] * len(l), |
|
|
[self.vandermers_buffer_residue_num] * len(l)) |
|
|
for t in map_fn(identify_valid_vandermers, arguments): |
|
|
if t is not None: |
|
|
vandermers[t[0]] = t[1] |
|
|
pbar.update() |
|
|
if self.num_workers > 1: p.__exit__(None, None, None) |
|
|
|
|
|
with open(os.path.join(self.cache_path, f'vandermers{i}_{self.vandermers_max_dist}_{self.vandermers_buffer_residue_num}.pkl'), 'wb') as f: |
|
|
pickle.dump(vandermers, f) |
|
|
self.vandermers.update(vandermers) |
|
|
|
|
|
print(f"Kept {len(self.protein_graphs)} proteins out of {len(self.chains_in_cluster)} total") |
|
|
return |
|
|
|
|
|
def preprocess(self): |
|
|
|
|
|
list_indices = list(range(len(self.chains_in_cluster) // 10000 + 1)) |
|
|
random.shuffle(list_indices) |
|
|
for i in list_indices: |
|
|
if os.path.exists(os.path.join(self.cache_path, f"protein_graphs{i}.pkl")): |
|
|
continue |
|
|
chains_names = self.chains_in_cluster[10000 * i:10000 * (i + 1)] |
|
|
protein_graphs = [] |
|
|
if self.num_workers > 1: |
|
|
p = Pool(self.num_workers, maxtasksperchild=1) |
|
|
p.__enter__() |
|
|
with tqdm(total=len(chains_names), |
|
|
desc=f'loading protein batch {i}/{len(self.chains_in_cluster) // 10000 + 1}') as pbar: |
|
|
map_fn = p.imap_unordered if self.num_workers > 1 else map |
|
|
for t in map_fn(self.load_chain, chains_names): |
|
|
if t is not None: |
|
|
protein_graphs.append(t) |
|
|
pbar.update() |
|
|
if self.num_workers > 1: p.__exit__(None, None, None) |
|
|
|
|
|
with open(os.path.join(self.cache_path, f"protein_graphs{i}.pkl"), 'wb') as f: |
|
|
pickle.dump(protein_graphs, f) |
|
|
|
|
|
print("Finished preprocessing and saving protein graphs") |
|
|
|
|
|
def load_chain(self, c): |
|
|
chain, cluster = c |
|
|
if not os.path.exists(self.root + f"/pdb/{chain[1:3]}/{chain}.pt"): |
|
|
print("File not found", chain) |
|
|
return None |
|
|
|
|
|
data = torch.load(self.root + f"/pdb/{chain[1:3]}/{chain}.pt") |
|
|
complex_graph = HeteroData() |
|
|
complex_graph['name'] = chain |
|
|
orig_seq = data["seq"] |
|
|
coords = data["xyz"] |
|
|
mask = data["mask"].bool() |
|
|
|
|
|
|
|
|
to_keep = torch.logical_not(torch.any(torch.isnan(coords[:, :4, 0]), dim=1)) |
|
|
coords = coords[to_keep] |
|
|
seq = ''.join(np.asarray(list(orig_seq))[to_keep.numpy()].tolist()) |
|
|
mask = mask[to_keep] |
|
|
|
|
|
if len(coords) == 0: |
|
|
print("All coords were NaN", chain) |
|
|
return None |
|
|
|
|
|
try: |
|
|
new_extract_receptor_structure(seq, coords.numpy(), complex_graph=complex_graph, neighbor_cutoff=self.receptor_radius, |
|
|
max_neighbors=self.c_alpha_max_neighbors, knn_only_graph=self.knn_only_graph, |
|
|
all_atoms=self.all_atoms, atom_cutoff=self.atom_radius, |
|
|
atom_max_neighbors=self.atom_max_neighbors) |
|
|
except Exception as e: |
|
|
print("Error in extracting receptor", chain) |
|
|
print(e) |
|
|
return None |
|
|
|
|
|
if torch.any(torch.isnan(complex_graph['receptor'].pos)): |
|
|
print("NaN in pos receptor", chain) |
|
|
return None |
|
|
|
|
|
complex_graph.coords = coords |
|
|
complex_graph.seq = seq |
|
|
complex_graph.mask = mask |
|
|
complex_graph.cluster = cluster |
|
|
complex_graph.orig_seq = orig_seq |
|
|
complex_graph.to_keep = to_keep |
|
|
return complex_graph |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
dataset = PDBSidechain(root="data/pdb_2021aug02_sample", split="train", multiplicity=1, limit_complexes=150) |
|
|
print(len(dataset)) |
|
|
print(dataset[0]) |
|
|
for p in dataset: |
|
|
print(p) |
|
|
pass |
|
|
|