# %pip install dm-sonnet tqdm
Enformer training
needs data for training
Copyright 2021 DeepMind Technologies Limited
Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
This colab showcases training of the Enformer model published in
“Effective gene expression prediction from sequence by integrating long-range interactions”
Žiga Avsec, Vikram Agarwal, Daniel Visentin, Joseph R. Ledsam, Agnieszka Grabska-Barwinska, Kyle R. Taylor, Yannis Assael, John Jumper, Pushmeet Kohli, David R. Kelley
Steps
- Setup tf.data.Dataset by directly accessing the Basenji2 data on GCS:
gs://basenji_barnyard/data
- Train the model for a few steps, alternating training on human and mouse data batches
- Evaluate the model on human and mouse genomes
Setup
Start the colab kernel with GPU: Runtime -> Change runtime type -> GPU
Install dependencies
# Get enformer source code
# !wget -q https://raw.githubusercontent.com/deepmind/deepmind-research/master/enformer/attention_module.py
# !wget -q https://raw.githubusercontent.com/deepmind/deepmind-research/master/enformer/enformer.py
Import
import tensorflow as tf
# Make sure the GPU is enabled
assert tf.config.list_physical_devices('GPU'), 'Start the colab kernel with GPU: Runtime -> Change runtime type -> GPU'
# Easier debugging of OOM
%env TF_ENABLE_GPU_GARBAGE_COLLECTION=false
import sonnet as snt
from tqdm import tqdm
from IPython.display import clear_output
import numpy as np
import pandas as pd
import time
import os
assert snt.__version__.startswith('2.0')
tf.__version__
# this doesn't work on mac os
# !nvidia-smi
print(tf.config.list_physical_devices('GPU'))
# There is no direct command-line equivalent to nvidia-smi for Apple Silicon. Use Activity Monitor for a graphical view, or check GPU availability in TensorFlow with Python code.
Code
import enformer
# @title `get_targets(organism)`
def get_targets(organism):
= f'https://raw.githubusercontent.com/calico/basenji/master/manuscripts/cross2020/targets_{organism}.txt'
targets_txt return pd.read_csv(targets_txt, sep='\t')
# @title `get_dataset(organism, subset, num_threads=8)`
import glob
import json
import functools
def organism_path(organism):
return os.path.join('gs://basenji_barnyard/data', organism)
def get_dataset(organism, subset, num_threads=8):
= get_metadata(organism)
metadata = tf.data.TFRecordDataset(tfrecord_files(organism, subset),
dataset ='ZLIB',
compression_type=num_threads)
num_parallel_reads= dataset.map(functools.partial(deserialize, metadata=metadata),
dataset =num_threads)
num_parallel_callsreturn dataset
def get_metadata(organism):
# Keys:
# num_targets, train_seqs, valid_seqs, test_seqs, seq_length,
# pool_width, crop_bp, target_length
= os.path.join(organism_path(organism), 'statistics.json')
path with tf.io.gfile.GFile(path, 'r') as f:
return json.load(f)
def tfrecord_files(organism, subset):
# Sort the values by int(*).
return sorted(tf.io.gfile.glob(os.path.join(
'tfrecords', f'{subset}-*.tfr'
organism_path(organism), =lambda x: int(x.split('-')[-1].split('.')[0]))
)), key
def deserialize(serialized_example, metadata):
"""Deserialize bytes stored in TFRecordFile."""
= {
feature_map 'sequence': tf.io.FixedLenFeature([], tf.string),
'target': tf.io.FixedLenFeature([], tf.string),
}= tf.io.parse_example(serialized_example, feature_map)
example = tf.io.decode_raw(example['sequence'], tf.bool)
sequence = tf.reshape(sequence, (metadata['seq_length'], 4))
sequence = tf.cast(sequence, tf.float32)
sequence
= tf.io.decode_raw(example['target'], tf.float16)
target = tf.reshape(target,
target 'target_length'], metadata['num_targets']))
(metadata[= tf.cast(target, tf.float32)
target
return {'sequence': sequence,
'target': target}
Load dataset
= get_targets('human')
df_targets_human df_targets_human.head()
= get_dataset('human', 'train').batch(1).repeat()
human_dataset = get_dataset('mouse', 'train').batch(1).repeat()
mouse_dataset = tf.data.Dataset.zip((human_dataset, mouse_dataset)).prefetch(2) human_mouse_dataset
= iter(mouse_dataset)
it = next(it) example
# Example input
= iter(human_mouse_dataset)
it = next(it)
example for i in range(len(example)):
print(['human', 'mouse'][i])
print({k: (v.shape, v.dtype) for k,v in example[i].items()})
Model training
def create_step_function(model, optimizer):
@tf.function
def train_step(batch, head, optimizer_clip_norm_global=0.2):
with tf.GradientTape() as tape:
= model(batch['sequence'], is_training=True)[head]
outputs = tf.reduce_mean(
loss 'target'], outputs))
tf.keras.losses.poisson(batch[
= tape.gradient(loss, model.trainable_variables)
gradients apply(gradients, model.trainable_variables)
optimizer.
return loss
return train_step
= tf.Variable(0., trainable=False, name='learning_rate')
learning_rate = snt.optimizers.Adam(learning_rate=learning_rate)
optimizer = 5000
num_warmup_steps = 0.0005
target_learning_rate
= enformer.Enformer(channels=1536 // 4, # Use 4x fewer channels to train faster.
model =8,
num_heads=11,
num_transformer_layers='max')
pooling_type
= create_step_function(model, optimizer) train_step
# Train the model
= 20
steps_per_epoch = 5
num_epochs
= iter(human_mouse_dataset)
data_it = 0
global_step for epoch_i in range(num_epochs):
for i in tqdm(range(steps_per_epoch)):
+= 1
global_step
if global_step > 1:
= tf.math.minimum(
learning_rate_frac 1.0, global_step / tf.math.maximum(1.0, num_warmup_steps))
* learning_rate_frac)
learning_rate.assign(target_learning_rate
= next(data_it)
batch_human, batch_mouse
= train_step(batch=batch_human, head='human')
loss_human = train_step(batch=batch_mouse, head='mouse')
loss_mouse
# End of epoch.
print('')
print('loss_human', loss_human.numpy(),
'loss_mouse', loss_mouse.numpy(),
'learning_rate', optimizer.learning_rate.numpy()
)
Evaluate
# @title `PearsonR` and `R2` metrics
def _reduced_shape(shape, axis):
if axis is None:
return tf.TensorShape([])
return tf.TensorShape([d for i, d in enumerate(shape) if i not in axis])
class CorrelationStats(tf.keras.metrics.Metric):
"""Contains shared code for PearsonR and R2."""
def __init__(self, reduce_axis=None, name='pearsonr'):
"""Pearson correlation coefficient.
Args:
reduce_axis: Specifies over which axis to compute the correlation (say
(0, 1). If not specified, it will compute the correlation across the
whole tensor.
name: Metric name.
"""
super(CorrelationStats, self).__init__(name=name)
self._reduce_axis = reduce_axis
self._shape = None # Specified in _initialize.
def _initialize(self, input_shape):
# Remaining dimensions after reducing over self._reduce_axis.
self._shape = _reduced_shape(input_shape, self._reduce_axis)
= dict(shape=self._shape, initializer='zeros')
weight_kwargs self._count = self.add_weight(name='count', **weight_kwargs)
self._product_sum = self.add_weight(name='product_sum', **weight_kwargs)
self._true_sum = self.add_weight(name='true_sum', **weight_kwargs)
self._true_squared_sum = self.add_weight(name='true_squared_sum',
**weight_kwargs)
self._pred_sum = self.add_weight(name='pred_sum', **weight_kwargs)
self._pred_squared_sum = self.add_weight(name='pred_squared_sum',
**weight_kwargs)
def update_state(self, y_true, y_pred, sample_weight=None):
"""Update the metric state.
Args:
y_true: Multi-dimensional float tensor [batch, ...] containing the ground
truth values.
y_pred: float tensor with the same shape as y_true containing predicted
values.
sample_weight: 1D tensor aligned with y_true batch dimension specifying
the weight of individual observations.
"""
if self._shape is None:
# Explicit initialization check.
self._initialize(y_true.shape)
y_true.shape.assert_is_compatible_with(y_pred.shape)= tf.cast(y_true, 'float32')
y_true = tf.cast(y_pred, 'float32')
y_pred
self._product_sum.assign_add(
* y_pred, axis=self._reduce_axis))
tf.reduce_sum(y_true
self._true_sum.assign_add(
=self._reduce_axis))
tf.reduce_sum(y_true, axis
self._true_squared_sum.assign_add(
=self._reduce_axis))
tf.reduce_sum(tf.math.square(y_true), axis
self._pred_sum.assign_add(
=self._reduce_axis))
tf.reduce_sum(y_pred, axis
self._pred_squared_sum.assign_add(
=self._reduce_axis))
tf.reduce_sum(tf.math.square(y_pred), axis
self._count.assign_add(
=self._reduce_axis))
tf.reduce_sum(tf.ones_like(y_true), axis
def result(self):
raise NotImplementedError('Must be implemented in subclasses.')
def reset_states(self):
if self._shape is not None:
self._shape))
tf.keras.backend.batch_set_value([(v, np.zeros(for v in self.variables])
class PearsonR(CorrelationStats):
"""Pearson correlation coefficient.
Computed as:
((x - x_avg) * (y - y_avg) / sqrt(Var[x] * Var[y])
"""
def __init__(self, reduce_axis=(0,), name='pearsonr'):
"""Pearson correlation coefficient.
Args:
reduce_axis: Specifies over which axis to compute the correlation.
name: Metric name.
"""
super(PearsonR, self).__init__(reduce_axis=reduce_axis,
=name)
name
def result(self):
= self._true_sum / self._count
true_mean = self._pred_sum / self._count
pred_mean
= (self._product_sum
covariance - true_mean * self._pred_sum
- pred_mean * self._true_sum
+ self._count * true_mean * pred_mean)
= self._true_squared_sum - self._count * tf.math.square(true_mean)
true_var = self._pred_squared_sum - self._count * tf.math.square(pred_mean)
pred_var = tf.math.sqrt(true_var) * tf.math.sqrt(pred_var)
tp_var = covariance / tp_var
correlation
return correlation
class R2(CorrelationStats):
"""R-squared (fraction of explained variance)."""
def __init__(self, reduce_axis=None, name='R2'):
"""R-squared metric.
Args:
reduce_axis: Specifies over which axis to compute the correlation.
name: Metric name.
"""
super(R2, self).__init__(reduce_axis=reduce_axis,
=name)
name
def result(self):
= self._true_sum / self._count
true_mean = self._true_squared_sum - self._count * tf.math.square(true_mean)
total = (self._pred_squared_sum - 2 * self._product_sum
residuals + self._true_squared_sum)
return tf.ones_like(residuals) - residuals / total
class MetricDict:
def __init__(self, metrics):
self._metrics = metrics
def update_state(self, y_true, y_pred):
for k, metric in self._metrics.items():
metric.update_state(y_true, y_pred)
def result(self):
return {k: metric.result() for k, metric in self._metrics.items()}
def evaluate_model(model, dataset, head, max_steps=None):
= MetricDict({'PearsonR': PearsonR(reduce_axis=(0,1))})
metric @tf.function
def predict(x):
return model(x, is_training=False)[head]
for i, batch in tqdm(enumerate(dataset)):
if max_steps is not None and i > max_steps:
break
'target'], predict(batch['sequence']))
metric.update_state(batch[
return metric.result()
= evaluate_model(model,
metrics_human =get_dataset('human', 'valid').batch(1).prefetch(2),
dataset='human',
head=100)
max_stepsprint('')
print({k: v.numpy().mean() for k, v in metrics_human.items()})
= evaluate_model(model,
metrics_mouse =get_dataset('mouse', 'valid').batch(1).prefetch(2),
dataset='mouse',
head=100)
max_stepsprint('')
print({k: v.numpy().mean() for k, v in metrics_mouse.items()})
Restore Checkpoint
Note: For the TF-Hub Enformer model, the required input sequence length is 393,216 which actually gets cropped within the model to 196,608. The open source module does not internally crop the sequence. Therefore, the code below crops the central 196,608 bp
of the longer sequence to reproduce the output of the TF hub from the reloaded checkpoint.
42)
np.random.seed(= 393_216
EXTENDED_SEQ_LENGTH = 196_608
SEQ_LENGTH = np.array(np.random.random((1, EXTENDED_SEQ_LENGTH, 4)), dtype=np.float32)
inputs = enformer.TargetLengthCrop1D(SEQ_LENGTH)(inputs) inputs_cropped
= 'gs://dm-enformer/models/enformer/sonnet_weights/*'
checkpoint_gs_path = '/tmp/enformer_checkpoint' checkpoint_path
!mkdir /tmp/enformer_checkpoint
# Copy checkpoints from GCS to temporary directory.
# This will take a while as the checkpoint is ~ 1GB.
for file_path in tf.io.gfile.glob(checkpoint_gs_path):
print(file_path)
= os.path.basename(file_path)
file_name f'{checkpoint_path}/{file_name}', overwrite=True) tf.io.gfile.copy(file_path,
!ls -lh /tmp/enformer_checkpoint
= enformer.Enformer() enformer_model
= tf.train.Checkpoint(module=enformer_model) checkpoint
= tf.train.latest_checkpoint(checkpoint_path)
latest print(latest)
= checkpoint.restore(latest) status
# Using `is_training=False` to match TF-hub predict_on_batch function.
= enformer_model(inputs_cropped, is_training=False) restored_predictions
import tensorflow_hub as hub
= hub.load("https://tfhub.dev/deepmind/enformer/1").model enformer_tf_hub_model
= enformer_tf_hub_model.predict_on_batch(inputs) hub_predictions
'human'], restored_predictions['human'], atol=1e-5) np.allclose(hub_predictions[
# Can run with 'is_training=True' but note that this will
# change the predictions as the batch statistics will be updated
# and the outputs will likley not match the TF-hub model.
# enformer(inputs_cropped, is_training=True)