Enformer Architecture
Enformer Architecture Overview
What is Enformer?
Enformer is a deep learning model designed for predicting gene expression and other functional genomics signals directly from DNA sequence. It was introduced by DeepMind and Calico, and is notable for its ability to model long-range interactions in DNA using transformer architectures.
Key Components in the Code
1. Enformer Class
The main model class (Enformer
) inherits from snt.Module
(Sonnet, a neural network library).
Inputs and Outputs
- Inputs: DNA sequence (as a one-hot encoded tensor)
- Outputs: Predictions for multiple genomic tracks (e.g., gene expression, chromatin accessibility) for human and mouse
Constructor Arguments
channels
: Model width (number of convolutional filters)num_transformer_layers
: Number of transformer blocksnum_heads
: Number of attention heads in each transformer blockpooling_type
: Type of pooling (attention or max)name
: Module name
Model Structure
- Stem: Initial convolution and pooling layers to process the input
- Convolutional Tower: Series of convolutional blocks with increasing filter sizes, interleaved with pooling
- Transformer Blocks: Stack of transformer layers to capture long-range dependencies
- Crop Layer: Crops the output to a fixed target length
- Final Pointwise Layer: Final convolution and activation
- Heads: Separate output layers for human and mouse predictions
2. Supporting Modules
- Sequential: Custom sequential container that passes is_training flag to layers that accept it
- Residual: Residual connection wrapper (adds input to output of a submodule)
- SoftmaxPooling1D: Custom pooling layer that uses softmax-weighted pooling
- TargetLengthCrop1D: Crops the sequence to a target length (centered)
- gelu: Gaussian Error Linear Unit activation function
- one_hot_encode: Utility to convert DNA sequence strings to one-hot encoded numpy arrays
3. Forward Pass
The model processes the input sequence through: 1. The stem 2. Convolutional tower 3. Transformer stack 4. Cropping 5. Final pointwise layer
The resulting embedding is passed to each output head (for human and mouse), producing the final predictions.
Why is Enformer Special?
- Long-range modeling: Uses transformers to capture dependencies across hundreds of thousands of base pairs
- Multi-task: Predicts thousands of genomic signals at once
- Flexible pooling: Uses attention-based pooling to better aggregate information
Example Usage
- Input: One-hot encoded DNA sequence of length 196,608 (with 4 channels for A, C, G, T)
- Output: Dictionary with predictions for human and mouse tracks
Summary Table
Component | Purpose |
---|---|
Stem | Initial feature extraction |
Conv Tower | Hierarchical feature extraction |
Transformer | Long-range dependency modeling |
Crop | Fix output length |
Final Pointwise | Final feature transformation |
Heads | Task-specific outputs (human/mouse) |
Enformer Class Implementation
The following code shows the TensorFlow implementation of the Enformer model, downloaded from enformer.py and annotated with Gemini.
```{python}
# Copyright 2021 DeepMind Technologies Limited
# Licensed under the Apache License, Version 2.0
"""Tensorflow implementation of Enformer model.
"Effective gene expression prediction from sequence by integrating long-range interactions"
Authors: Žiga Avsec1, Vikram Agarwal2,4, Daniel Visentin1,4, Joseph R. Ledsam1,3,
Agnieszka Grabska-Barwinska1, Kyle R. Taylor1, Yannis Assael1, John Jumper1,
Pushmeet Kohli1, David R. Kelley2*
1 DeepMind, London, UK
2 Calico Life Sciences, South San Francisco, CA, USA
3 Google, Tokyo, Japan
4 These authors contributed equally.
* correspondence: avsec@google.com, pushmeet@google.com, drk@calicolabs.com
"""
import inspect
from typing import Any, Callable, Dict, Optional, Text, Union, Iterable
import attention_module
import numpy as np
import sonnet as snt
import tensorflow as tf
# Model configuration constants
SEQUENCE_LENGTH = 196_608 # Input DNA sequence length
BIN_SIZE = 128 # Output bin size
TARGET_LENGTH = 896 # Target output length
class Enformer(snt.Module):
"""Main model for predicting genomic signals from DNA sequence."""
def __init__(self,
channels: int = 1536, # Model width/dimensionality
num_transformer_layers: int = 11, # Number of transformer blocks
num_heads: int = 8, # Number of attention heads
pooling_type: str = 'attention', # Pooling type ('attention' or 'max')
name: str = 'enformer'): # Module name
super().__init__(name=name)
# Output channels for different species
heads_channels = {'human': 5313, 'mouse': 1643}
dropout_rate = 0.4
# Validate model configuration
assert channels % num_heads == 0, ('channels must be divisible by num_heads')
# Multi-head attention configuration
whole_attention_kwargs = {
'attention_dropout_rate': 0.05,
'initializer': None,
'key_size': 64,
'num_heads': num_heads,
'num_relative_position_features': channels // num_heads,
'positional_dropout_rate': 0.01,
'relative_position_functions': [
'positional_features_exponential',
'positional_features_central_mask',
'positional_features_gamma'
],
'relative_positions': True,
'scaling': True,
'value_size': channels // num_heads,
'zero_initialize': True
}
# Build model trunk
with tf.name_scope('trunk'):
# Convolutional block helper
def conv_block(filters, width=1, w_init=None, name='conv_block', **kwargs):
return Sequential(lambda: [
snt.distribute.CrossReplicaBatchNorm(
create_scale=True,
create_offset=True,
scale_init=snt.initializers.Ones(),
moving_mean=snt.ExponentialMovingAverage(0.9),
moving_variance=snt.ExponentialMovingAverage(0.9)),
gelu,
snt.Conv1D(filters, width, w_init=w_init, **kwargs)
], name=name)
# Initial stem
stem = Sequential(lambda: [
snt.Conv1D(channels // 2, 15),
Residual(conv_block(channels // 2, 1, name='pointwise_conv_block')),
pooling_module(pooling_type, pool_size=2),
], name='stem')
# Convolutional tower
filter_list = exponential_linspace_int(
start=channels // 2,
end=channels,
num=6,
divisible_by=128
)
conv_tower = Sequential(lambda: [
Sequential(lambda: [
conv_block(num_filters, 5),
Residual(conv_block(num_filters, 1, name='pointwise_conv_block')),
pooling_module(pooling_type, pool_size=2),
], name=f'conv_tower_block_{i}')
for i, num_filters in enumerate(filter_list)
], name='conv_tower')
# Transformer blocks
def transformer_mlp():
return Sequential(lambda: [
snt.LayerNorm(axis=-1, create_scale=True, create_offset=True),
snt.Linear(channels * 2),
snt.Dropout(dropout_rate),
tf.nn.relu,
snt.Linear(channels),
snt.Dropout(dropout_rate)
], name='mlp')
transformer = Sequential(lambda: [
Sequential(lambda: [
Residual(Sequential(lambda: [
snt.LayerNorm(axis=-1, create_scale=True, create_offset=True,
scale_init=snt.initializers.Ones()),
attention_module.MultiheadAttention(**whole_attention_kwargs,
name=f'attention_{i}'),
snt.Dropout(dropout_rate)
], name='mha')),
Residual(transformer_mlp())
], name=f'transformer_block_{i}')
for i in range(num_transformer_layers)
], name='transformer')
# Final layers
crop_final = TargetLengthCrop1D(TARGET_LENGTH, name='target_input')
final_pointwise = Sequential(lambda: [
conv_block(channels * 2, 1),
snt.Dropout(dropout_rate / 8),
gelu
], name='final_pointwise')
# Combine trunk modules
self._trunk = Sequential([
stem,
conv_tower,
transformer,
crop_final,
final_pointwise
], name='trunk')
# Output heads
with tf.name_scope('heads'):
self._heads = {
head: Sequential(
lambda: [snt.Linear(num_channels), tf.nn.softplus],
name=f'head_{head}')
for head, num_channels in heads_channels.items()
}
@property
def trunk(self):
return self._trunk
@property
def heads(self):
return self._heads
def __call__(self, inputs: tf.Tensor, is_training: bool) -> Dict[str, tf.Tensor]:
"""Forward pass through the model."""
trunk_embedding = self.trunk(inputs, is_training=is_training)
return {
head: head_module(trunk_embedding, is_training=is_training)
for head, head_module in self.heads.items()
}
@tf.function(input_signature=[
tf.TensorSpec([None, SEQUENCE_LENGTH, 4], tf.float32)])
def predict_on_batch(self, x):
"""Prediction method for SavedModel."""
return self(x, is_training=False)
class TargetLengthCrop1D(snt.Module):
"""Crops sequence to target length."""
def __init__(self, target_length: Optional[int], name: str = 'target_length_crop'):
super().__init__(name=name)
self._target_length = target_length
def __call__(self, inputs):
if self._target_length is None:
return inputs
trim = (inputs.shape[-2] - self._target_length) // 2
if trim < 0:
raise ValueError('inputs longer than target length')
elif trim == 0:
return inputs
else:
return inputs[..., trim:-trim, :]
class Sequential(snt.Module):
"""Sequential container with is_training support."""
def __init__(self,
layers: Optional[Union[Callable[[], Iterable[snt.Module]],
Iterable[Callable[..., Any]]]] = None,
name: Optional[Text] = None):
super().__init__(name=name)
if layers is None:
self._layers = []
else:
if hasattr(layers, '__call__'):
layers = layers()
self._layers = [layer for layer in layers if layer is not None]
def __call__(self, inputs: tf.Tensor, is_training: bool, **kwargs):
outputs = inputs
for mod in self._layers:
if accepts_is_training(mod):
outputs = mod(outputs, is_training=is_training, **kwargs)
else:
outputs = mod(outputs, **kwargs)
return outputs
def pooling_module(kind, pool_size):
"""Returns appropriate pooling module."""
if kind == 'attention':
return SoftmaxPooling1D(pool_size=pool_size, per_channel=True,
w_init_scale=2.0)
elif kind == 'max':
return tf.keras.layers.MaxPool1D(pool_size=pool_size, padding='same')
else:
raise ValueError(f'Invalid pooling kind: {kind}.')
class SoftmaxPooling1D(snt.Module):
"""Softmax-weighted pooling operation."""
def __init__(self,
pool_size: int = 2,
per_channel: bool = False,
w_init_scale: float = 0.0,
name: str = 'softmax_pooling'):
super().__init__(name=name)
self._pool_size = pool_size
self._per_channel = per_channel
self._w_init_scale = w_init_scale
self._logit_linear = None
@snt.once
def _initialize(self, num_features):
self._logit_linear = snt.Linear(
output_size=num_features if self._per_channel else 1,
with_bias=False,
w_init=snt.initializers.Identity(self._w_init_scale))
def __call__(self, inputs):
_, length, num_features = inputs.shape
self._initialize(num_features)
inputs = tf.reshape(
inputs,
(-1, length // self._pool_size, self._pool_size, num_features))
return tf.reduce_sum(
inputs * tf.nn.softmax(self._logit_linear(inputs), axis=-2),
axis=-2)
class Residual(snt.Module):
"""Residual connection wrapper."""
def __init__(self, module: snt.Module, name='residual'):
super().__init__(name=name)
self._module = module
def __call__(self, inputs: tf.Tensor, is_training: bool, *args,
**kwargs) -> tf.Tensor:
return inputs + self._module(inputs, is_training, *args, **kwargs)
def gelu(x: tf.Tensor) -> tf.Tensor:
"""Gaussian Error Linear Unit activation function."""
return tf.nn.sigmoid(1.702 * x) * x
def one_hot_encode(sequence: str,
alphabet: str = 'ACGT',
neutral_alphabet: str = 'N',
neutral_value: Any = 0,
dtype=np.float32) -> np.ndarray:
"""One-hot encodes DNA sequence."""
def to_uint8(string):
return np.frombuffer(string.encode('ascii'), dtype=np.uint8)
hash_table = np.zeros((np.iinfo(np.uint8).max, len(alphabet)), dtype=dtype)
hash_table[to_uint8(alphabet)] = np.eye(len(alphabet), dtype=dtype)
hash_table[to_uint8(neutral_alphabet)] = neutral_value
hash_table = hash_table.astype(dtype)
return hash_table[to_uint8(sequence)]
def exponential_linspace_int(start, end, num, divisible_by=1):
"""Generates exponentially spaced integers."""
def _round(x):
return int(np.round(x / divisible_by) * divisible_by)
base = np.exp(np.log(end / start) / (num - 1))
return [_round(start * base**i) for i in range(num)]
def accepts_is_training(module):
"""Checks if module accepts is_training parameter."""
return 'is_training' in list(inspect.signature(module.__call__).parameters)
```