pengyizhou commited on
Commit
adc1dbd
·
verified ·
1 Parent(s): d69b7c6

Upload decode_Khmer.py

Browse files
Files changed (1) hide show
  1. decode_Khmer.py +83 -0
decode_Khmer.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # pip install transformers datasets torch soundfile jiwer
4
+
5
+ from datasets import load_dataset, Audio
6
+ from transformers import pipeline, WhisperProcessor
7
+ from torch.utils.data import DataLoader
8
+ import torch
9
+ from jiwer import wer as jiwer_wer
10
+ from jiwer import cer as jiwer_cer
11
+ import ipdb
12
+
13
+ # 1. Load FLEURS Burmese test set, cast to 16 kHz audio
14
+ ds = load_dataset("google/fleurs", "km_kh", split="test")
15
+ ds = ds.cast_column("audio", Audio(sampling_rate=16_000))
16
+
17
+ from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
18
+
19
+
20
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
21
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
22
+
23
+ # model_id = "openai/whisper-large-v3"
24
+ model_id = "pengyizhou/whisper-fleurs-km_kh"
25
+
26
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
27
+ model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
28
+ )
29
+ model.to(device)
30
+ whisper_model = "openai/whisper-large-v3"
31
+ processor = WhisperProcessor.from_pretrained(whisper_model, language="khmer")
32
+
33
+ asr = pipeline(
34
+ "automatic-speech-recognition",
35
+ model=model,
36
+ tokenizer=processor.tokenizer,
37
+ feature_extractor=processor.feature_extractor,
38
+ torch_dtype=torch_dtype,
39
+ chunk_length_s=30,
40
+ batch_size=64,
41
+ max_new_tokens=440,
42
+ device=device,
43
+ )
44
+
45
+
46
+ # 3. Batch‐wise transcription function
47
+ def transcribe_batch(batch):
48
+ # `batch["audio"]` is a list of {"array": np.ndarray, ...}
49
+ inputs = [ ex["array"] for ex in batch["audio"] ]
50
+ outputs = asr(inputs) # returns a list of dicts with "text"
51
+ # lower-case and strip to normalize for CER
52
+ preds = [ out["text"].lower().strip() for out in outputs ]
53
+ return {"prediction": preds}
54
+
55
+ # 4. Map over the dataset in chunks of, say, 32 examples at a time
56
+ result = ds.map(
57
+ transcribe_batch,
58
+ batched=True,
59
+ batch_size=64, # feed 32 audios → pipeline will sub-batch into 8s
60
+ remove_columns=ds.column_names
61
+ )
62
+
63
+ # ipdb.set_trace()
64
+ # 5. Compute corpus-level CER with jiwer
65
+ # refs = "\n".join(t.lower().strip() for t in ds["transcription"])
66
+ # preds = "\n".join(t for t in result["prediction"])
67
+ # score = jiwer_cer(refs, preds)
68
+ refs = [t.lower().strip() for t in ds["transcription"]]
69
+ preds = [t for t in result["prediction"]]
70
+ score_cer = jiwer_cer(refs, preds)
71
+ score_wer = jiwer_wer(refs, preds)
72
+
73
+ print(f"Zero-shot CER on FLEURS km_kh: {score_cer*100:.2f}%")
74
+ print(f"Zero-shot WER on FLEURS km_kh: {score_wer*100:.2f}%")
75
+
76
+ with open("./km_kh_finetune.pred", "w") as pred_results:
77
+ for pred in preds:
78
+ pred_results.write("{}\n".format(pred))
79
+
80
+ with open("./km_kh.ref", "w") as ref_results:
81
+ for ref in refs:
82
+ ref_results.write("{}\n".format(ref))
83
+