= False
first_time if first_time:
# First uninstall the conflicting packages
%pip uninstall -y -q torch torchvision
%pip uninstall -y numpy pandas scipy scikit-learn anndata cell-gears datasets dcor
#%pip install -q torchvision==0.16.2 torch==2.1.2
%pip install -q torch==2.3.0 torchvision==0.16.2
%pip install -q scgpt scanpy gdown
# Then install them in the correct order with specific versions
%pip install numpy==1.23.5
%pip install pandas==1.5.3 # This version is compatible with anndata 0.10.9
%pip install scipy==1.10.1 # This version is >1.8 as required by anndata
%pip install scikit-learn==1.2.2
%pip install anndata==0.10.9
%pip install cell-gears==0.0.2
%pip install dcor==0.6
%pip install datasets==2.3.0
# First uninstall both packages
%pip uninstall -y torch torchtext
# Then install compatible versions
%pip install torch==2.1.2 torchtext==0.16.2
scgpt quickstart - jupyter notebook
https://virtualcellmodels.cziscience.com/quickstart/scgpt-quickstart
Quick Start: scGPT
This quick start will guide you through using the scGPT model, trained on 33 million cells (including data from the CZ CELLxGENE Census), to generate embeddings for single-cell transcriptomic data analysis. Learning Goals
By the end of this tutorial, you will understand how to:
Access and prepare the scGPT model for use.
Generate embeddings to analyze and compare your dataset against the CZ CELLxGENE Census.
Visualize the results using a UMAP, colored by cell type.
Pre-requisites and Requirements
Before starting, ensure you are familiar with:
Python and AnnData
Single-cell data analysis (see this tutorial for a primer on the subject) You can run this tutorial locally (tested on an M3 MacBook with 32 GiB memory) or in Google Colab using a T4 instance. Environment setup will be covered in a later section.
Overview
This notebook provides a step-by-step guide to:
Setting up your environment
Downloading the necessary model checkpoints and h5ad dataset
Performing model inference to create embeddings
Visualizing the results with UMAP
Setup
Let’s start by setting up dependencies. The released version of scGPT requires PyTorch 2.1.2, so we will remove the existing PyTorch installation and replace it with the required one. If you want to run this on another environment, this step might not be necessary.
** but then I got error about torch version. So I installed torch 2.3.0 and torchvision 0.16.2. and then reinstalled torch 2.1.2.**
```{bash}
conda create -n scgpt python=3.9
conda activate scgpt
```
We can install the rest of our dependencies and import the relevant libraries.
# Import libraries
# Import required packages
import os
import multiprocessing
# Set MPS fallback for unimplemented operations
'PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
os.environ[
# Monkey-patch os.sched_getaffinity for macOS
if not hasattr(os, 'sched_getaffinity'):
def sched_getaffinity(pid):
return set(range(multiprocessing.cpu_count()))
= sched_getaffinity
os.sched_getaffinity
import warnings
import urllib.request
from pathlib import Path
import scgpt as scg
import scanpy as sc
import numpy as np
import pandas as pd
import torch
# Check for MPS availability
= (
device "mps")
torch.device(if torch.backends.mps.is_available()
else torch.device("cpu")
)print(f"Using device: {device}")
print("Note: Some operations may fall back to CPU due to MPS limitations")
"ignore") warnings.filterwarnings(
/Users/haekyungim/miniconda3/envs/scgpt/lib/python3.9/site-packages/scgpt/model/model.py:21: UserWarning: flash_attn is not installed
warnings.warn("flash_attn is not installed")
/Users/haekyungim/miniconda3/envs/scgpt/lib/python3.9/site-packages/scgpt/model/multiomic_model.py:19: UserWarning: flash_attn is not installed
warnings.warn("flash_attn is not installed")
/Users/haekyungim/miniconda3/envs/scgpt/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
Using device: mps
Note: Some operations may fall back to CPU due to MPS limitations
# Define the base working directory
= "/Users/haekyungim/Library/CloudStorage/Box-Box/LargeFiles/imlab-data/data-Github/web-data/web-GENE-46100/scgpt"
WORKDIR # Convert to Path objects for better path handling
= Path(WORKDIR)
WORKDIR = WORKDIR / "data"
DATA_DIR = WORKDIR / "model" MODEL_DIR
Download Model Checkpoints and Data
Let’s download the checkpoints from the scGPT repository.
"ignore", ResourceWarning)
warnings.simplefilter("ignore", category=ImportWarning)
warnings.filterwarnings(
# Use gdown with the recursive flag to download the folder
# Replace the folder ID with the ID of your folder
= '1oWh_-ZRdhtoGQ2Fw24HP41FgLoomVo-y'
folder_id
# Check if model files already exist
if not (MODEL_DIR / "args.json").exists():
print("Downloading model checkpoint...")
!gdown --folder {folder_id} -O {MODEL_DIR}
else:
print("Model files already exist in", MODEL_DIR)
Model files already exist in /Users/haekyungim/Library/CloudStorage/Box-Box/LargeFiles/imlab-data/data-Github/web-data/web-GENE-46100/scgpt/model
We will now download an H5AD dataset from CZ CELLxGENE. To reduce memory utilization, we will also perform a reduction to the top 3000 highly variable genes using scanpy’s highly_variable_genes function.
= "https://datasets.cellxgene.cziscience.com/f50deffa-43ae-4f12-85ed-33e45040a1fa.h5ad"
uri = DATA_DIR / "source.h5ad"
source_path
# Check if file exists before downloading
if not source_path.exists():
print(f"Downloading dataset to {source_path}...")
=str(source_path))
urllib.request.urlretrieve(uri, filenameelse:
print(f"Dataset already exists at {source_path}")
# Read the data
= sc.read_h5ad(source_path)
adata
= "sample"
batch_key = 3000
N_HVG
=N_HVG, flavor='seurat_v3')
sc.pp.highly_variable_genes(adata, n_top_genes= adata[:, adata.var['highly_variable']] adata_hvg
Dataset already exists at /Users/haekyungim/Library/CloudStorage/Box-Box/LargeFiles/imlab-data/data-Github/web-data/web-GENE-46100/scgpt/data/source.h5ad
We can now use embed_data to generate the embeddings. Note that gene_col needs to point to the column where the gene names (not symbols!) are defined. For CZ CELLxGENE datasets, they are stored in the feature_name column.
# Monkey patch get_batch_cell_embeddings to force single processor
import types
from scgpt.tasks.cell_emb import get_batch_cell_embeddings as original_get_batch_cell_embeddings
import torch
from torch.utils.data import Dataset, DataLoader, SequentialSampler
from scgpt.data_collator import DataCollator
import numpy as np
from tqdm import tqdm
# Define Dataset class at module level
class CellEmbeddingDataset(Dataset):
def __init__(self, count_matrix, gene_ids, batch_ids=None, vocab=None, model_configs=None):
self.count_matrix = count_matrix
self.gene_ids = gene_ids
self.batch_ids = batch_ids
self.vocab = vocab
self.model_configs = model_configs
def __len__(self):
return len(self.count_matrix)
def __getitem__(self, idx):
= self.count_matrix[idx]
row = np.nonzero(row)[0]
nonzero_idx = row[nonzero_idx]
values = self.gene_ids[nonzero_idx]
genes # append <cls> token at the beginning
= np.insert(genes, 0, self.vocab["<cls>"])
genes = np.insert(values, 0, self.model_configs["pad_value"])
values = torch.from_numpy(genes).long()
genes = torch.from_numpy(values).float()
values = {
output "id": idx,
"genes": genes,
"expressions": values,
}if self.batch_ids is not None:
"batch_labels"] = self.batch_ids[idx]
output[return output
def patched_get_batch_cell_embeddings(
adata,str = "cls",
cell_embedding_mode: =None,
model=None,
vocab=1200,
max_length=64,
batch_size=None,
model_configs=None,
gene_ids=False,
use_batch_labels-> np.ndarray:
) """
Patched version of get_batch_cell_embeddings that uses the module-level Dataset class
and forces num_workers=0.
"""
= adata.X
count_matrix = (
count_matrix if isinstance(count_matrix, np.ndarray) else count_matrix.toarray()
count_matrix
)
# gene vocabulary ids
if gene_ids is None:
= np.array(adata.var["id_in_vocab"])
gene_ids assert np.all(gene_ids >= 0)
if use_batch_labels:
= np.array(adata.obs["batch_id"].tolist())
batch_ids
if cell_embedding_mode == "cls":
= CellEmbeddingDataset(
dataset
count_matrix,
gene_ids, if use_batch_labels else None,
batch_ids =vocab,
vocab=model_configs
model_configs
)= DataCollator(
collator =True,
do_padding=vocab[model_configs["pad_token"]],
pad_token_id=model_configs["pad_value"],
pad_value=False,
do_mlm=True,
do_binning=max_length,
max_length=True,
sampling=1,
keep_first_n_tokens
)= DataLoader(
data_loader
dataset,=batch_size,
batch_size=SequentialSampler(dataset),
sampler=collator,
collate_fn=False,
drop_last=0, # Force single worker
num_workers=True,
pin_memory
)
# Use the global device variable instead of getting it from model
= np.zeros(
cell_embeddings len(dataset), model_configs["embsize"]), dtype=np.float32
(
)with torch.no_grad():
# Disable autocast for MPS as it's not supported
= 0
count for data_dict in tqdm(data_loader, desc="Embedding cells"):
= data_dict["gene"].to(device)
input_gene_ids = input_gene_ids.eq(
src_key_padding_mask "pad_token"]]
vocab[model_configs[
)= model._encode(
embeddings
input_gene_ids,"expr"].to(device),
data_dict[=src_key_padding_mask,
src_key_padding_mask=data_dict["batch_labels"].to(device)
batch_labelsif use_batch_labels
else None,
)
= embeddings[:, 0, :] # get the <cls> position embedding
embeddings = embeddings.cpu().numpy()
embeddings + len(embeddings)] = embeddings
cell_embeddings[count : count += len(embeddings)
count = cell_embeddings / np.linalg.norm(
cell_embeddings =1, keepdims=True
cell_embeddings, axis
)else:
raise ValueError(f"Unknown cell embedding mode: {cell_embedding_mode}")
return cell_embeddings
# Replace the original function with our patched version
import scgpt.tasks.cell_emb
= patched_get_batch_cell_embeddings
scgpt.tasks.cell_emb.get_batch_cell_embeddings
'PYTHONWARNINGS'] = 'ignore' os.environ[
= MODEL_DIR #/ "scGPT_human"
model_dir = "feature_name"
gene_col = "cell_type"
cell_type_key
= DATA_DIR / "ref_embed_adata.h5ad"
embedding_file
if embedding_file.exists():
print(f"Loading existing embeddings from {embedding_file}")
= sc.read_h5ad(str(embedding_file))
ref_embed_adata else:
print("Computing new embeddings...")
= scg.tasks.embed_data(
ref_embed_adata
adata_hvg,
model_dir,=gene_col,
gene_col=cell_type_key,
obs_to_save=64,
batch_size=True,
return_new_adata=device, # Pass the device to embed_data
device
)print(f"Saving embeddings to {embedding_file}")
str(embedding_file)) ref_embed_adata.write(
Loading existing embeddings from /Users/haekyungim/Library/CloudStorage/Box-Box/LargeFiles/imlab-data/data-Github/web-data/web-GENE-46100/scgpt/data/ref_embed_adata.h5ad
Our scGPT embeddings are stored in the .X attribute of the returned AnnData object and have a dimensionality of 512.
ref_embed_adata.X.shape
(11103, 512)
We can now calculate neighbors based on scGPT embeddings.
="X")
sc.pp.neighbors(ref_embed_adata, use_rep sc.tl.umap(ref_embed_adata)
We will put our calculated UMAP and embeddings in our original adata object with our original annotations.
"X_scgpt"] = ref_embed_adata.X
adata.obsm["X_umap"] = ref_embed_adata.obsm["X_umap"] adata.obsm[
We can also switch our .var index which is currently set to Ensembl ID’s, to be gene symbols, allowing us to plot gene expression more easily.
# Add the current index ('ensembl_id') as a new column
'ensembl_id'] = adata.var.index
adata.var[
# Set the new index to the 'feature_name' column
'feature_name', inplace=True) adata.var.set_index(
# Add a copy of the gene symbols back to the var dataframe
'gene_symbol'] = adata.var.index adata.var[
We can now plot a UMAP, coloring it by cell type to visualize our embeddings. Below, we color by both the standard cell type labels provided by CZ CELLxGENE and the original cell type annotations from the authors. The embeddings generated by scGPT effectively capture the structure of the data, closely aligning with the original author annotations.
with warnings.catch_warnings():
"ignore")
warnings.filterwarnings(#sc.pp.neighbors(ref_embed_adata, use_rep="X")
#sc.tl.umap(ref_embed_adata)
=["cell_type", "annotation_res0.34_new2"], wspace = 0.6) sc.pl.umap(adata, color
We can also take a look at some markers of the major cell types represented in the dataset.
=['cell_type', 'MKI67', 'LYZ', 'RBP2', 'MUC2', 'CHGA', 'TAGLN', 'ELAVL3'], frameon=False, use_raw=False, legend_fontsize ="xx-small", legend_loc="none") sc.pl.umap(adata, color
References
Please refer to the following papers for information about:
scGPT: Toward building a foundation model for single-cell multi-omics using generative AI
Cui, H., Wang, C., Maan, H. et al. scGPT: toward building a foundation model for single-cell multi-omics using generative AI. Nat Methods 21, 1470–1480 (2024). https://doi.org/10.1038/s41592-024-02201-0
The dataset used in this tutorial
Moerkens, R., Mooiweer, J., Ramírez-Sánchez, A. D., Oelen, R., Franke, L., Wijmenga, C., Barrett, R. J., Jonkers, I. H., & Withoff, S. (2024). An iPSC-derived small intestine-on-chip with self-organizing epithelial, mesenchymal, and neural cells. Cell Reports, 43(7). https://doi.org/10.1016/j.celrep.2024.114247
CZ CELLxGENE Discover and Census
CZ CELLxGENE Discover: A single-cell data platform for scalable exploration, analysis and modeling of aggregated data CZI Single-Cell Biology, et al. bioRxiv 2023.10.30; doi: https://doi.org/10.1101/2023.10.30.563174