AI2Text β Bilingual ASR (Vietnamese + English)
A ~30M-parameter Transformer Seq2Seq Automatic Speech Recognition model trained on ~224k bilingual (Vietnamese + English) audio samples.
Model Description
| Attribute | Value |
|---|---|
| Architecture | Encoder-Decoder Transformer |
| Parameters | ~30,325,164 |
| d_model | 256 |
| Encoder layers | 14 (RoPE + Flash Attention) |
| Decoder layers | 6 (causal, cross-attention) |
| Vocabulary size | 3,500 (SentencePiece BPE) |
| Language embedding | Yes (Vietnamese=0, English=1) |
| Normalization | RMSNorm |
| Activation | SiLU (Swish) |
| Positional encoding | Rotary (RoPE) |
Modern Components
- RMSNorm β more efficient than LayerNorm
- SiLU (Swish) activation
- Rotary Positional Embedding (RoPE) β better generalization
- Flash Attention (SDPA) β memory-efficient attention
- Hybrid CTC / Attention loss β helps encoder learn alignment
Training Data
Trained on Cong123779/AI2Text-Bilingual-ASR-Dataset:
- Train: ~194,167 samples (77% Vietnamese, 23% English)
- Validation: ~30,123 samples
Audio format: 16 kHz mono WAV, 80-dim Mel-spectrogram features.
Training Configuration
| Hyperparameter | Value |
|---|---|
| Batch size | 32 (effective 128 w/ grad-accum Γ 4) |
| Learning rate | 3e-4 |
| Epochs | 50 |
| Warmup | 3% of training steps |
| Mixed precision | bfloat16 (AMP) |
| Gradient clipping | 0.5 |
| CTC weight | 0.2 |
| Scheduled sampling | 1.0 β 0.5 (linear) |
Usage
import torch
from pathlib import Path
import sys
# Clone the repo and add to path
sys.path.insert(0, "AI2Text")
from models.asr_base import ASRModel
from preprocessing.sentencepiece_tokenizer import SentencePieceTokenizer
from preprocessing.audio_processing import AudioProcessor
# Load tokenizer
tokenizer = SentencePieceTokenizer("models/tokenizer_vi_en_3500.model")
# Load model
checkpoint = torch.load("best_model.pt", map_location="cpu")
config = checkpoint.get("config", {})
model = ASRModel(
input_dim=80,
vocab_size=3500,
d_model=256,
num_encoder_layers=14,
num_decoder_layers=6,
num_heads=8,
d_ff=2048,
num_languages=2,
)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
# Transcribe
audio_processor = AudioProcessor(sample_rate=16000, n_mels=80)
features = audio_processor.process("audio.wav") # (time, 80)
features = features.unsqueeze(0) # (1, time, 80)
lengths = torch.tensor([features.size(1)])
with torch.no_grad():
tokens = model.generate(
features, lengths=lengths,
language_ids=torch.tensor([0]), # 0=vi, 1=en
max_len=128,
sos_token_id=tokenizer.sos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
)
text = tokenizer.decode(tokens[0].tolist())
print(text)
Framework
Built with PyTorch. Optimized for RTX 5060TI 16GB / Ryzen 9 9990X / 64GB RAM.
License
Apache 2.0