cse2000-software-project/kernel/spectral/transcription.py
2024-06-10 13:50:10 +02:00

327 lines
10 KiB
Python

from deepgram import DeepgramClient, PrerecordedOptions, FileSource
from fastapi import HTTPException
from jiwer import process_words, process_characters
import jiwer
from .signal_analysis import get_audio, calculate_signal_duration
from .types import FileStateType
from allosaurus.app import read_recognizer # type: ignore
import tempfile
import os
from typing import Any
def calculate_error_rates(
reference_annotations: list[dict], hypothesis_annotations: list[dict]
) -> dict | None:
"""
Calculate error rates between the reference transcription and annotations.
This function calculates both word-level and character-level error rates
based on the provided reference transcription and annotations.
Parameters:
- reference (str): The reference transcription.
- annotations (list of dict): The list of annotations where each annotation is a dictionary with a "value" key.
Returns:
- dict: A dictionary containing word-level and character-level error rates.
"""
reference = annotation_to_sentence(reference_annotations)
if reference == "":
return None
hypothesis = annotation_to_sentence(hypothesis_annotations)
word_level = word_level_processing(reference, hypothesis)
character_level = character_level_processing(reference, hypothesis)
return {"wordLevel": word_level, "characterLevel": character_level}
def annotation_to_sentence(annotations: list) -> str:
"""
Convert annotations to a single hypothesis string.
This function concatenates the values from the annotations list to form a hypothesis string.
Parameters:
- annotations (list of dict): The list of annotations where each annotation is a dictionary with a "value" key.
Returns:
- str: A single concatenated hypothesis string.
"""
res = ""
if len(annotations) == 0:
return res
for annotation in annotations:
if annotation["value"] == "":
continue
res += annotation["value"] + " "
return res[: len(res) - 1]
def word_level_processing(reference: str, hypothesis: str) -> dict[str, Any]:
"""
Process word-level error metrics between the reference and hypothesis.
This function processes word-level metrics.
Parameters:
- reference (str): The reference transcription.
- hypothesis (str): The hypothesis transcription.
Returns:
- dict: A dictionary containing word-level error metrics and alignments.
"""
processed_data = process_words(reference=reference, hypothesis=hypothesis)
result = {
"wer": processed_data.wer,
"mer": processed_data.mer,
"wil": processed_data.wil,
"wip": processed_data.wip,
"hits": processed_data.hits,
"substitutions": processed_data.substitutions,
"insertions": processed_data.insertions,
"deletions": processed_data.deletions,
"reference": processed_data.references[0],
"hypothesis": processed_data.hypotheses[0],
"alignments": get_alignments(processed_data.alignments[0]),
}
return result
def character_level_processing(reference: str, hypothesis: str) -> dict[str, Any]:
"""
Process character-level error metrics between the reference and hypothesis.
This function processes character-level metrics.
Parameters:
- reference (str): The reference transcription.
- hypothesis (str): The hypothesis transcription.
Returns:
- dict: A dictionary containing character-level error metrics and alignments.
"""
processed_data = process_characters(reference=reference, hypothesis=hypothesis)
result = {
"cer": processed_data.cer,
"hits": processed_data.hits,
"substitutions": processed_data.substitutions,
"insertions": processed_data.insertions,
"deletions": processed_data.deletions,
"reference": processed_data.references[0],
"hypothesis": processed_data.hypotheses[0],
"alignments": get_alignments(processed_data.alignments[0]),
}
return result
def get_alignments(unparsed_alignments: list[jiwer.process.AlignmentChunk]) -> list[dict]:
"""
Convert unparsed alignments into a structured format.
This function processes unparsed alignment data and converts it into a list of dictionaries
with detailed alignment information.
Parameters:
- unparsed_alignments (list): A list of unparsed alignment objects.
Returns:
- list of dict: A list of dictionaries where each dictionary contains alignment information.
"""
alignments = []
for alignment in unparsed_alignments:
alignment_dict = {
"type": alignment.type,
"referenceStartIndex": alignment.ref_start_idx,
"referenceEndIndex": alignment.ref_end_idx,
"hypothesisStartIndex": alignment.hyp_start_idx,
"hypothesisEndIndex": alignment.hyp_end_idx,
}
alignments.append(alignment_dict)
return alignments
def get_transcription(model: str, file: FileStateType):
"""
Get transcription of an audio file using the specified model.
This function gets the transcription of an audio file using the specified model.
Parameters:
- model (str): The transcription model to use.
- file (dict): The file object containing the audio data.
Returns:
- list: A list of transcriptions containing words with their start and end times.
Raises:
- HTTPException: If the specified model is not found.
"""
if model == "deepgram":
return fill_gaps(deepgram_transcription(file["data"]), file)
if model == "allosaurus":
return fill_gaps(allosaurs_transcription(file), file)
raise HTTPException(status_code=404, detail="Model was not found")
def fill_gaps(transcriptions: list[dict], file: FileStateType) -> list[dict]:
res = []
audio = get_audio(file)
duration = calculate_signal_duration(audio)
if len(transcriptions) == 0:
return [{"value": "", "start": 0, "end": duration}]
time = 0
for transcription in transcriptions:
if time != transcription["start"]:
res.append({"value": "", "start": time, "end": transcription["start"]})
time = transcription["end"]
res.append(transcription)
if time != duration:
res.append({"value": "", "start": time, "end": duration})
return res
def deepgram_transcription(data: bytes) -> list[dict]:
"""
Transcribe audio data using Deepgram API.
This function transcribes audio data using the Deepgram API.
Parameters:
- data (bytes): The audio data to transcribe.
Returns:
- list: A list of transcriptions containing words with their start and end times.
Raises:
- Exception: If an error occurs during the transcription process.
"""
try:
# STEP 1: Create a Deepgram client using the API key
key = os.getenv("DG_KEY")
deepgram = None
if key is None:
raise Exception("No API key for Deepgram is found")
else:
deepgram = DeepgramClient(key)
payload: FileSource = {
"buffer": data,
}
# STEP 2: Configure Deepgram options for audio analysis
options = PrerecordedOptions(
model="nova-2",
smart_format=True,
profanity_filter=False,
)
# STEP 3: Call the transcribe_file method with the text payload and options
response = deepgram.listen.prerecorded.v("1").transcribe_file(payload, options)
res = []
for word in response["results"]["channels"][0]["alternatives"][0]["words"]:
res.append({"value": word["word"], "start": word["start"], "end": word["end"]})
return res
except Exception as e:
print(f"Exception: {e}")
return []
def allosaurs_transcription(file: FileStateType) -> Any:
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_wav:
temp_wav.write(file["data"])
temp_wav_filename = temp_wav.name
word_level_transcription = fill_gaps(deepgram_transcription(file["data"]), file)
model = read_recognizer()
phoneme_level_transcription = model.recognize(temp_wav_filename, timestamp=True, emit=1.2)
phoneme_level_parsed = []
for phoneme_string in phoneme_level_transcription.splitlines():
phoneme_level_parsed.append(
[float(phoneme_string.split(" ")[0]), phoneme_string.split(" ")[2]]
)
phoneme_word_splits = get_phoneme_word_splits(word_level_transcription, phoneme_level_parsed)
return get_phoneme_transcriptions(phoneme_word_splits)
def get_phoneme_word_splits(
word_level_transcription: list[dict], phoneme_level_parsed: list[list]
) -> list[dict]:
if len(word_level_transcription) == 0:
return []
word_pointer = 0
phoneme_pointer = 0
phoneme_word_splits = []
current_split = {"phonemes": [], "word_transcription": None}
while word_pointer < len(word_level_transcription) and phoneme_pointer < len(
phoneme_level_parsed
):
if phoneme_level_parsed[phoneme_pointer][0] > word_level_transcription[word_pointer]["end"]:
current_split["word_transcription"] = word_level_transcription[word_pointer]
phoneme_word_splits.append(current_split)
current_split = {"phonemes": [], "word_transcription": None}
word_pointer += 1
continue
current_split["phonemes"].append(phoneme_level_parsed[phoneme_pointer])
phoneme_pointer += 1
if phoneme_pointer == len(phoneme_level_parsed):
current_split["word_transcription"] = word_level_transcription[word_pointer]
phoneme_word_splits.append(current_split)
return phoneme_word_splits
def get_phoneme_transcriptions(phoneme_word_splits: list[Any]) -> list[dict]:
res = []
for phoneme_split in phoneme_word_splits:
if len(phoneme_split) == 0:
continue
for i in range(len(phoneme_split["phonemes"])):
start = 0
if i == 0:
start = phoneme_split["word_transcription"]["start"]
else:
# this is an (educated) guess, it could be way off :D
start = (phoneme_split["phonemes"][i - 1][0] + phoneme_split["phonemes"][i][0]) / 2
end = 0
if i + 1 == len(phoneme_split["phonemes"]):
end = phoneme_split["word_transcription"]["end"]
else:
end = (phoneme_split["phonemes"][i + 1][0] + phoneme_split["phonemes"][i][0]) / 2
res.append({"value": phoneme_split["phonemes"][i][1], "start": start, "end": end})
return res