scgpt quickstart - qmd version

notebook
gene46100
Author

Haky Im based on cziscience page

Published

May 13, 2025

This tutorial will guide you through using scGPT, a foundation model for single-cell data analysis. We’ll learn how to process single-cell data and generate meaningful embeddings that can be used for various downstream analyses.

Downloaded from https://virtualcellmodels.cziscience.com/quickstart/scgpt-quickstart

Quick Start: scGPT

This quick start will guide you through using the scGPT model, trained on 33 million cells (including data from the CZ CELLxGENE Census), to generate embeddings for single-cell transcriptomic data analysis.

Learning Goals

By the end of this tutorial, you will understand how to:

  • Access and prepare the scGPT model for use
  • Generate embeddings to analyze and compare your dataset against the CZ CELLxGENE Census
  • Visualize the results using a UMAP, colored by cell type

Pre-requisites and Requirements

Important: Before starting, make sure you have: 1. Python 3.9 installed 2. Basic understanding of single-cell data analysis 3. At least 32GB of RAM (tested on M3 MacBook)

Before starting, ensure you are familiar with:

  • Python and AnnData
  • Single-cell data analysis (see this tutorial for a primer on the subject)
    You can run this tutorial locally (tested on an M3 MacBook with 32 GiB memory) or in Google Colab using a T4 instance. Environment setup will be covered in a later section.

AnnData schema

AnnData Structure

Key Concept: AnnData is the fundamental data structure for single-cell analysis in Python. Understanding its structure is necessary for working with scGPT and other tools.

The AnnData object is a Python-based data structure used by scanpy, scGPT, and other single-cell analysis tools. Here are its key components:

Main Attributes

  • .X: Main data matrix (cells × genes)
  • .obs: Cell annotations (pandas dataframe with metadata for each cell)
  • .var: Gene annotations (pandas dataframe with metadata for each gene)

Embedding Storage

  • .obsm: Cell embeddings (e.g., PCA, UMAP coordinates)
  • .varm: Gene embeddings (e.g., gene loadings)

Overview

This notebook provides a step-by-step guide to:

  • Setting up your environment
  • Downloading the necessary model checkpoints and h5ad dataset
  • Performing model inference to create embeddings
  • Visualizing the results with UMAP

Setup

Environment Setup Note: 1. We’ll create a new conda environment for this tutorial 2. We need specific versions of PyTorch and other dependencies 3. The setup process might take a few minutes

Let’s start by setting up dependencies. The released version of scGPT requires PyTorch 2.1.2, so we will remove the existing PyTorch installation and replace it with the required one. If you want to run this on another environment, this step might not be necessary.

Note: I encountered an error with torch version. I installed torch 2.3.0 and torchvision 0.16.2, and then reinstalled torch 2.1.2.

```{bash}
# Create and activate a new conda environment
conda create -n scgpt python=3.9
conda activate scgpt
```
# Environment setup and package installation
first_time = False
if first_time:
    # First uninstall the conflicting packages
    %pip uninstall -y -q torch torchvision
    %pip uninstall -y numpy pandas scipy scikit-learn anndata cell-gears datasets dcor
    #%pip install -q torchvision==0.16.2 torch==2.1.2
    %pip install -q torch==2.3.0 torchvision==0.16.2
    %pip install -q scgpt scanpy gdown

    # Then install them in the correct order with specific versions
    %pip install numpy==1.23.5
    %pip install pandas==1.5.3  # This version is compatible with anndata 0.10.9
    %pip install scipy==1.10.1  # This version is >1.8 as required by anndata
    %pip install scikit-learn==1.2.2
    %pip install anndata==0.10.9
    %pip install cell-gears==0.0.2
    %pip install dcor==0.6
    %pip install datasets==2.3.0

    # First uninstall both packages
    %pip uninstall -y torch torchtext

    # Then install compatible versions
    %pip install torch==2.1.2 torchtext==0.16.2

Package Installation Note: - We install specific versions of packages to ensure compatibility - The order of installation matters - Some packages are uninstalled first to avoid conflicts

# Import libraries
import os
import multiprocessing

# Set MPS fallback for unimplemented operations
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'

# Monkey-patch os.sched_getaffinity for macOS
if not hasattr(os, 'sched_getaffinity'):
    def sched_getaffinity(pid):
        return set(range(multiprocessing.cpu_count()))
    os.sched_getaffinity = sched_getaffinity

