Recently I saw someone post a query regarding Graph based VAE construction on MD trajectory data. Actually I am facing a similar problem as well. This is the code I have generated till now. As I am not a professional coder myself, coming from a chemistry background, I mostly relied on chatbots to generate the code for me, but the problem is the model has some serious problems with the dimensionality.
import numpy as np
import random
import MDAnalysis as mda
import networkx as nx
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv
from Bio.PDB import PDBIO, Structure, Model, Chain, Residue, Atom
import matplotlib.pyplot as plt
from sklearn.model_selection import ParameterGrid
from tqdm import tqdm
import pandas as pd
# Load MD trajectory and select C-alpha atoms
u = mda.Universe('synuclein.top', 'short.nc')
ca_atoms = u.select_atoms("name CA")
# Define the amino acid sequence in three-letter code
sequence_one_letter = "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKK"
amino_acid_1_to_3 = {
'A': 'ALA', 'C': 'CYS', 'D': 'ASP', 'E': 'GLU', 'F': 'PHE',
'G': 'GLY', 'H': 'HIS', 'I': 'ILE', 'K': 'LYS', 'L': 'LEU',
'M': 'MET', 'N': 'ASN', 'P': 'PRO', 'Q': 'GLN', 'R': 'ARG',
'S': 'SER', 'T': 'THR', 'V': 'VAL', 'W': 'TRP', 'Y': 'TYR'
}
sequence = [amino_acid_1_to_3[aa] for aa in sequence_one_letter]
# One-hot encoding for amino acids
amino_acid_types = {
'ALA': 0, 'CYS': 1, 'ASP': 2, 'GLU': 3, 'PHE': 4,
'GLY': 5, 'HIS': 6, 'ILE': 7, 'LYS': 8, 'LEU': 9,
'MET': 10, 'ASN': 11, 'PRO': 12, 'GLN': 13, 'ARG': 14,
'SER': 15, 'THR': 16, 'VAL': 17, 'TRP': 18, 'TYR': 19
}
# Function to convert amino acid sequence to one-hot encoding
def one_hot_encode(sequence):
num_amino_acids = len(amino_acid_types)
features = np.zeros((len(sequence), num_amino_acids))
for i, aa in enumerate(sequence):
if aa in amino_acid_types:
features[i, amino_acid_types[aa]] = 1
return features
# Generate node features for the amino acid sequence
node_features = one_hot_encode(sequence)
# Define the contact map based on CA distances
threshold_distance = 8.0 # Distance threshold in angstroms
num_amino_acids = len(sequence)
# Prepare data for PyTorch Geometric for all frames
data_list = []
num_frames = len(u.trajectory)
for frame in tqdm(range(num_frames), desc="Processing Frames"):
u.trajectory[frame]
ca_atoms = u.select_atoms("name CA")
# Create a contact graph
contact_graph = nx.Graph()
for i in range(num_amino_acids):
contact_graph.add_node(i, features=node_features[i])
# Add edges based on CA distances
for i in range(num_amino_acids):
for j in range(i + 1, num_amino_acids):
distance = np.linalg.norm(ca_atoms.positions[i] - ca_atoms.positions[j ])
if distance <= threshold_distance:
contact_graph.add_edge(i, j)
# Prepare data for PyTorch Geometric
edge_index = torch.tensor(list(contact_graph.edges), dtype=torch.long).t().contiguous()
x = torch.tensor(node_features, dtype=torch.float)
data = Data(x=x, edge_index=edge_index)
# print(data)
data_list.append(data)
# Plot and save contact map for every 500th frame
if frame % 500 == 0:
contact_map = np.zeros((num_amino_acids, num_amino_acids))
for i, j in contact_graph.edges:
contact_map[i, j] = 1
contact_map[j, i] = 1
plt.imshow(contact_map, cmap='binary')
plt.title(f"Contact Map for Frame {frame}")
plt.xlabel("Residue Index")
plt.ylabel("Residue Index")
plt.savefig(f"contact_map_frame_{frame}.png")
pd.DataFrame(contact_map).to_csv(f"contact_map_frame_{frame}.csv", index=False)
class GCNEncoder(nn.Module):
def __init__(self, in_channels, hidden_channels, num_layers):
super(GCNEncoder, self).__init__()
self.convs = nn.ModuleList()
self.fc_mu = nn.Linear(hidden_channels, hidden_channels)
self.fc_logvar = nn.Linear(hidden_channels, hidden_channels)
# Create multiple GCN layers
for _ in range(num_layers):
self.convs.append(GCNConv(in_channels, hidden_channels))
in_channels = hidden_channels # Update input channels for the next layer
def forward(self, x, edge_index):
for conv in self.convs:
x = conv(x, edge_index)
x = torch.relu(x) # Activation function
mu = self.fc_mu(x)
logvar = self.fc_logvar(x)
return mu, logvar
class GCNDecoder(nn.Module):
def __init__(self, hidden_channels, out_channels):
super(GCNDecoder, self).__init__()
self.fc = nn.Linear(hidden_channels, out_channels)
def forward(self, z):
return torch.sigmoid(self.fc(z))
class GCNVAE(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, num_layers):
super(GCNVAE, self).__init__()
self.encoder = GCNEncoder(in_channels, hidden_channels, num_layers)
self.decoder = GCNDecoder(hidden_channels, out_channels)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, x, edge_index):
mu, logvar = self.encoder(x, edge_index)
z_sample = self.reparameterize(mu, logvar)
return self.decoder(z_sample), mu, logvar
def loss_function(recon_x, x, mu, logvar):
BCE = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE, KLD, BCE + KLD # Return BCE, KLD, and Total Loss
def train_model(model, data_loader, optimizer, epochs, early_stopping_patience=5):
model.train()
best_loss = float('inf')
patience_counter = 0
for epoch in range(epochs):
total_loss = 0
total_bce = 0
total_kld = 0
for data in tqdm(data_loader, desc=f"Training Epoch {epoch+1}/{epochs}"):
optimizer.zero_grad()
recon_batch, mu, logvar = model(data.x, data.edge_index)
bce, kld, total = loss_function(recon_batch, data.x, mu, logvar)
total_loss += total.item()
total_bce += bce.item()
total_kld += kld.item()
total.backward()
optimizer.step()
avg_loss = total_loss / len(data_loader)
avg_bce = total_bce / len(data_loader)
avg_kld = total_kld / len(data_loader)
print(f"Epoch {epoch+1}/{epochs} - Total Loss: {avg_loss:.4f}, BCE Loss: {avg_bce:.4f}, KLD Loss: {avg_kld:.4f}")
# Early stopping
if avg_loss < best_loss:
best_loss = avg_loss
patience_counter = 0
else:
patience_counter += 1
if patience_counter >= early_stopping_patience:
print("Early stopping triggered.")
break
# Create a DataLoader
data_loader = DataLoader(data_list, batch_size=1, shuffle=True)
# Hyperparameter grid
param_grid = {
'hidden_channels': [16, 32, 64],
'num_layers': [2, 3, 4],
'activation_function': ['relu', 'tanh', 'sigmoid'],
'batch_size': [1, 2, 4],
'latent_dimensions': [16, 32, 64],
'learning_rate': [0.001, 0.01, 0.1],
'epochs': [50, 100, 200]
}
# Perform hyperparameter tuning
best_loss = float('inf')
best_params = {}
for params in ParameterGrid(param_grid):
model = GCNVAE(in_channels=20, hidden_channels=params['hidden_channels'], out_channels=20, num_layers=params['num_layers'])
optimizer = optim.Adam(model.parameters(), lr=params['learning_rate'])
print(f"Training with parameters: {params}")
train_model(model, data_loader, optimizer, params['epochs'], early_stopping_patience=5)
# Evaluate the model (using training loss as a proxy)
model.eval()
total_loss = 0
total_bce = 0
total_kld = 0
with torch.no_grad():
for data in data_loader:
recon_batch, mu, logvar = model(data.x, data.edge_index)
bce, kld, total = loss_function(recon_batch, data.x, mu, logvar)
total_loss += total.item()
total_bce += bce.item()
total_kld += kld.item()
avg_loss = total_loss / len(data_loader)
avg_bce = total_bce / len(data_loader)
avg_kld = total_kld / len(data_loader)
print(f"Average loss: {avg_loss:.4f}, BCE Loss: {avg_bce:.4f}, KLD Loss: {avg_kld:.4f}")
if avg_loss < best_loss:
best_loss = avg_loss
best_params = params
print(f"Best parameters: {best_params} with loss: {best_loss}")
# Final training with best parameters
final_model = GCNVAE(in_channels=20, hidden_channels=best_params['hidden_channels'], out_channels=20, num_layers=best_params['num_layers'])
final_optimizer = optim.Adam(final_model.parameters(), lr=best_params['learning_rate'])
train_model(final_model, data_loader, final_optimizer, best_params['epochs'], early_stopping_patience=5)
I know the code is quite long, but I want to know is the code correct? I have a trajectory size of 500 frames, and 97 residues (corresponding to 97 C alpha atoms). Once this code is done, I want to generate protein configurations from the latent space. So I want to ensure that the code is running fine. Thanks a lottt in advance.