AI2Text – Bilingual ASR (Vietnamese + English)

A ~30M-parameter Transformer Seq2Seq Automatic Speech Recognition model trained on ~224k bilingual (Vietnamese + English) audio samples.

Model Description

Attribute Value
Architecture Encoder-Decoder Transformer
Parameters ~30,325,164
d_model 256
Encoder layers 14 (RoPE + Flash Attention)
Decoder layers 6 (causal, cross-attention)
Vocabulary size 3,500 (SentencePiece BPE)
Language embedding Yes (Vietnamese=0, English=1)
Normalization RMSNorm
Activation SiLU (Swish)
Positional encoding Rotary (RoPE)

Modern Components

  • RMSNorm – more efficient than LayerNorm
  • SiLU (Swish) activation
  • Rotary Positional Embedding (RoPE) – better generalization
  • Flash Attention (SDPA) – memory-efficient attention
  • Hybrid CTC / Attention loss – helps encoder learn alignment

Training Data

Trained on Cong123779/AI2Text-Bilingual-ASR-Dataset:

  • Train: ~194,167 samples (77% Vietnamese, 23% English)
  • Validation: ~30,123 samples

Audio format: 16 kHz mono WAV, 80-dim Mel-spectrogram features.

Training Configuration

Hyperparameter Value
Batch size 32 (effective 128 w/ grad-accum Γ— 4)
Learning rate 3e-4
Epochs 50
Warmup 3% of training steps
Mixed precision bfloat16 (AMP)
Gradient clipping 0.5
CTC weight 0.2
Scheduled sampling 1.0 β†’ 0.5 (linear)

Usage

import torch
from pathlib import Path
import sys

# Clone the repo and add to path
sys.path.insert(0, "AI2Text")

from models.asr_base import ASRModel
from preprocessing.sentencepiece_tokenizer import SentencePieceTokenizer
from preprocessing.audio_processing import AudioProcessor

# Load tokenizer
tokenizer = SentencePieceTokenizer("models/tokenizer_vi_en_3500.model")

# Load model
checkpoint = torch.load("best_model.pt", map_location="cpu")
config = checkpoint.get("config", {})

model = ASRModel(
    input_dim=80,
    vocab_size=3500,
    d_model=256,
    num_encoder_layers=14,
    num_decoder_layers=6,
    num_heads=8,
    d_ff=2048,
    num_languages=2,
)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

# Transcribe
audio_processor = AudioProcessor(sample_rate=16000, n_mels=80)
features = audio_processor.process("audio.wav")  # (time, 80)
features = features.unsqueeze(0)                 # (1, time, 80)
lengths  = torch.tensor([features.size(1)])

with torch.no_grad():
    tokens = model.generate(
        features, lengths=lengths,
        language_ids=torch.tensor([0]),   # 0=vi, 1=en
        max_len=128,
        sos_token_id=tokenizer.sos_token_id,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id,
    )
    text = tokenizer.decode(tokens[0].tolist())
    print(text)

Framework

Built with PyTorch. Optimized for RTX 5060TI 16GB / Ryzen 9 9990X / 64GB RAM.

License

Apache 2.0

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Dataset used to train Cong123779/AI2Text-Bilingual-ASR