from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model import sys import torch import torchinfo hf_path = sys.argv[1] audio_file = sys.argv[2] extract = Wav2Vec2FeatureExtractor.from_pretrained(hf_path) hf_model = Wav2Vec2Model.from_pretrained(hf_path) torchinfo.summary(hf_model) hf_model.eval() import torchaudio waveform, sample_rate = torchaudio.load(audio_file) if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) with torch.no_grad(): feat = extract(waveform.squeeze().numpy(), sampling_rate=sample_rate, return_tensors="pt", padding=True) out = hf_model(feat.input_values,output_hidden_states=True) last_hidden_states = out.last_hidden_state cnn_out = out.extract_features hidden_states = out.hidden_states print("CNN features shape:", cnn_out.shape) print("Hidden states length:", hidden_states.__len__()) print("Last hidden states shape:", last_hidden_states.shape) #normalize features for visualization import numpy as np last_hidden_states = (last_hidden_states - last_hidden_states.min()) / (last_hidden_states.max() - last_hidden_states.min()) cnn_out = [(feat - feat.min()) / (feat.max() - feat.min()) for feat in cnn_out] #apply log scaling last_hidden_states = torch.log1p(last_hidden_states * 100) cnn_out = [torch.log1p(feat * 100) for feat in cnn_out] #visualize output feature maps import matplotlib.pyplot as plt plt.figure(figsize=(12, 6)) plt.subplot(2, 1, 1) plt.title("Last Hidden States") plt.imshow(last_hidden_states[0].cpu().numpy().T, aspect='auto', origin='lower') plt.colorbar() plt.subplot(2, 1, 2) plt.title("CNN Features") plt.imshow(cnn_out[-1].cpu().numpy().T, aspect='auto', origin='lower') plt.colorbar() plt.tight_layout() plt.savefig("hf_wav2vec2_features.png") plt.close()