cse2000-software-project/kernel/spectral/error_rates.py
2024-06-10 16:15:15 +02:00

148 lines
4.6 KiB
Python

from jiwer import process_words, process_characters
import jiwer
from typing import Any
def calculate_error_rates(
reference_annotations: list[dict], hypothesis_annotations: list[dict]
) -> dict | None:
"""
Calculate error rates between the reference transcription and annotations.
This function calculates both word-level and character-level error rates
based on the provided reference transcription and annotations.
Parameters:
- reference (str): The reference transcription.
- annotations (list of dict): The list of annotations where each annotation is a dictionary with a "value" key.
Returns:
- dict: A dictionary containing word-level and character-level error rates.
"""
reference = annotation_to_sentence(reference_annotations)
if reference == "":
return None
hypothesis = annotation_to_sentence(hypothesis_annotations)
word_level = word_level_processing(reference, hypothesis)
character_level = character_level_processing(reference, hypothesis)
return {"wordLevel": word_level, "characterLevel": character_level}
def word_level_processing(reference: str, hypothesis: str) -> dict[str, Any]:
"""
Process word-level error metrics between the reference and hypothesis.
This function processes word-level metrics.
Parameters:
- reference (str): The reference transcription.
- hypothesis (str): The hypothesis transcription.
Returns:
- dict: A dictionary containing word-level error metrics and alignments.
"""
processed_data = process_words(reference=reference, hypothesis=hypothesis)
result = {
"wer": processed_data.wer,
"mer": processed_data.mer,
"wil": processed_data.wil,
"wip": processed_data.wip,
"hits": processed_data.hits,
"substitutions": processed_data.substitutions,
"insertions": processed_data.insertions,
"deletions": processed_data.deletions,
"reference": processed_data.references[0],
"hypothesis": processed_data.hypotheses[0],
"alignments": get_alignments(processed_data.alignments[0]),
}
return result
def character_level_processing(reference: str, hypothesis: str) -> dict[str, Any]:
"""
Process character-level error metrics between the reference and hypothesis.
This function processes character-level metrics.
Parameters:
- reference (str): The reference transcription.
- hypothesis (str): The hypothesis transcription.
Returns:
- dict: A dictionary containing character-level error metrics and alignments.
"""
processed_data = process_characters(reference=reference, hypothesis=hypothesis)
result = {
"cer": processed_data.cer,
"hits": processed_data.hits,
"substitutions": processed_data.substitutions,
"insertions": processed_data.insertions,
"deletions": processed_data.deletions,
"reference": processed_data.references[0],
"hypothesis": processed_data.hypotheses[0],
"alignments": get_alignments(processed_data.alignments[0]),
}
return result
def annotation_to_sentence(annotations: list) -> str:
"""
Convert annotations to a single hypothesis string.
This function concatenates the values from the annotations list to form a hypothesis string.
Parameters:
- annotations (list of dict): The list of annotations where each annotation is a dictionary with a "value" key.
Returns:
- str: A single concatenated hypothesis string.
"""
res = ""
if len(annotations) == 0:
return res
for annotation in annotations:
if annotation["value"] == "":
continue
res += annotation["value"] + " "
return res[: len(res) - 1]
def get_alignments(
unparsed_alignments: list[jiwer.process.AlignmentChunk],
) -> list[dict]:
"""
Convert unparsed alignments into a structured format.
This function processes unparsed alignment data and converts it into a list of dictionaries
with detailed alignment information.
Parameters:
- unparsed_alignments (list): A list of unparsed alignment objects.
Returns:
- list of dict: A list of dictionaries where each dictionary contains alignment information.
"""
alignments = []
for alignment in unparsed_alignments:
alignment_dict = {
"type": alignment.type,
"referenceStartIndex": alignment.ref_start_idx,
"referenceEndIndex": alignment.ref_end_idx,
"hypothesisStartIndex": alignment.hyp_start_idx,
"hypothesisEndIndex": alignment.hyp_end_idx,
}
alignments.append(alignment_dict)
return alignments