import warnings
import urllib.request
from pathlib import Path

import scgpt as scg
import scanpy as sc
import numpy as np
import pandas as pd
import torch

# Check for MPS availability
device = (
    torch.device("mps")
    if torch.backends.mps.is_available()
    else torch.device("cpu")
)
print(f"Using device: {device}")
print("Note: Some operations may fall back to CPU due to MPS limitations")

warnings.filterwarnings("ignore")

Library Import Note: - We import all necessary libraries for data processing and model inference - Special handling for macOS MPS (Metal Performance Shaders) - Warning suppression for cleaner output

# Define the base working directory
WORKDIR = "/Users/haekyungim/Library/CloudStorage/Box-Box/LargeFiles/imlab-data/data-Github/web-data/web-GENE-46100/scgpt"
# Convert to Path objects for better path handling
WORKDIR = Path(WORKDIR)
DATA_DIR = WORKDIR / "data"
MODEL_DIR = WORKDIR / "model" 

Directory Setup Note: - We create a structured directory for data and model files - Using Path objects for cross-platform compatibility

Download and Prepare Data

Data Download Note: - We’ll download both the model and a sample dataset - The dataset will be processed to reduce memory usage - We’ll use highly variable genes for better analysis

warnings.simplefilter("ignore", ResourceWarning)
warnings.filterwarnings("ignore", category=ImportWarning)

# Use gdown with the recursive flag to download the folder
folder_id = '1oWh_-ZRdhtoGQ2Fw24HP41FgLoomVo-y'

# Check if model files already exist
if not (MODEL_DIR / "args.json").exists():
    print("Downloading model checkpoint...")
    !gdown --folder {folder_id} -O {MODEL_DIR}
else:
    print("Model files already exist in", MODEL_DIR)

Model Download Note: - The model is downloaded from Google Drive - We check if files exist to avoid re-downloading - The model is about 200MB in size

uri = "https://datasets.cellxgene.cziscience.com/f50deffa-43ae-4f12-85ed-33e45040a1fa.h5ad"
source_path = DATA_DIR / "source.h5ad"

# Check if file exists before downloading
if not source_path.exists():
    print(f"Downloading dataset to {source_path}...")
    urllib.request.urlretrieve(uri, filename=str(source_path))
else:
    print(f"Dataset already exists at {source_path}")

# Read the data
adata = sc.read_h5ad(source_path)

batch_key = "sample"
N_HVG = 3000  # Number of highly variable genes to select

# Select highly variable genes to reduce memory usage
sc.pp.highly_variable_genes(adata, n_top_genes=N_HVG, flavor='seurat_v3')
adata_hvg = adata[:, adata.var['highly_variable']]

Data Processing Note: - We download a sample dataset from CZ CELLxGENE - The dataset is reduced to 3000 highly variable genes - This reduction helps with memory usage and computation speed

Model Inference and Embedding Generation

Key Concept: In this section, we’ll: 1. Set up the model for inference 2. Generate cell embeddings using scGPT 3. Process these embeddings for downstream analysis

Note: The embedding generation process might take several minutes depending on your hardware.

# Monkey patch get_batch_cell_embeddings to force single processor
import types
from scgpt.tasks.cell_emb import get_batch_cell_embeddings as original_get_batch_cell_embeddings
import torch
from torch.utils.data import Dataset, DataLoader, SequentialSampler
from scgpt.data_collator import DataCollator
import numpy as np
from tqdm import tqdm

# Define Dataset class at module level
class CellEmbeddingDataset(Dataset):
    """
    Custom dataset class for cell embedding generation.
    Processes single-cell data into a format suitable for the scGPT model.
    """
    def __init__(self, count_matrix, gene_ids, batch_ids=None, vocab=None, model_configs=None):
        self.count_matrix = count_matrix
        self.gene_ids = gene_ids
        self.batch_ids = batch_ids
        self.vocab = vocab
        self.model_configs = model_configs

    def __len__(self):
        return len(self.count_matrix)

    def __getitem__(self, idx):
        # Process a single cell's data
        row = self.count_matrix[idx]
        nonzero_idx = np.nonzero(row)[0]
        values = row[nonzero_idx]
        genes = self.gene_ids[nonzero_idx]
        # append <cls> token at the beginning
        genes = np.insert(genes, 0, self.vocab["<cls>"])
        values = np.insert(values, 0, self.model_configs["pad_value"])
        genes = torch.from_numpy(genes).long()
        values = torch.from_numpy(values).float()
        output = {
            "id": idx,
            "genes": genes,
            "expressions": values,
        }
        if self.batch_ids is not None:
            output["batch_labels"] = self.batch_ids[idx]
        return output

