flowchart LR
A[Load Batch] --> B[Forward Pass] --> C[Compute Loss] --> D[Backpropagation] --> E[Update Weights]
E --> F{More Batches?}
F -->|Yes| A
F -->|No| G[Validate & Save Checkpoint]
G --> H{More Epochs?}
H -->|Yes| A
H -->|No| I[Done]
style A fill:#d6efe8,stroke:#1f6b4f
style G fill:#d8e4f0,stroke:#2a5278
style H fill:#e4dced,stroke:#5a3d75
style I fill:#f0e2d6,stroke:#994a2a
Training a DeepTaxa Model
This tutorial demonstrates how to train DeepTaxa from scratch on 16S rRNA gene sequences, inspect the resulting model checkpoints, and visualize learning progress. By the end, you will have a trained classifier capable of assigning taxonomic labels from Domain to Species.
Training a deep neural network involves repeatedly adjusting the model’s parameters to reduce classification errors on a labeled dataset. Each pass through the full training set is called an epoch. After each epoch, the model is evaluated on a held-out validation set to monitor generalization and detect overfitting (the point at which the model begins memorizing training examples rather than learning transferable patterns).
Prerequisites - A CUDA-capable GPU (training on CPU is possible but roughly 10x slower) - Approximately 1 GB of disk space for the training data - 1 to 2 hours for the full training run described here
1 Setup
The following cells configure the Python environment and verify GPU availability.
The PATH modification and GPU verification below are specific to containerized environments (e.g., Vast.ai, Docker) where the system Python differs from the virtual environment. If you are running on a local machine with CUDA already configured, skip ahead to the installation step.
# Configure PATH so bash cells find the GPU-enabled Python
import os
os.environ['PATH'] = '/venv/main/bin:' + os.environ['PATH']Verify GPU access.
import torch
print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'No GPU')True
1
NVIDIA A40
Install DeepTaxa and optional libraries for evaluation and plotting.
%%bash
mkdir -p ~/deeptaxa-workspace && cd ~/deeptaxa-workspace
test -d deeptaxa || git clone https://github.com/systems-genomics-lab/deeptaxa.git
cd deeptaxa && pip install -q .
pip install -q matplotlib scikit-learnConfirm the installation.
%%bash
deeptaxa --versionDeepTaxa 0.1.0.dev1
2 Training Data
DeepTaxa trains on the Greengenes 2 reference database (release 2024.09), a curated collection of full-length 16S rRNA gene sequences with standardized seven-rank taxonomy annotations. The training subset contains approximately 277,000 sequences spanning over 16,000 species.
The class distribution is highly imbalanced: some species are represented by thousands of sequences, while others have fewer than five. This imbalance is inherent to reference databases and reflects the uneven sampling of microbial diversity in culture collections and environmental surveys.
Download the training sequences and taxonomy labels from Hugging Face.
%%bash
cd ~/deeptaxa-workspace
mkdir -p data/greengenes && cd data/greengenes
test -f gg_2024_09_training.fna.gz || curl -L -O https://huggingface.co/datasets/systems-genomics-lab/greengenes/resolve/main/gg_2024_09_training.fna.gz
test -f gg_2024_09_training.tsv.gz || curl -L -O https://huggingface.co/datasets/systems-genomics-lab/greengenes/resolve/main/gg_2024_09_training.tsv.gzList the downloaded files.
%%bash
ls -lh ~/deeptaxa-workspace/data/greengenes/total 124M
-rw-rw-r--. 1 ahmed ahmed 25M Apr 29 01:51 gg_2024_09_testing.fna.gz
-rw-rw-r--. 1 ahmed ahmed 809K Apr 29 01:51 gg_2024_09_testing.tsv.gz
-rw-rw-r--. 1 ahmed ahmed 97M Apr 29 02:08 gg_2024_09_training.fna.gz
-rw-rw-r--. 1 ahmed ahmed 2.7M Apr 29 01:56 gg_2024_09_training.tsv.gz
3 Training
The training loop follows this sequence at each epoch:
The deeptaxa train command orchestrates this entire pipeline. Only --fasta-file, --taxonomy-file, and --output-dir are required — all other flags default to the canonical SMALL values used in the published checkpoint.
| Flag | Purpose | Default |
|---|---|---|
--fasta-file |
Input sequences in FASTA format | (required) |
--taxonomy-file |
Tab-separated file mapping sequence IDs to taxonomy | (required) |
--output-dir |
Directory for checkpoints, metrics, and logs | (required) |
--model-type |
Architecture to use | hybridcnnbert |
--epochs |
Number of passes through the training set | 10 |
--batch-size |
Number of sequences per gradient update | 64 |
--learning-rate |
Learning rate | 5e-4 |
--seed |
Random seed for reproducibility | 42 |
--hidden-size |
Transformer hidden dimension | 896 |
--num-hidden-layers |
Number of Transformer layers | 4 |
--num-attention-heads |
Number of attention heads | 7 |
--intermediate-size |
Transformer feed-forward dimension | 3584 |
--embed-dim |
CNN embedding dimension | 896 |
--num-filters |
CNN filters per kernel size | 256 |
--kernel-sizes |
CNN kernel widths | 3 5 7 |
--num-conv-layers |
Number of CNN layers | 1 |
--hidden-dropout-prob |
Dropout probability | 0.20 |
The architecture flags and seed above match the configuration used to produce the published DeepTaxa checkpoint. For details on what each architecture parameter controls, see the architecture tutorial.
Training time depends on GPU model, batch size, and data loading speed. On an NVIDIA A40, each epoch takes approximately 14 minutes (total: roughly 2.5 hours for 10 epochs). On older hardware or with smaller batch sizes, expect proportionally longer runs.
During training, DeepTaxa logs per-batch progress, loss values, and validation metrics to the console. NaN gradient warnings (if they appear) are handled automatically by skipping the affected batch.
Launch the training run.
%%bash
cd ~/deeptaxa-workspace
deeptaxa train --fasta-file data/greengenes/gg_2024_09_training.fna.gz --taxonomy-file data/greengenes/gg_2024_09_training.tsv.gz --model-type hybridcnnbert --output-dir outputs/training4 Understanding Training Metrics
4.1 The loss function
DeepTaxa minimizes a weighted sum of cross-entropy losses, one per taxonomic rank:
\[\mathcal{L} = \sum_{r=1}^{R} w_r \cdot \text{CE}_r\]
where \(R = 7\) is the number of taxonomic ranks (Domain through Species) and \(w_r\) is the weight assigned to rank \(r\) (1.0 by default, giving equal importance to each level). Separate per-rank losses allow the model to train all seven classifiers simultaneously from a single forward pass.
For each rank \(r\), let \(\hat{y}_{i,c}^{(r)}\) denote the predicted probability for class \(c\) at rank \(r\) (the softmax output), and let \(y_i\) be the index of the true class for sequence \(i\). The cross-entropy at rank \(r\) is:
\[\text{CE}_r = -\frac{1}{N} \sum_{i=1}^{N} \log \hat{y}_{i,\, y_i}^{(r)}\]
where \(N\) is the number of sequences in the batch. The subscript \(y_i\) selects the predicted probability assigned to the correct class: the loss is large when the model assigns low probability to the correct answer and approaches zero as that probability approaches 1.0.
DeepTaxa also supports focal loss (--loss-type focal), which down-weights well-classified examples and concentrates gradient updates on harder cases. This is particularly useful when the training data are highly imbalanced across classes. See the architecture tutorial for details.
4.2 Learning rate schedule
Training uses a linear warmup-and-decay schedule: the learning rate increases linearly from zero to the base rate over the first 10% of steps (warmup), then decreases linearly back to zero over the remaining 90% (decay). Warmup prevents early instability — the classification heads start from random weights and initially produce large, noisy gradients that would destabilize training if the full learning rate were applied immediately.
# --- Learning rate schedule ---
import numpy as np
import matplotlib.pyplot as plt
n_epochs = 10
batch_size = 64
lr_base = 5e-4
warmup_ratio = 0.1
n_train = int(277_000 * (1 - 0.1)) # approximate after 10% validation split
steps_per_epoch = int(np.ceil(n_train / batch_size))
total_steps = n_epochs * steps_per_epoch
warmup_steps = int(total_steps * warmup_ratio)
steps = np.arange(total_steps + 1)
lr = np.where(
steps <= warmup_steps,
lr_base * (steps / warmup_steps),
lr_base * (1 - (steps - warmup_steps) / (total_steps - warmup_steps))
)
plt.figure(figsize=(8, 4))
plt.plot(steps, lr, color='#2a5278', linewidth=1.5)
for e in range(1, n_epochs + 1):
plt.axvline(steps_per_epoch * e, color='#994a2a', linestyle=':', alpha=0.35, linewidth=0.8)
plt.axvline(steps_per_epoch, color='#994a2a', linestyle=':', alpha=0.35,
linewidth=0.8, label='Epoch boundary')
plt.axvline(warmup_steps, color='gray', linestyle='--', alpha=0.8,
label=f'Warmup end (step {warmup_steps:,})')
plt.xlabel('Training step')
plt.ylabel('Learning rate')
plt.legend()
plt.tight_layout()
plt.show()
4.3 Interpreting learning curves
Plotting training loss and validation loss over epochs reveals the model’s learning trajectory:
- Both decreasing: the model is learning and generalizing to unseen data.
- Training loss decreasing, validation loss increasing: the model is overfitting to the training data.
- Both plateauing: the model has converged. Additional epochs are unlikely to improve performance.
Load the per-epoch metrics from the JSON files saved during training.
# --- Load training metrics ---
import os
os.chdir(os.path.expanduser('~/deeptaxa-workspace'))
import json, glob, numpy as np, matplotlib.pyplot as plt
files = sorted(glob.glob('outputs/training/metrics/*epoch*.json'))
epochs, train_loss, val_loss = [], [], []
f1_data = {r: [] for r in range(7)}
for path in files:
with open(path) as fh:
d = json.load(fh)
epochs.append(d['model_details']['current_epoch'])
train_loss.append(d['performance_metrics']['training_loss'])
val_loss.append(d['performance_metrics']['validation_loss'])
for r in range(7):
f1_data[r].append(d['performance_metrics']['validation_metrics'][str(r)]['f1_score'])
# Sort by epoch number
order = np.argsort(epochs)
epochs = [epochs[i] for i in order]
train_loss = [train_loss[i] for i in order]
val_loss = [val_loss[i] for i in order]
for r in range(7):
f1_data[r] = [f1_data[r][i] for i in order]
print('Loaded metrics for', len(epochs), 'epochs')Loaded metrics for 10 epochs
4.4 Training and validation loss
The logarithmic y-axis makes it easier to see the rate of improvement. Training loss decreases steadily as the optimizer adjusts the model weights. Validation loss should track the training loss; a widening gap between the two is the classic signature of overfitting.
Plot the loss curves.
# --- Loss curves ---
plt.figure()
plt.plot(epochs, train_loss, 'o-', label='Training')
plt.plot(epochs, val_loss, 's--', label='Validation')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.yscale('log')
plt.legend()
plt.tight_layout()
plt.show()
The training loss drops rapidly in the first few epochs and continues to decrease throughout. The validation loss follows a similar trajectory but flattens earlier as the model begins to overfit. The widening gap between the two curves at later epochs is the classic signature of overfitting, and it is precisely this signal that helps identify the best checkpoint (the epoch with the lowest validation loss).
4.5 Validation F1-score by taxonomic rank
The F1-score balances precision (fraction of predicted positives that are correct) and recall (fraction of true positives that are detected):
\[F_1 = 2 \cdot \frac{\text{Precision} \cdot \text{Recall}}{\text{Precision} + \text{Recall}}\]
The denominator is the harmonic mean rather than the arithmetic mean, which ensures that a low value in either precision or recall strongly depresses the score. A classifier that achieves 90% precision but only 10% recall gets an \(F_1\) of 0.18, not 0.50.
Higher taxonomic ranks (Domain, Phylum) converge within the first epoch because the model needs to learn only a small number of coarse-grained distinctions. Lower ranks (Genus, Species) require more epochs to reach their peak performance because the number of classes is orders of magnitude larger and the distinguishing sequence features are more subtle. For a comparison of macro versus weighted F1 and its implications for class-imbalanced data, see the analysis tutorial.
Plot the F1 trajectory for each rank.
# --- F1 by rank ---
RANK_LABELS = ['Domain', 'Phylum', 'Class', 'Order', 'Family', 'Genus', 'Species']
plt.figure()
for r in range(7):
plt.plot(epochs, f1_data[r], 'o-', label=RANK_LABELS[r])
plt.xlabel('Epoch')
plt.ylabel('Validation F1')
plt.ylim(0.5, 1.02)
plt.legend(ncol=2)
plt.tight_layout()
plt.show()
The plot shows rapid convergence at Domain through Order (above 0.98 by epoch 2) and a more gradual climb at Genus and Species. The Species F1 curve is the one to watch: it reflects the model’s ability to resolve the finest-grained taxonomic distinctions, which is the most challenging and practically relevant task.
5 Checkpoint Inspection
After each evaluation epoch, DeepTaxa saves a checkpoint file (.pt) containing everything needed to resume training or run inference:
- Model weights: all learnable parameters (convolutional filters, attention matrices, classification heads), required for inference and for resuming training.
- Optimizer state: the AdamW (Loshchilov & Hutter, 2019) momentum buffers, which accumulate the running mean and variance of recent gradients for each parameter; without these, resumption restarts momentum from zero, slowing convergence.
- Scheduler state: the current position in the learning rate warmup and decay schedule; without it, the learning rate resets to its initial value rather than continuing from where training left off.
- GradScaler state: the mixed-precision loss scaling factor, which preserves numerical stability for FP16 gradients across resumed runs.
- Label encoders: the mapping between taxonomy strings (e.g., “Pseudomonadota”) and integer class indices, required to decode model output back to human-readable labels at inference time.
Inspect the epoch 10 checkpoint using the deeptaxa describe command, which summarizes the checkpoint contents without loading the full model onto a GPU.
%%bash
cd ~/deeptaxa-workspace
DEEPTAXA_UUID=$(cat outputs/training/deeptaxa_uuid.txt)
echo "Run UUID: ${DEEPTAXA_UUID}"
deeptaxa describe --checkpoint outputs/training/checkpoints/deeptaxa_${DEEPTAXA_UUID}_epoch10.ptRun UUID: 2026_04_29T02_08_53_a9efd989_bf0d_4ea8_ae66_88a027005ab3
2026-04-29 04:39:31,634 - INFO -
======================================================================
DeepTaxa Model Description (v0.1.0.dev1)
--------------------------------------------------
Checkpoint: outputs/training/checkpoints/deeptaxa_2026_04_29T02_08_53_a9efd989_bf0d_4ea8_ae66_88a027005ab3_epoch10.pt
Timestamp: 2026-04-29T04:39:31.634207
======================================================================
Model Details:
--------------------------------------------------
run-uuid: 2026_04_29T02_08_53_a9efd989_bf0d_4ea8_ae66_88a027005ab3
model-type: hybridcnnbert
tokenizer: zhihan1996/DNABERT-2-117M
epoch: 10
total-parameters: 76,365,205
max-length: 512
embed-dim: 896
num-filters: 256
kernel-sizes: [3, 5, 7]
num-conv-layers: 1
hidden-size: 896
num-hidden-layers: 4
num-attention-heads: 7
intermediate-size: 3,584
output-attentions: 0
hidden-dropout-prob: 0.2
Training Hyperparameters:
--------------------------------------------------
learning-rate: 0.0005
batch-size: 64
target-epochs: 10
focal-gamma: 2.0
level-weights: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
optimizer: {'lr': 0.0005, 'betas': [0.9, 0.999], 'eps': 1e-08, 'weight_decay': 0.01}
scheduler-steps: 34,670
Dataset Info:
--------------------------------------------------
total-sequences: 277,336
training: 221,868
validation: 55,468
fasta-file: data/greengenes/gg_2024_09_training.fna.gz
taxonomy-file: data/greengenes/gg_2024_09_training.tsv.gz
Taxonomic Levels:
--------------------------------------------------
Level 0 - domain: 2 labels
Level 1 - phylum: 129 labels
Level 2 - class: 349 labels
Level 3 - order: 997 labels
Level 4 - family: 2250 labels
Level 5 - genus: 7287 labels
Level 6 - species: 16909 labels
Timing:
--------------------------------------------------
training-time: 842.550651 seconds
evaluation-time: 52.981471 seconds
System Info:
--------------------------------------------------
cuda: Available
gpu: NVIDIA A40
The output shows the model architecture, parameter counts, training epoch, and the taxonomic ranks with their class counts. This information is useful for verifying that the checkpoint matches the expected configuration before using it for prediction.
6 Prediction on the Test Set
Download the test data and run prediction using the epoch 10 checkpoint.
%%bash
cd ~/deeptaxa-workspace
DEEPTAXA_UUID=$(cat outputs/training/deeptaxa_uuid.txt)
CHECKPOINT=outputs/training/checkpoints/deeptaxa_${DEEPTAXA_UUID}_epoch10.pt
test -f data/greengenes/gg_2024_09_testing.fna.gz || curl -L -o data/greengenes/gg_2024_09_testing.fna.gz https://huggingface.co/datasets/systems-genomics-lab/greengenes/resolve/main/gg_2024_09_testing.fna.gz
test -f data/greengenes/gg_2024_09_testing.tsv.gz || curl -L -o data/greengenes/gg_2024_09_testing.tsv.gz https://huggingface.co/datasets/systems-genomics-lab/greengenes/resolve/main/gg_2024_09_testing.tsv.gz
deeptaxa predict --fasta-file data/greengenes/gg_2024_09_testing.fna.gz --taxonomy-file data/greengenes/gg_2024_09_testing.tsv.gz --checkpoint $CHECKPOINT --tabular --output-dir outputs/workflow_predictionsPrint the per-rank accuracy on the test set to see how the trained model performs.
# --- Test set accuracy ---
import os
os.chdir(os.path.expanduser('~/deeptaxa-workspace'))
import pandas as pd
from sklearn.metrics import accuracy_score, f1_score
RANKS = ['domain', 'phylum', 'class', 'order', 'family', 'genus', 'species']
pred_files = sorted(glob.glob('outputs/workflow_predictions/*_predictions.tsv'))
test_df = pd.read_csv(pred_files[0], sep='\t')
print(f'Test sequences: {len(test_df):,}')
print()
print(f'{"Rank":8s} {"Accuracy":>8s} {"Weighted F1":>11s}')
print('-' * 31)
for r in RANKS:
acc = accuracy_score(test_df[f'{r}_true'], test_df[f'{r}_predicted'])
f1 = f1_score(test_df[f'{r}_true'], test_df[f'{r}_predicted'], average='weighted', zero_division=0)
print(f' {r.capitalize():8s} {acc:8.4f} {f1:11.4f}')Test sequences: 69,335
Rank Accuracy Weighted F1
-------------------------------
Domain 0.9998 0.9998
Phylum 0.9969 0.9967
Class 0.9963 0.9959
Order 0.9904 0.9893
Family 0.9861 0.9842
Genus 0.9686 0.9641
Species 0.9298 0.9213
7 Summary
This tutorial covered the full training lifecycle: preparing data, configuring and launching a training run, inspecting checkpoints, interpreting loss curves and per-rank F1 trajectories, and evaluating the trained model on held-out test data.
The training run demonstrated here follows the same procedure, data, seed, and architecture used to produce the published DeepTaxa checkpoint (10 epochs on the full Greengenes dataset), which achieved 93.0% species-level accuracy and 92.0% weighted F1-score.
For prediction with the published pre-trained model, see the prediction tutorial. For in-depth analysis of model behavior, see the analysis tutorial.
8 Appendix: Key Concepts
Epoch. One complete pass through the entire training dataset. During each epoch, the model processes every sequence once (in random order, via mini-batches) and updates its parameters after each batch.
Overfitting. When a model performs well on the training data but poorly on unseen data. This happens when the model memorizes training examples rather than learning generalizable patterns. Monitoring validation loss is the standard way to detect overfitting.
Cross-entropy loss. A loss function for classification tasks that measures the discrepancy between the model’s predicted probability distribution and the true label. It equals \(-\log p_t\), where \(p_t\) is the predicted probability of the correct class. Lower cross-entropy indicates better predictions.
Learning rate schedule. The learning rate controls the step size of each parameter update. Starting large allows rapid initial progress; gradually reducing it allows fine-grained convergence. DeepTaxa uses linear warmup (increasing from zero to the base rate over the first 10% of steps) followed by linear decay (decreasing back to zero over the remaining 90%).