# Environment setup and package installation
= False
first_time if first_time:
# First uninstall the conflicting packages
%pip uninstall -y -q torch torchvision
%pip uninstall -y numpy pandas scipy scikit-learn anndata cell-gears datasets dcor
#%pip install -q torchvision==0.16.2 torch==2.1.2
%pip install -q torch==2.3.0 torchvision==0.16.2
%pip install -q scgpt scanpy gdown
# Then install them in the correct order with specific versions
%pip install numpy==1.23.5
%pip install pandas==1.5.3 # This version is compatible with anndata 0.10.9
%pip install scipy==1.10.1 # This version is >1.8 as required by anndata
%pip install scikit-learn==1.2.2
%pip install anndata==0.10.9
%pip install cell-gears==0.0.2
%pip install dcor==0.6
%pip install datasets==2.3.0
# First uninstall both packages
%pip uninstall -y torch torchtext
# Then install compatible versions
%pip install torch==2.1.2 torchtext==0.16.2
scgpt quickstart - qmd version
This tutorial will guide you through using scGPT, a foundation model for single-cell data analysis. We’ll learn how to process single-cell data and generate meaningful embeddings that can be used for various downstream analyses.
Downloaded from https://virtualcellmodels.cziscience.com/quickstart/scgpt-quickstart
Quick Start: scGPT
This quick start will guide you through using the scGPT model, trained on 33 million cells (including data from the CZ CELLxGENE Census), to generate embeddings for single-cell transcriptomic data analysis.
Learning Goals
By the end of this tutorial, you will understand how to:
- Access and prepare the scGPT model for use
- Generate embeddings to analyze and compare your dataset against the CZ CELLxGENE Census
- Visualize the results using a UMAP, colored by cell type
Pre-requisites and Requirements
Important: Before starting, make sure you have: 1. Python 3.9 installed 2. Basic understanding of single-cell data analysis 3. At least 32GB of RAM (tested on M3 MacBook)
Before starting, ensure you are familiar with:
- Python and AnnData
- Single-cell data analysis (see this tutorial for a primer on the subject)
You can run this tutorial locally (tested on an M3 MacBook with 32 GiB memory) or in Google Colab using a T4 instance. Environment setup will be covered in a later section.
AnnData Structure
Key Concept: AnnData is the fundamental data structure for single-cell analysis in Python. Understanding its structure is necessary for working with scGPT and other tools.
The AnnData object is a Python-based data structure used by scanpy, scGPT, and other single-cell analysis tools. Here are its key components:
Main Attributes
.X
: Main data matrix (cells × genes).obs
: Cell annotations (pandas dataframe with metadata for each cell).var
: Gene annotations (pandas dataframe with metadata for each gene)
Embedding Storage
.obsm
: Cell embeddings (e.g., PCA, UMAP coordinates).varm
: Gene embeddings (e.g., gene loadings)
Overview
This notebook provides a step-by-step guide to:
- Setting up your environment
- Downloading the necessary model checkpoints and h5ad dataset
- Performing model inference to create embeddings
- Visualizing the results with UMAP
Setup
Environment Setup Note: 1. We’ll create a new conda environment for this tutorial 2. We need specific versions of PyTorch and other dependencies 3. The setup process might take a few minutes
Let’s start by setting up dependencies. The released version of scGPT requires PyTorch 2.1.2, so we will remove the existing PyTorch installation and replace it with the required one. If you want to run this on another environment, this step might not be necessary.
Note: I encountered an error with torch version. I installed torch 2.3.0 and torchvision 0.16.2, and then reinstalled torch 2.1.2.
```{bash}
# Create and activate a new conda environment
conda create -n scgpt python=3.9
conda activate scgpt
```
Package Installation Note: - We install specific versions of packages to ensure compatibility - The order of installation matters - Some packages are uninstalled first to avoid conflicts
# Import libraries
import os
import multiprocessing
# Set MPS fallback for unimplemented operations
'PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
os.environ[
# Monkey-patch os.sched_getaffinity for macOS
if not hasattr(os, 'sched_getaffinity'):
def sched_getaffinity(pid):
return set(range(multiprocessing.cpu_count()))
= sched_getaffinity
os.sched_getaffinity
import warnings
import urllib.request
from pathlib import Path
import scgpt as scg
import scanpy as sc
import numpy as np
import pandas as pd
import torch
# Check for MPS availability
= (
device "mps")
torch.device(if torch.backends.mps.is_available()
else torch.device("cpu")
)print(f"Using device: {device}")
print("Note: Some operations may fall back to CPU due to MPS limitations")
"ignore") warnings.filterwarnings(
Library Import Note: - We import all necessary libraries for data processing and model inference - Special handling for macOS MPS (Metal Performance Shaders) - Warning suppression for cleaner output
# Define the base working directory
= "/Users/haekyungim/Library/CloudStorage/Box-Box/LargeFiles/imlab-data/data-Github/web-data/web-GENE-46100/scgpt"
WORKDIR # Convert to Path objects for better path handling
= Path(WORKDIR)
WORKDIR = WORKDIR / "data"
DATA_DIR = WORKDIR / "model" MODEL_DIR
Directory Setup Note: - We create a structured directory for data and model files - Using Path objects for cross-platform compatibility
Download and Prepare Data
Data Download Note: - We’ll download both the model and a sample dataset - The dataset will be processed to reduce memory usage - We’ll use highly variable genes for better analysis
"ignore", ResourceWarning)
warnings.simplefilter("ignore", category=ImportWarning)
warnings.filterwarnings(
# Use gdown with the recursive flag to download the folder
= '1oWh_-ZRdhtoGQ2Fw24HP41FgLoomVo-y'
folder_id
# Check if model files already exist
if not (MODEL_DIR / "args.json").exists():
print("Downloading model checkpoint...")
!gdown --folder {folder_id} -O {MODEL_DIR}
else:
print("Model files already exist in", MODEL_DIR)
Model Download Note: - The model is downloaded from Google Drive - We check if files exist to avoid re-downloading - The model is about 200MB in size
= "https://datasets.cellxgene.cziscience.com/f50deffa-43ae-4f12-85ed-33e45040a1fa.h5ad"
uri = DATA_DIR / "source.h5ad"
source_path
# Check if file exists before downloading
if not source_path.exists():
print(f"Downloading dataset to {source_path}...")
=str(source_path))
urllib.request.urlretrieve(uri, filenameelse:
print(f"Dataset already exists at {source_path}")
# Read the data
= sc.read_h5ad(source_path)
adata
= "sample"
batch_key = 3000 # Number of highly variable genes to select
N_HVG
# Select highly variable genes to reduce memory usage
=N_HVG, flavor='seurat_v3')
sc.pp.highly_variable_genes(adata, n_top_genes= adata[:, adata.var['highly_variable']] adata_hvg
Data Processing Note: - We download a sample dataset from CZ CELLxGENE - The dataset is reduced to 3000 highly variable genes - This reduction helps with memory usage and computation speed
Model Inference and Embedding Generation
Key Concept: In this section, we’ll: 1. Set up the model for inference 2. Generate cell embeddings using scGPT 3. Process these embeddings for downstream analysis
Note: The embedding generation process might take several minutes depending on your hardware.
# Monkey patch get_batch_cell_embeddings to force single processor
import types
from scgpt.tasks.cell_emb import get_batch_cell_embeddings as original_get_batch_cell_embeddings
import torch
from torch.utils.data import Dataset, DataLoader, SequentialSampler
from scgpt.data_collator import DataCollator
import numpy as np
from tqdm import tqdm
# Define Dataset class at module level
class CellEmbeddingDataset(Dataset):
"""
Custom dataset class for cell embedding generation.
Processes single-cell data into a format suitable for the scGPT model.
"""
def __init__(self, count_matrix, gene_ids, batch_ids=None, vocab=None, model_configs=None):
self.count_matrix = count_matrix
self.gene_ids = gene_ids
self.batch_ids = batch_ids
self.vocab = vocab
self.model_configs = model_configs
def __len__(self):
return len(self.count_matrix)
def __getitem__(self, idx):
# Process a single cell's data
= self.count_matrix[idx]
row = np.nonzero(row)[0]
nonzero_idx = row[nonzero_idx]
values = self.gene_ids[nonzero_idx]
genes # append <cls> token at the beginning
= np.insert(genes, 0, self.vocab["<cls>"])
genes = np.insert(values, 0, self.model_configs["pad_value"])
values = torch.from_numpy(genes).long()
genes = torch.from_numpy(values).float()
values = {
output "id": idx,
"genes": genes,
"expressions": values,
}if self.batch_ids is not None:
"batch_labels"] = self.batch_ids[idx]
output[return output
def patched_get_batch_cell_embeddings(
adata,str = "cls",
cell_embedding_mode: =None,
model=None,
vocab=1200,
max_length=64,
batch_size=None,
model_configs=None,
gene_ids=False,
use_batch_labels-> np.ndarray:
) """
Patched version of get_batch_cell_embeddings that uses the module-level Dataset class
and forces num_workers=0 for better compatibility.
"""
# Convert data to appropriate format
= adata.X
count_matrix = (
count_matrix if isinstance(count_matrix, np.ndarray) else count_matrix.toarray()
count_matrix
)
# Get gene vocabulary ids
if gene_ids is None:
= np.array(adata.var["id_in_vocab"])
gene_ids assert np.all(gene_ids >= 0)
if use_batch_labels:
= np.array(adata.obs["batch_id"].tolist())
batch_ids
if cell_embedding_mode == "cls":
# Set up dataset and data loader
= CellEmbeddingDataset(
dataset
count_matrix,
gene_ids, if use_batch_labels else None,
batch_ids =vocab,
vocab=model_configs
model_configs
)= DataCollator(
collator =True,
do_padding=vocab[model_configs["pad_token"]],
pad_token_id=model_configs["pad_value"],
pad_value=False,
do_mlm=True,
do_binning=max_length,
max_length=True,
sampling=1,
keep_first_n_tokens
)= DataLoader(
data_loader
dataset,=batch_size,
batch_size=SequentialSampler(dataset),
sampler=collator,
collate_fn=False,
drop_last=0, # Force single worker for compatibility
num_workers=True,
pin_memory
)
# Generate embeddings
= np.zeros(
cell_embeddings len(dataset), model_configs["embsize"]), dtype=np.float32
(
)with torch.no_grad():
= 0
count for data_dict in tqdm(data_loader, desc="Embedding cells"):
# Process each batch of cells
= data_dict["gene"].to(device)
input_gene_ids = input_gene_ids.eq(
src_key_padding_mask "pad_token"]]
vocab[model_configs[
)= model._encode(
embeddings
input_gene_ids,"expr"].to(device),
data_dict[=src_key_padding_mask,
src_key_padding_mask=data_dict["batch_labels"].to(device)
batch_labelsif use_batch_labels
else None,
)
# Extract CLS token embeddings and normalize
= embeddings[:, 0, :] # get the <cls> position embedding
embeddings = embeddings.cpu().numpy()
embeddings + len(embeddings)] = embeddings
cell_embeddings[count : count += len(embeddings)
count = cell_embeddings / np.linalg.norm(
cell_embeddings =1, keepdims=True
cell_embeddings, axis
)else:
raise ValueError(f"Unknown cell embedding mode: {cell_embedding_mode}")
return cell_embeddings
# Replace the original function with our patched version
import scgpt.tasks.cell_emb
= patched_get_batch_cell_embeddings
scgpt.tasks.cell_emb.get_batch_cell_embeddings
'PYTHONWARNINGS'] = 'ignore' os.environ[
Technical Note: The code above: - Creates a custom dataset class for processing single-cell data - Implements a patched version of the embedding function for better compatibility - Handles data batching and normalization - Uses the CLS token for cell representation
= MODEL_DIR
model_dir = "feature_name"
gene_col = "cell_type"
cell_type_key
= DATA_DIR / "ref_embed_adata.h5ad"
embedding_file
if embedding_file.exists():
print(f"Loading existing embeddings from {embedding_file}")
= sc.read_h5ad(str(embedding_file))
ref_embed_adata else:
print("Computing new embeddings...")
= scg.tasks.embed_data(
ref_embed_adata
adata_hvg,
model_dir,=gene_col,
gene_col=cell_type_key,
obs_to_save=64,
batch_size=True,
return_new_adata=device,
device
)print(f"Saving embeddings to {embedding_file}")
str(embedding_file)) ref_embed_adata.write(
Embedding Generation Note: - We first check if embeddings already exist to avoid recomputing - The embedding process generates a 512-dimensional representation for each cell - Results are saved for future use
# Check the shape of our embeddings
ref_embed_adata.X.shape
Shape Explanation: - The output shows (number of cells × 512) - 512 is the embedding dimension of scGPT - Each cell is represented by a 512-dimensional vector
Dimensionality Reduction and Visualization
Key Concept: In this section, we’ll: 1. Compute cell-cell similarities using the embeddings 2. Generate a UMAP visualization 3. Compare different cell type annotations
Note: UMAP helps us visualize high-dimensional data in 2D while preserving important relationships.
# Compute cell-cell similarities
="X")
sc.pp.neighbors(ref_embed_adata, use_rep# Generate UMAP coordinates
sc.tl.umap(ref_embed_adata)
Neighborhood Graph Note: - We compute a neighborhood graph based on the embeddings - This graph is used to generate the UMAP visualization
# Transfer embeddings and UMAP to original adata object
"X_scgpt"] = ref_embed_adata.X
adata.obsm["X_umap"] = ref_embed_adata.obsm["X_umap"] adata.obsm[
Data Transfer Note: - We transfer the embeddings and UMAP coordinates to our original dataset - This allows us to use the original annotations with our new embeddings - The embeddings are stored in
.obsm
for easy access
# Add the current index ('ensembl_id') as a new column
'ensembl_id'] = adata.var.index
adata.var[
# Set the new index to the 'feature_name' column
'feature_name', inplace=True)
adata.var.set_index(
# Add a copy of the gene symbols back to the var dataframe
'gene_symbol'] = adata.var.index adata.var[
Gene Name Processing Note: - We reorganize the gene annotations for easier access - This allows us to use gene symbols in visualizations - Both Ensembl IDs and gene symbols are preserved
Visualization
Key Concept: In this section, we’ll: 1. Visualize the UMAP colored by different cell type annotations 2. Examine marker gene expression 3. Compare different annotation schemes
with warnings.catch_warnings():
"ignore")
warnings.filterwarnings(=["cell_type", "annotation_res0.34_new2"], wspace = 0.6) sc.pl.umap(adata, color
UMAP Visualization Note: - The plot shows cells colored by two different annotation schemes - This helps us compare different cell type classifications - The spatial organization reflects cell type relationships
# Visualize marker genes for different cell types
sc.pl.umap(adata, =['cell_type', 'MKI67', 'LYZ', 'RBP2', 'MUC2', 'CHGA', 'TAGLN', 'ELAVL3'],
color=False,
frameon=False,
use_raw="xx-small",
legend_fontsize="none") legend_loc
Marker Gene Note: - We visualize expression of key marker genes - These genes help identify different cell types: - MKI67: Proliferating cells - LYZ: Myeloid cells - RBP2: Enterocytes - MUC2: Goblet cells - CHGA: Enteroendocrine cells - TAGLN: Smooth muscle cells - ELAVL3: Neurons
References
Key Papers: 1. scGPT paper: Foundation model for single-cell multi-omics 2. Dataset paper: iPSC-derived small intestine-on-chip 3. CZ CELLxGENE platform paper
Please refer to the following papers for more information:
scGPT: Toward building a foundation model for single-cell multi-omics using generative AI Cui, H., Wang, C., Maan, H. et al. Nature Methods 21, 1470–1480 (2024) https://doi.org/10.1038/s41592-024-02201-0
The dataset used in this tutorial Moerkens, R., et al. Cell Reports, 43(7) (2024) https://doi.org/10.1016/j.celrep.2024.114247
CZ CELLxGENE Discover and Census CZI Single-Cell Biology, et al. bioRxiv 2023.10.30 doi: https://doi.org/10.1101/2023.10.30.563174