def patched_get_batch_cell_embeddings(
    adata,
    cell_embedding_mode: str = "cls",
    model=None,
    vocab=None,
    max_length=1200,
    batch_size=64,
    model_configs=None,
    gene_ids=None,
    use_batch_labels=False,
) -> np.ndarray:
    """
    Patched version of get_batch_cell_embeddings that uses the module-level Dataset class
    and forces num_workers=0 for better compatibility.
    """
    # Convert data to appropriate format
    count_matrix = adata.X
    count_matrix = (
        count_matrix if isinstance(count_matrix, np.ndarray) else count_matrix.toarray()
    )

    # Get gene vocabulary ids
    if gene_ids is None:
        gene_ids = np.array(adata.var["id_in_vocab"])
        assert np.all(gene_ids >= 0)

    if use_batch_labels:
        batch_ids = np.array(adata.obs["batch_id"].tolist())

    if cell_embedding_mode == "cls":
        # Set up dataset and data loader
        dataset = CellEmbeddingDataset(
            count_matrix, 
            gene_ids, 
            batch_ids if use_batch_labels else None,
            vocab=vocab,
            model_configs=model_configs
        )
        collator = DataCollator(
            do_padding=True,
            pad_token_id=vocab[model_configs["pad_token"]],
            pad_value=model_configs["pad_value"],
            do_mlm=False,
            do_binning=True,
            max_length=max_length,
            sampling=True,
            keep_first_n_tokens=1,
        )
        data_loader = DataLoader(
            dataset,
            batch_size=batch_size,
            sampler=SequentialSampler(dataset),
            collate_fn=collator,
            drop_last=False,
            num_workers=0,  # Force single worker for compatibility
            pin_memory=True,
        )

        # Generate embeddings
        cell_embeddings = np.zeros(
            (len(dataset), model_configs["embsize"]), dtype=np.float32
        )
        with torch.no_grad():
            count = 0
            for data_dict in tqdm(data_loader, desc="Embedding cells"):
                # Process each batch of cells
                input_gene_ids = data_dict["gene"].to(device)
                src_key_padding_mask = input_gene_ids.eq(
                    vocab[model_configs["pad_token"]]
                )
                embeddings = model._encode(
                    input_gene_ids,
                    data_dict["expr"].to(device),
                    src_key_padding_mask=src_key_padding_mask,
                    batch_labels=data_dict["batch_labels"].to(device)
                    if use_batch_labels
                    else None,
                )

                # Extract CLS token embeddings and normalize
                embeddings = embeddings[:, 0, :]  # get the <cls> position embedding
                embeddings = embeddings.cpu().numpy()
                cell_embeddings[count : count + len(embeddings)] = embeddings
                count += len(embeddings)
        cell_embeddings = cell_embeddings / np.linalg.norm(
            cell_embeddings, axis=1, keepdims=True
        )
    else:
        raise ValueError(f"Unknown cell embedding mode: {cell_embedding_mode}")
    return cell_embeddings

# Replace the original function with our patched version
import scgpt.tasks.cell_emb
scgpt.tasks.cell_emb.get_batch_cell_embeddings = patched_get_batch_cell_embeddings

os.environ['PYTHONWARNINGS'] = 'ignore'

Technical Note: The code above: - Creates a custom dataset class for processing single-cell data - Implements a patched version of the embedding function for better compatibility - Handles data batching and normalization - Uses the CLS token for cell representation

model_dir = MODEL_DIR
gene_col = "feature_name"
cell_type_key = "cell_type"

embedding_file = DATA_DIR / "ref_embed_adata.h5ad"

if embedding_file.exists():
    print(f"Loading existing embeddings from {embedding_file}")
    ref_embed_adata = sc.read_h5ad(str(embedding_file))
