Week 1 - Machine Learning Focus - Gene-Disease Association Prediction

Machine Learning - Gene-Disease Association Prediction

Part 1: Data Collection and Preprocessing (Common for Both Groups)

Step 1: Data Collection

  • OpenTargets Data:

    • Download from OpenTargets Platform
    • Focus on associations, use JSON format, and download a few files to start with. Collaborate to pick different files.
    • Example data format:
      {"diseaseId":"DOID_0050890","targetId":"ENSG00000004478","score":0.0022,"evidenceCount":1}
      {"diseaseId":"DOID_0050890","targetId":"ENSG00000005381","score":0.0022,"evidenceCount":1}
      {"diseaseId":"DOID_0050890","targetId":"ENSG00000006128","score":0.0022,"evidenceCount":1}
      
  • STRING Database:

    • Download from STRING Database
    • Select Homo Sapiens
    • Files needed: “9606.protein.links.detailed.v11.5.txt.gz” and “9606.protein.info.v11.5.txt.gz”
    • Example data format:
      protein1            protein2          combined_score
      9606.ENSP00000000233 9606.ENSP00000356607 173
      9606.ENSP00000000233 9606.ENSP00000427567 154
      9606.ENSP00000000233 9606.ENSP00000253413 151
      

Step 2: Data Preprocessing

  • Load data into DataFrames using pandas.

Step 3: Graph Construction

  • Construct the graph using NetworkX.
    import networkx as nx
    import pandas as pd
    
    opentargets_df = pd.read_json('path_to_opentargets_file')
    string_interactions = pd.read_csv('path_to_string_interactions', sep='\t')
    
    G = nx.Graph()
    
    # Add OpenTargets edges
    for _, row in opentargets_df.iterrows():
        G.add_edge(row['diseaseId'], row['targetId'], weight=row['score'], type='disease-gene')
    
    # Add STRING edges
    for _, row in string_interactions.iterrows():
        G.add_edge(row['protein1'], row['protein2'], weight=row['combined_score']/1000, type='protein-protein')
    

Step 4: Feature Extraction

  • Extract node and edge features.
    nx.set_node_attributes(G, nx.degree_centrality(G), 'degree_centrality')
    nx.set_node_attributes(G, nx.clustering(G), 'clustering_coefficient')
    nx.set_node_attributes(G, nx.pagerank(G), 'pagerank')
    
    def common_neighbors(G, u, v):
        return len(list(nx.common_neighbors(G, u, v)))
    
    for u, v, d in G.edges(data=True):
        d['common_neighbors'] = common_neighbors(G, u, v)
    

Step 5: Dataset Preparation

  • Create positive and negative samples, then split the data.
    import random
    from sklearn.model_selection import train_test_split
    
    edges = list(G.edges(data=True))
    non_edges = list(nx.non_edges(G))
    
    positive_samples = [(u, v, 1, d) for u, v, d in edges]
    negative_samples = [(u, v, 0, {'weight': 0, 'common_neighbors': common_neighbors(G, u, v)}) 
                        for u, v in random.sample(non_edges, len(edges))]
    
    samples = positive_samples + negative_samples
    random.shuffle(samples)
    
    X = [[G.nodes[u]['degree_centrality'], G.nodes[u]['clustering_coefficient'], G.nodes[u]['pagerank'],
          G.nodes[v]['degree_centrality'], G.nodes[v]['clustering_coefficient'], G.nodes[v]['pagerank'],
          d['weight'], d['common_neighbors']] for u, v, _, d in samples]
    y = [label for _, _, label, _ in samples]
    
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
    X_val, X_test, y_val, y_test = train_test_split(X_test, y_test, test_size=0.5, random_state=42)
    




Part 2: Traditional Machine Learning Approach (Group 1)