else:
    print("Computing new embeddings...")
    ref_embed_adata = scg.tasks.embed_data(
        adata_hvg,
        model_dir,
        gene_col=gene_col,
        obs_to_save=cell_type_key,
        batch_size=64,
        return_new_adata=True,
        device=device,
    )
    print(f"Saving embeddings to {embedding_file}")
    ref_embed_adata.write(str(embedding_file))

Embedding Generation Note: - We first check if embeddings already exist to avoid recomputing - The embedding process generates a 512-dimensional representation for each cell - Results are saved for future use

# Check the shape of our embeddings
ref_embed_adata.X.shape

Shape Explanation: - The output shows (number of cells × 512) - 512 is the embedding dimension of scGPT - Each cell is represented by a 512-dimensional vector

Dimensionality Reduction and Visualization

Key Concept: In this section, we’ll: 1. Compute cell-cell similarities using the embeddings 2. Generate a UMAP visualization 3. Compare different cell type annotations

Note: UMAP helps us visualize high-dimensional data in 2D while preserving important relationships.

# Compute cell-cell similarities
sc.pp.neighbors(ref_embed_adata, use_rep="X")
# Generate UMAP coordinates
sc.tl.umap(ref_embed_adata)

Neighborhood Graph Note: - We compute a neighborhood graph based on the embeddings - This graph is used to generate the UMAP visualization

# Transfer embeddings and UMAP to original adata object
adata.obsm["X_scgpt"] = ref_embed_adata.X
adata.obsm["X_umap"] = ref_embed_adata.obsm["X_umap"]

Data Transfer Note: - We transfer the embeddings and UMAP coordinates to our original dataset - This allows us to use the original annotations with our new embeddings - The embeddings are stored in .obsm for easy access

# Add the current index ('ensembl_id') as a new column
adata.var['ensembl_id'] = adata.var.index

# Set the new index to the 'feature_name' column
adata.var.set_index('feature_name', inplace=True)

# Add a copy of the gene symbols back to the var dataframe
adata.var['gene_symbol'] = adata.var.index

Gene Name Processing Note: - We reorganize the gene annotations for easier access - This allows us to use gene symbols in visualizations - Both Ensembl IDs and gene symbols are preserved

Visualization

Key Concept: In this section, we’ll: 1. Visualize the UMAP colored by different cell type annotations 2. Examine marker gene expression 3. Compare different annotation schemes

with warnings.catch_warnings():
    warnings.filterwarnings("ignore")
    sc.pl.umap(adata, color=["cell_type", "annotation_res0.34_new2"], wspace = 0.6)

UMAP Visualization Note: - The plot shows cells colored by two different annotation schemes - This helps us compare different cell type classifications - The spatial organization reflects cell type relationships

# Visualize marker genes for different cell types
sc.pl.umap(adata, 
           color=['cell_type', 'MKI67', 'LYZ', 'RBP2', 'MUC2', 'CHGA', 'TAGLN', 'ELAVL3'], 
           frameon=False, 
           use_raw=False, 
           legend_fontsize="xx-small", 
           legend_loc="none")

Marker Gene Note: - We visualize expression of key marker genes - These genes help identify different cell types: - MKI67: Proliferating cells - LYZ: Myeloid cells - RBP2: Enterocytes - MUC2: Goblet cells - CHGA: Enteroendocrine cells - TAGLN: Smooth muscle cells - ELAVL3: Neurons

References

Key Papers: 1. scGPT paper: Foundation model for single-cell multi-omics 2. Dataset paper: iPSC-derived small intestine-on-chip 3. CZ CELLxGENE platform paper

Please refer to the following papers for more information:

  1. scGPT: Toward building a foundation model for single-cell multi-omics using generative AI Cui, H., Wang, C., Maan, H. et al. Nature Methods 21, 1470–1480 (2024) https://doi.org/10.1038/s41592-024-02201-0

  2. The dataset used in this tutorial Moerkens, R., et al. Cell Reports, 43(7) (2024) https://doi.org/10.1016/j.celrep.2024.114247

  3. CZ CELLxGENE Discover and Census CZI Single-Cell Biology, et al. bioRxiv 2023.10.30 doi: https://doi.org/10.1101/2023.10.30.563174

© HakyImLab and Listed Authors - CC BY 4.0 License