Step 1: Model Selection and Training

  • Choose models and train using GridSearchCV.
    from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
    from sklearn.linear_model import LogisticRegression
    from sklearn.model_selection import GridSearchCV
    from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
    
    models = {
        'Random Forest': RandomForestClassifier(),
        'Gradient Boosting': GradientBoostingClassifier(),
        'Logistic Regression': LogisticRegression()
    }
    
    param_grids = {
        'Random Forest': {'n_estimators': [100, 200, 300], 'max_depth': [5, 10, None]},
        'Gradient Boosting': {'n_estimators': [100, 200, 300], 'learning_rate': [0.01, 0.1, 0.3]},
        'Logistic Regression': {'C': [0.1, 1, 10], 'penalty': ['l1', 'l2']}
    }
    
    results = {}
    
    for name, model in models.items():
        grid_search = GridSearchCV(model, param_grids[name], cv=5, scoring='roc_auc')
        grid_search.fit(X_train, y_train)
        
        best_model = grid_search.best_estimator_
        y_pred = best_model.predict(X_test)
        y_pred_proba = best_model.predict_proba(X_test)[:, 1]
        
        results[name] = {
            'accuracy': accuracy_score(y_test, y_pred),
            'precision': precision_score(y_test, y_pred),
            'recall': recall_score(y_test, y_pred),
            'f1': f1_score(y_test, y_pred),
            'auc': roc_auc_score(y_test, y_pred_proba)
        }
    
    for name, metrics in results.items():
        print(f"{name}:")
        for metric, value in metrics.items():
            print(f"  {metric}: {value:.4f}")
    

Step 2: Feature Importance Analysis

  • Visualize feature importance.
    import matplotlib.pyplot as plt
    
    rf_model = models['Random Forest']
    feature_importance = rf_model.feature_importances_
    feature_names = ['u_degree', 'u_clustering', 'u_pagerank', 'v_degree', 'v_clustering', 'v_pagerank', 'weight', 'common_neighbors']
    
    plt.figure(figsize=(10, 6))
    plt.bar(feature_names, feature_importance)
    plt.title('Feature Importance in Random Forest Model')
    plt.xlabel('Features')
    plt.ylabel('Importance')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()
    




Part 3: Graph Neural Network Approach (Group 2)

Step 1: GNN Data Preparation

  • Prepare data for GNN using PyTorch Geometric.
    import torch
    from torch_geometric.data import Data
    
    edge_index = torch.tensor([[u, v] for u, v in G.edges()]).t().contiguous()
    x = torch.tensor([[G.nodes[n]['degree_centrality'], G.nodes[n]['clustering_coefficient'], G.nodes[n]['pagerank']] 
                      for n in G.nodes()])
    
    data = Data(x=x, edge_index=edge_index)
    

Step 2: GNN Model Implementation

  • Implement a GCN model.
    import torch.nn.functional as F
    from torch_geometric.nn import GCNConv
    
    class GCN(torch.nn.Module):
        def __init__(self, in_channels, hidden_channels, out_channels):
            super().__init__()
            self.conv1 = GCNConv(in_channels, hidden_channels)
            self.conv2 = GCNConv(hidden_channels, out_channels)
    
        def forward(self, x, edge_index):
            x = self.conv1(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=0.5, training=self.training)
            x = self.conv2(x, edge_index)
            return x
    
    model = GCN(in_channels=3, hidden_channels=64, out_channels=32)
    

Step 3: Model Training

  • Train the GNN model.
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    criterion = torch.nn.BCEWithLogitsLoss()
    
    def train():
        model.train()
        optimizer.zero_grad()
        out = model(data.x, data.edge_index)
        loss = criterion(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()
        return loss
    
    for epoch in range(200):
        loss = train()
        print(f'Epoch {epoch+1}, Loss: {loss.item():.4f}')
    

Step 4: Model Evaluation

  • Evaluate the GNN model.
    from sklearn.metrics import roc_auc_score
    
    model.eval()
    with torch.no_grad():
        out = model(data.x, data.edge_index)
        pred = torch.sigmoid(out)
        auc = roc_auc_score(data.y[data.test_mask].cpu(), pred[data.test_mask].cpu())
        print(f'Test AUC: {auc:.4f}')
    

Step 5: Visualization

  • Visualize node embeddings using t-SNE.
    from sklearn.manifold import TSNE
    import matplotlib.pyplot as plt
    
    embeddings = model(data.x, data.edge_index).detach().cpu().numpy()
    tsne = TSNE(n_components=2)
    node_embeddings_2d = tsne.fit_transform(embeddings)
    
    plt.figure(figsize=(10, 8))
    plt.scatter(node_embeddings_2d[:, 0], node_embeddings_2d[:, 1])
    plt.title('2D Visualization of Node Embeddings')
    plt.xlabel('Dimension 1')
    plt.ylabel('Dimension 2')
    plt.show()
    




Final Steps (Both Groups)

Step 1: Results Comparison

  1. Compare the performance metrics of traditional ML models and GNN.

    for name, metrics in results.items():
        print(f"{name} Performance Metrics:")
        for metric, value in metrics.items():
            print(f"  {metric}: {value:.4f}")
    
    # GNN performance
    print(f"GNN Test AUC: {auc:.4f}")
    
  2. Analyze which model performed better in terms of accuracy, precision, recall, F1 score, and AUC. Discuss potential reasons for the observed differences, considering the complexity of relationships captured by GNNs versus traditional ML models.

Step 2: Biological Interpretation

  • Analyze the top predictions from each method.
    # For traditional ML (e.g., Random Forest)
    top_predictions = sorted(zip(X_test, y_pred_proba), key=lambda x: x[1], reverse=True)[:10]
    
    # For GNN
    top_gnn_predictions = sorted(zip(data.edge_index.t(), pred), key=lambda x: x[1], reverse=True)[:10]
    
    # Interpret biological significance
    for prediction in top_predictions:
        u, v, score = prediction
        print(f"Disease: {u}, Gene: {v}, Prediction Score: {score}")
    
    for prediction in top_gnn_predictions:
        edge, score = prediction
        print(f"Edge: {edge}, Prediction Score: {score}")
    

Step 3: Incorporating Additional Features via NLP

  • Extract additional features using NLP (e.g., from scientific literature or clinical notes).
    • Use libraries like spaCy or NLTK to process text data and extract relevant features.
    import spacy
    from sklearn.feature_extraction.text import TfidfVectorizer
    
    nlp = spacy.load('en_core_web_sm')
    documents = [...]  # List of documents to process
    
    tfidf = TfidfVectorizer(max_features=1000)
    tfidf_matrix = tfidf.fit_transform(documents)
    
    # Add these features to the existing feature set
    additional_features = tfidf_matrix.toarray()
    X_enhanced = np.hstack((X, additional_features))
    
  1. Incorporate these additional features into the dataset and retrain the models.
    # Split enhanced data
    X_train, X_test, y_train, y_test = train_test_split(X_enhanced, y, test_size=0.3, random_state=42)
    X_val, X_test, y_val, y_test = train_test_split(X_test, y_test, test_size=0.5, random_state=42)
    
    # Retrain traditional ML models and GNN with the enhanced feature set
    # Follow the same steps as before for model training and evaluation
    

Step 4: Future Directions

  1. Explore incorporating more diverse features such as gene expression data or pathway information.
  2. Experiment with more advanced GNN architectures, like GraphSAGE or GAT.
  3. Develop an ensemble method combining traditional ML and GNN predictions to leverage the strengths of both approaches.

These suggestions could potentially enhance the robustness, performance, and interpretability of the gene-disease association prediction model:

  1. Data quality and preprocessing:

    • Add a step for handling missing values and outliers in the OpenTargets and STRING data.
    • Consider normalizing or scaling features before model training.
  2. Feature engineering:

    • Explore more complex network features like betweenness centrality or eigenvector centrality.
    • Consider creating interaction terms between features.
  3. Model selection and evaluation:

    • Include cross-validation for more robust performance estimation.
    • Add other metrics like Matthews Correlation Coefficient (MCC) for imbalanced datasets.
    • Consider using SHAP (SHapley Additive exPlanations) values for more interpretable feature importance.
  4. GNN approach:

    • Experiment with other GNN architectures like GraphSAGE, GAT, or more recent ones like GraphTransformer.
    • Implement early stopping to prevent overfitting.
    • Use k-fold cross-validation for more reliable GNN performance estimation.
  5. Biological interpretation:

    • Include gene set enrichment analysis (GSEA) on top predictions to identify overrepresented pathways or functions.
    • Validate top predictions against recent literature or experimental data.
  6. NLP integration:

    • Consider using biomedical-specific language models like BioBERT or PubMedBERT for feature extraction.
    • Implement named entity recognition to extract specific biological entities from text.
  7. Ensemble methods:

    • Implement stacking or blending of different models (traditional ML, GNN, and NLP-based) for potentially improved performance.
  8. Explainability:

    • Implement techniques like LIME or SHAP for explaining individual predictions, which is crucial in biomedical applications.
  9. External validation:

    • Include a step to validate the model on an independent external dataset to assess generalizability.
  10. Time-based splitting:

    • If the data has a temporal component, consider using time-based splitting instead of random splitting to mimic real-world scenarios.

@ameliatthomas @ARUUKE_BAYAKMATOVA @Moh_Saiger @Prasun_Sharma @hahaharsini @ahmedsalim @hahaharsini @Hyunji_An

@Hamza_Khan @DMot @Thuraya_Ayman

Are you available to meet with Sam this weekend to share results or ask questions about Part 1? If you have a preference, please let us know below (pick all time slots that may work):

  • Sat (08/03) morning Pacific
  • Sat afternoon Pacific
  • Sat evening Pacific
  • Sun (08/04) morning Pacific
  • Sun afternoon Pacific
  • Sun evening Pacific
0 voters

Feel free to ask any questions or share preliminary results by replying to this thread.

Meeting is an open one, all are welcome to attend.

@ameliatthomas @ARUUKE_BAYAKMATOVA @Moh_Saiger @Prasun_Sharma @hahaharsini @ahmedsalim @hahaharsini @Hyunji_An

@Hamza_Khan @DMot @Thuraya_Ayman

And a follow up meeting with Anubhav to share results or ask questions about Part 2 and/or Part 3? If you have a preference, please let us know below (pick all time slots that may work):

  • Mon (08/05) late afternoon Pacific
  • Tue late afternoon Pacific
  • Wed late afternoon Pacific
  • Thurs late afternoon Pacific
0 voters

Feel free to ask any questions or share preliminary results by replying to this thread.

Meeting is an open one, all are welcome to attend.

Confirming Sunday Noon Pacific Meeting Time for the ML/Bioinformatics Session

  • Please review the tasks and come prepared for an interactive code-along meeting.
  • If you are interested in this task but unavailable for the Sunday meeting, please reply to this topic with any exploratory data visualizations, coding results, and/or questions.
  • Participation in one of the above is mandatory to be part of the ML subteam.

Zoom Link: Zoom Meeting

If you are able to attend a Sunday 5pm Pacific meeting time, please reply to this post, and we will hold a second meeting.

2 Likes

Next steps based on 08/01 and 08/03 meetings:

Contributors: @moneuron @Moh_Saiger @Prasun_Sharma

Disease Selection and Data Collection:

  • Each student selects a disease. Please reply to this post with the disease you are picking, everyone works on a different disease in order to build up our final dataset.
    • Use OpenTargets to download associated proteins.
    • For StringDB data:
      • Option A: Use full Homo Sapiens PPI set.
      • Option B: Input disease-associated proteins and extend to ≥1000 genes.

  • Data Management:
    • Commit data directly to the main branch.
    • Code development on individual branches.

  • Network Creation:
    • Combine OpenTargets and StringDB data.
    • Nodes: Proteins and diseases (hint: extract proteins from StringDB only , ensure no protein duplication).
    • Edges: PPI (StringDB) and disease-protein associations (OpenTargets).
    • Visualize the network using tools like NetworkX or Cytoscape to gain insights into its structure.

Binary Classification Model for One Disease

  • Data preparation:
    • Positive class: Proteins associated with the chosen disease in OpenTargets
    • Negative class: Proteins in the network not associated with the disease
  • Features:
    • Network properties: degree, centrality, clustering coefficient
    • Binary feature: whether the protein interacts with any known disease proteins
  • Model:
    • Simple logistic regression or decision tree
    • Output: Binary prediction (associated/not associated)
  • Evaluate using accuracy, precision, recall, F1-score

Integrate Scores/Confidence Levels

  • Enhance the model by incorporating confidence scores
  • Additional features:
    • OpenTargets overall score for disease-protein associations
    • StringDB combined score for protein-protein interactions
  • Model:
    • Random Forest or Gradient Boosting
    • Output: Probability of association (can be interpreted as confidence)
  • Evaluation:
    • Use metrics like ROC-AUC that account for prediction confidence
  • Analyze feature importance to understand the impact of different evidence types

Generalize to Multiple Diseases

  • Expand dataset to include multiple diseases
  • Approaches:
    • Multi-label classification: predict association with multiple diseases simultaneously
    • One-vs-Rest: train separate binary classifiers for each disease
  • Features:
    • Same as Level 2, but calculated for each disease
    • Disease-specific features (e.g., disease ontology information)
  • Model:
    • More complex architectures like neural networks
    • Output: Probabilities for association with each disease

Biological Interpretation of Results (to be added)

I will not be able to make it to todays meeting but I have made a graph solely visualizing gene-disease associations. I only added 15 genes to the graph. gene diease association graph

2 Likes

Recap of some of the things we discussed in the meeting:

  1. Read this article we talked about, or to the very least " Preliminary" and " A road of machine learning-based approaches for the disease gene prediction" sections along Figure 1. This will help you understand the overall scope of our project and how each Level 1,2, and 3 (discussed by Anya) span out here. Figure 1

  2. Go to OpenTargetPlatform and search for a disease of interest. Say “Breast Cancer” as we talked about in the meeting. Then try to use an API query to get the same results in Python [There’s a button on the same page for it and also the OpenTargetPlatform API docs here: link 1, link 2] Alternatively, you can also export data and load in Python. From here you’ll get genes/protein names associated with the disease.

  3. Go to StringDB and click on the multiple protein option from the left. Enter your gene list from step 1, and generate a network. Then click on the More button to expand the network. Now do this using Python.

  4. Create a network from both datasets, load it in the NetworkX library, and integrate it. Refer to NetworkX documentation for creating and integrating graph layers.

  5. Coding: For leads (or however you guys are planning to work as a team) Create/Migrate a repo at STEM-Away mentor chains GitHub using the template repo. [Don’t work directly in the Main branch, it’s for production/deployment]

  • If only one person is pushing code to that repo, create a dev branch, compile code from everyone, and push there.
  • If multiple people are pushing code - create a dev branch and other branches with your name. The lead can pull the code from other branches and merge in dev.

Happy coding!

-Sam

1 Like

This is very good Moh, did you use data from OpenTarget and NetworkX for visualizing graph? This is a very good start.

Yes! exactly that

1 Like

Hey there. Sorry i haven’t been able to join the meetings due to time difference and as ihave been traveling. I did try downloading the datasets and forming networks. I have a small disease-target network using around 30 diabetic nephropathy disease targets form Opentarget. i wanted to incorporate couple of cytoscape/gephi features like edges with different thickness based on weight or node size with different size based on score, though its not nearly as good as Moh. Thank you Sam for the reference paper and pointing to the API docs. I’m trying to figure out Github and PPI network and hope to update soon! disease target network prac

3 Likes

This is a great start. Here are a few suggestions to enhance the aesthetics for you and everyone else:

  1. Consider using HGNC symbols for genes/proteins as node labels in a network graphic.
  2. If the weight values are very small, you can preprocess these scores using scaling techniques such as log transformation.
  3. If you have a weight or edge length property, you can experiment with different graph layouts to identify any cluster/hub formations. More detailed parameters can be calculated using network analysis later on.
  4. Also, consider adjusting node colors and edge thickness based on scores.
2 Likes

Fetched datasets for Breast Cancer, Infectious Disease, and Alzheimer Disease from the OpenTargets platform. I then used the STRING database to generate a network of proteins associated with the genes from the OpenTargets data. I created separate networks for the OpenTargets and STRING data using NetworkX. The graphs were constructed with from_pandas_edgelist, and add_edge in a loop, then they were combined using compose function.

5 Likes

I tried to create graphs from open target and string DB. Then I combined both graphs using nx.compose.

Open target data graph shows the breast cancer disease association with the targets. The combined score from all features for each gene/protein is plotted between the edges. Only 10 targets are used to ease the visualization. open_target_data

String DB data graph demonstrates proteins association. The combined score is considered in the plot. Only 4 proteins association with BRCA2 are considered. stringDB

Both graphs are combined as follows: combined graphs

3 Likes

this is my latest visualization for diabetes: gene-diseasediabetes ppidiabetes

2 Likes

This is awesome. The visualizations are getting better from everyone. Good job @ahmedsalim @ayahashim16 @Moh_Saiger @hahaharsini

Check out using Cytoscape in Python. Home - py2cytoscape It supports NetworkX, Pandas, and Igraph formats.

1 Like