Merge branch 'kernel/beartype' into 'dev'

Kernel/beartype

See merge request cse2000-software-project/2023-2024/cluster-n/11c/atypical-speech-project!105
This commit is contained in:
Thijs Houben 2024-06-10 14:05:17 +02:00
commit fe72db9176
15 changed files with 182 additions and 122 deletions

View file

@ -25,7 +25,7 @@ repos:
hooks:
- id: pyright
types_or: [python, pyi, jupyter]
additional_dependencies: [numpy, pytest, fastapi, praat-parselmouth, orjson, pydantic, scipy, psycopg, deepgram-sdk, pydub, ffmpeg-python, jiwer]
additional_dependencies: [numpy, pytest, fastapi, praat-parselmouth, orjson, pydantic, scipy, psycopg, deepgram-sdk, pydub, ffmpeg-python, jiwer, beartype]
stages: [pre-commit]
- repo: https://github.com/crate-ci/typos
rev: v1.21.0

22
kernel/poetry.lock generated
View file

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
[[package]]
name = "aiofiles"
@ -189,6 +189,24 @@ tests = ["attrs[tests-no-zope]", "zope-interface"]
tests-mypy = ["mypy (>=1.6)", "pytest-mypy-plugins"]
tests-no-zope = ["attrs[tests-mypy]", "cloudpickle", "hypothesis", "pympler", "pytest (>=4.3.0)", "pytest-xdist[psutil]"]
[[package]]
name = "beartype"
version = "0.18.5"
description = "Unbearably fast runtime type checking in pure Python."
optional = false
python-versions = ">=3.8.0"
files = [
{file = "beartype-0.18.5-py3-none-any.whl", hash = "sha256:5301a14f2a9a5540fe47ec6d34d758e9cd8331d36c4760fc7a5499ab86310089"},
{file = "beartype-0.18.5.tar.gz", hash = "sha256:264ddc2f1da9ec94ff639141fbe33d22e12a9f75aa863b83b7046ffff1381927"},
]
[package.extras]
all = ["typing-extensions (>=3.10.0.0)"]
dev = ["autoapi (>=0.9.0)", "coverage (>=5.5)", "equinox", "mypy (>=0.800)", "numpy", "pandera", "pydata-sphinx-theme (<=0.7.2)", "pytest (>=4.0.0)", "sphinx", "sphinx (>=4.2.0,<6.0.0)", "sphinxext-opengraph (>=0.7.5)", "tox (>=3.20.1)", "typing-extensions (>=3.10.0.0)"]
doc-rtd = ["autoapi (>=0.9.0)", "pydata-sphinx-theme (<=0.7.2)", "sphinx (>=4.2.0,<6.0.0)", "sphinxext-opengraph (>=0.7.5)"]
test-tox = ["equinox", "mypy (>=0.800)", "numpy", "pandera", "pytest (>=4.0.0)", "sphinx", "typing-extensions (>=3.10.0.0)"]
test-tox-coverage = ["coverage (>=5.5)"]
[[package]]
name = "certifi"
version = "2024.2.2"
@ -3062,4 +3080,4 @@ multidict = ">=4.0"
[metadata]
lock-version = "2.0"
python-versions = "^3.11"
content-hash = "1e35a62167b27370ababf328d0b9ca3c05472d6ef752fc5b7303fb7d7284db3f"
content-hash = "f2ccf666bf19504a96354945c1adbd3beeed52c8c0800a86ead0a384610317b3"

View file

@ -50,6 +50,7 @@ mutmut = "^2.5.0"
pytest-xdist = "^3.6.1"
pytest-testmon = "^2.1.1"
pytest-mock = "^3.14.0"
beartype = "^0.18.5"
[tool.mutmut]

View file

@ -0,0 +1,7 @@
try:
from beartype.claw import beartype_this_package
beartype_this_package()
except ImportError:
# in case beartype is not installed: running in production
pass

View file

@ -18,7 +18,13 @@ class Database:
Closes the database connection and cursor.
"""
def __init__(self, user, password, host, port, dbname):
user: str
password: str
host: str
port: int
dbname: str
def __init__(self, user: str, password: str, host: str, port: int, dbname: str):
"""
Initializes the Database object and opens a connection to the specified PostgreSQL database.
@ -35,7 +41,7 @@ class Database:
self.port = port
self.dbname = dbname
def connection(self):
def connection(self) -> None:
self.conn = psycopg.connect(
dbname=self.dbname,
user=self.user,
@ -45,12 +51,12 @@ class Database:
)
self.cursor = self.conn.cursor()
def fetch_file(self, id):
def fetch_file(self, id: int) -> dict:
"""
Fetches a file record from the database by its ID.
Args:
id (string): The ID of the file to fetch.
id (int): The ID of the file to fetch.
Returns:
dict: A dictionary containing the file record's details.
@ -73,7 +79,7 @@ class Database:
result[self.snake_to_camel(column[0])] = db_res[column[1] - 1]
return result
def snake_to_camel(self, snake_case_str):
def snake_to_camel(self, snake_case_str: str) -> str:
"""
Converts a snake_case string to camelCase.
@ -91,12 +97,12 @@ class Database:
components = snake_case_str.split("_")
return components[0] + "".join(x.title() for x in components[1:])
def get_transcriptions(self, file_id):
def get_transcriptions(self, file_id: int) -> list[list]:
"""
Fetches transcriptions associated with a file from the database.
Args:
file_id (string): The ID of the file to fetch transcriptions for.
file_id (int): The ID of the file to fetch transcriptions for.
Returns:
list: A list of lists containing transcription entries, where each inner list represents a file transcription and contains dictionaries with "start", "end", and "value" keys.
@ -131,7 +137,7 @@ class Database:
res.append(parsed_file_transcriptions)
return res
def close(self):
def close(self) -> None:
"""
Closes the database connection and cursor.
"""

View file

@ -1,9 +1,14 @@
import parselmouth
import numpy as np
from fastapi import HTTPException
from array import array
from .types import FileStateType
def simple_frame_info(frame, fs, frame_info):
def simple_frame_info(
frame: array, fs: int | float, frame_info: dict[str, int] | None
) -> dict[str, float] | None:
"""
Extracts and returns basic information from a given audio frame.
@ -42,7 +47,7 @@ def simple_frame_info(frame, fs, frame_info):
return res
def calculate_frame_duration(frame, fs):
def calculate_frame_duration(frame: array, fs: int | float) -> float:
"""
This method calculates the duration of a frame based on the frame and the sample frequency.
@ -61,7 +66,7 @@ def calculate_frame_duration(frame, fs):
return len(frame) / fs
def calculate_frame_pitch(frame, fs):
def calculate_frame_pitch(frame: array, fs: float | int) -> float:
"""
This method calculates the pitch of a frame.
@ -84,11 +89,11 @@ def calculate_frame_pitch(frame, fs):
time_step=calculate_frame_duration(frame, fs) + 1
) # the + 1 ensures that the complete frame is considered as 1 frame
return pitch.get_value_at_time(0)
except Exception as _:
except Exception:
return float("nan")
def calculate_frame_f1_f2(frame, fs):
def calculate_frame_f1_f2(frame: array, fs: int | float) -> list[float]:
"""
This method calculates the first and second fromant of a frame.
@ -114,11 +119,12 @@ def calculate_frame_f1_f2(frame, fs):
formants.get_value_at_time(formant_number=1, time=0),
formants.get_value_at_time(formant_number=2, time=0),
]
except Exception as _:
except Exception as e:
print(e)
return [float("nan"), float("nan")]
def validate_frame_index(data, file_state):
def validate_frame_index(data: array, file_state: FileStateType):
"""
Validate the frame index specified in the file_state.
@ -162,11 +168,7 @@ def validate_frame_index(data, file_state):
status_code=400, detail="startIndex should be strictly lower than endIndex"
)
if start_index < 0:
raise HTTPException(
status_code=400, detail="startIndex should be larger or equal to 0"
)
raise HTTPException(status_code=400, detail="startIndex should be larger or equal to 0")
if end_index > len(data):
raise HTTPException(
status_code=400, detail="endIndex should be lower than the file length"
)
raise HTTPException(status_code=400, detail="endIndex should be lower than the file length")
return {"startIndex": start_index, "endIndex": end_index}

View file

@ -25,18 +25,20 @@ from .database import Database
import orjson
import json
import os
from typing import Any
from collections.abc import Iterator
def get_db(): # pragma: no cover
def get_db() -> Iterator[Database]: # pragma: no cover
db = None
try:
db = Database(
os.getenv("POSTGRES_USER"),
os.getenv("POSTGRES_PASSWORD"),
os.getenv("POSTGRES_HOST"),
os.getenv("POSTGRES_PORT"),
os.getenv("POSTGRES_DB"),
)
user = os.getenv("POSTGRES_USER", "user")
password = os.getenv("POSTGRES_PASSWORD", "password")
host = os.getenv("POSTGRES_HOST", "localhost")
port = int(os.getenv("POSTGRES_PORT", "5432"))
dbname = os.getenv("POSTGRES_DB", "postgres")
db = Database(user, password, host, port, dbname)
db.connection()
yield db
finally:
@ -51,7 +53,7 @@ class ORJSONResponse(JSONResponse):
media_type = "application/json"
def render(self, content) -> bytes:
def render(self, content: Any) -> bytes:
return orjson.dumps(content)

View file

@ -8,11 +8,14 @@ from .frame_analysis import (
validate_frame_index,
)
from .transcription import calculate_error_rates
from .types import FileStateType
import tempfile
import subprocess
from .database import Database
from typing import Any
def simple_info_mode(database, file_state):
def simple_info_mode(database: Database, file_state: FileStateType) -> dict[str, Any]:
"""
Extracts and returns basic information about a signal and its corresponding frame.
@ -43,28 +46,26 @@ def simple_info_mode(database, file_state):
frame_index = validate_frame_index(audio.get_array_of_samples(), file_state)
result["frame"] = simple_frame_info(
audio.get_array_of_samples(), audio.frame_rate, frame_index
)
result["frame"] = simple_frame_info(audio.get_array_of_samples(), audio.frame_rate, frame_index)
return result
def spectrogram_mode(database, file_state):
def spectrogram_mode(database: Database, file_state: FileStateType) -> Any:
"""
TBD
"""
return None
def waveform_mode(database, file_state):
def waveform_mode(database: Database, file_state: FileStateType) -> Any:
"""
TBD
"""
return None
def vowel_space_mode(database, file_state):
def vowel_space_mode(database: Database, file_state: FileStateType) -> dict[str, float] | None:
"""
Extracts and returns the first and second formants of a specified frame.
@ -97,14 +98,14 @@ def vowel_space_mode(database, file_state):
return {"f1": formants[0], "f2": formants[1]}
def transcription_mode(database, file_state):
def transcription_mode(database: Database, file_state: FileStateType) -> Any:
"""
TBD
"""
return None
def error_rate_mode(database, file_state):
def error_rate_mode(database: Database, file_state: FileStateType) -> dict[str, Any] | None:
"""
Calculate the error rates of transcriptions against the ground truth.
@ -140,7 +141,7 @@ def error_rate_mode(database, file_state):
return errorRate
def get_file(database, file_state):
def get_file(database: Database, file_state: FileStateType) -> FileStateType:
"""
Fetch a file from the database using the file_state information.
@ -171,7 +172,7 @@ def get_file(database, file_state):
return file
def convert_to_wav(data):
def convert_to_wav(data: bytes) -> bytes:
with tempfile.NamedTemporaryFile(delete=False) as temp_input:
temp_input.write(data)
temp_input.flush() # Ensure data is written to disk

View file

@ -1,6 +1,6 @@
from typing import Any, Dict
from typing import Any
signal_modes_response_examples: Dict[int | str, Dict[str, Any]] = {
signal_modes_response_examples: dict[int | str, dict[str, Any]] = {
200: {
"content": {
"application/json": {
@ -208,7 +208,7 @@ signal_modes_response_examples: Dict[int | str, Dict[str, Any]] = {
404: {"content": {"application/json": {"example": {"detail": "error message"}}}},
}
transcription_response_examples: Dict[int | str, Dict[str, Any]] = {
transcription_response_examples: dict[int | str, dict[str, Any]] = {
200: {
"content": {
"application/json": {

View file

@ -1,10 +1,13 @@
import parselmouth
import numpy as np
from pydub import AudioSegment
from .types import AudioType, SoundType
import io
from typing import Any
from array import array
def get_audio(file):
def get_audio(file: dict[str, Any]) -> AudioType:
"""
Extract audio data and sampling rate from the given file.
@ -12,7 +15,7 @@ def get_audio(file):
- file: A dictionary containing the file data, including audio bytes.
Returns:
- A tuple (fs, data) where fs is the sampling rate and data is the array of audio samples.
- A list, that contains audio data
Example:
```python
@ -24,7 +27,7 @@ def get_audio(file):
return audio
def simple_signal_info(audio):
def simple_signal_info(audio: AudioType) -> dict[str, Any]:
"""
Extracts and returns basic information from a given audio signal.
@ -32,7 +35,6 @@ def simple_signal_info(audio):
Parameters:
- signal (list of int): The audio signal data.
- fs (float): The sample frequency of the audio signal.
Returns:
- dict: A dictionary containing the duration and average pitch of the signal.
@ -42,8 +44,8 @@ def simple_signal_info(audio):
result = simple_signal_info(signal, fs)
```
"""
duration = calculate_signal_duration(audio)
avg_pitch = np.mean(
duration: float = calculate_signal_duration(audio)
avg_pitch: float = np.mean(
calculate_sound_pitch(
signal_to_sound(signal=audio.get_array_of_samples(), fs=audio.frame_rate)
)["data"] # type: ignore
@ -51,7 +53,7 @@ def simple_signal_info(audio):
return {"duration": duration, "averagePitch": avg_pitch}
def signal_to_sound(signal, fs):
def signal_to_sound(signal: array, fs: float | int) -> SoundType:
"""
This method converts a signal to a parselmouth sound object.
@ -67,12 +69,10 @@ def signal_to_sound(signal, fs):
result = signal_to_sound(signal, fs)
```
"""
return parselmouth.Sound(
values=np.array(signal).astype("float64"), sampling_frequency=fs
)
return parselmouth.Sound(values=np.array(signal).astype("float64"), sampling_frequency=fs)
def calculate_signal_duration(audio):
def calculate_signal_duration(audio: AudioType) -> float:
"""
This method calculates the duration of a signal based on the signal and the sample frequency.
@ -91,7 +91,9 @@ def calculate_signal_duration(audio):
return audio.duration_seconds
def calculate_sound_pitch(sound, time_step=None): # pragma: no cover
def calculate_sound_pitch(
sound: SoundType, time_step: float | None = None
) -> dict[str, Any] | None: # pragma: no cover
"""
This method calculates the pitches present in a sound object.
@ -121,8 +123,11 @@ def calculate_sound_pitch(sound, time_step=None): # pragma: no cover
def calculate_sound_spectrogram(
sound, time_step=0.002, window_length=0.005, frequency_step=20.0
): # pragma: no cover
sound: SoundType,
time_step: float = 0.002,
window_length: float = 0.005,
frequency_step: float = 20.0,
) -> dict[str, Any] | None: # pragma: no cover
"""
This method calculates the spectrogram of a sound fragment.
@ -162,7 +167,7 @@ def calculate_sound_spectrogram(
def calculate_sound_f1_f2(
sound, time_step=None, window_length=0.025
sound: SoundType, time_step: float | None = None, window_length: float = 0.025
): # pragma: no cover
"""
This method calculates the first and second formant of a sound fragment.
@ -184,10 +189,8 @@ def calculate_sound_f1_f2(
```
"""
try:
formants = sound.to_formant_burg(
time_step=time_step, window_length=window_length
)
data = []
formants = sound.to_formant_burg(time_step=time_step, window_length=window_length)
data: list = []
for frame in np.arange(1, len(formants) + 1):
data.append(
[

View file

@ -1,13 +1,18 @@
from deepgram import DeepgramClient, PrerecordedOptions, FileSource
from fastapi import HTTPException
from jiwer import process_words, process_characters
import jiwer
from .signal_analysis import get_audio, calculate_signal_duration
from .types import FileStateType
from allosaurus.app import read_recognizer # type: ignore
import tempfile
import os
from typing import Any
def calculate_error_rates(reference_annotations, hypothesis_annotations):
def calculate_error_rates(
reference_annotations: list[dict], hypothesis_annotations: list[dict]
) -> dict | None:
"""
Calculate error rates between the reference transcription and annotations.
@ -32,7 +37,7 @@ def calculate_error_rates(reference_annotations, hypothesis_annotations):
return {"wordLevel": word_level, "characterLevel": character_level}
def annotation_to_sentence(annotations):
def annotation_to_sentence(annotations: list) -> str:
"""
Convert annotations to a single hypothesis string.
@ -57,7 +62,7 @@ def annotation_to_sentence(annotations):
return res[: len(res) - 1]
def word_level_processing(reference, hypothesis):
def word_level_processing(reference: str, hypothesis: str) -> dict[str, Any]:
"""
Process word-level error metrics between the reference and hypothesis.
@ -89,7 +94,7 @@ def word_level_processing(reference, hypothesis):
return result
def character_level_processing(reference, hypothesis):
def character_level_processing(reference: str, hypothesis: str) -> dict[str, Any]:
"""
Process character-level error metrics between the reference and hypothesis.
@ -119,7 +124,7 @@ def character_level_processing(reference, hypothesis):
return result
def get_alignments(unparsed_alignments):
def get_alignments(unparsed_alignments: list[jiwer.process.AlignmentChunk]) -> list[dict]:
"""
Convert unparsed alignments into a structured format.
@ -148,7 +153,7 @@ def get_alignments(unparsed_alignments):
return alignments
def get_transcription(model, file):
def get_transcription(model: str, file: FileStateType):
"""
Get transcription of an audio file using the specified model.
@ -171,7 +176,7 @@ def get_transcription(model, file):
raise HTTPException(status_code=404, detail="Model was not found")
def fill_gaps(transcriptions, file):
def fill_gaps(transcriptions: list[dict], file: FileStateType) -> list[dict]:
res = []
audio = get_audio(file)
@ -194,7 +199,7 @@ def fill_gaps(transcriptions, file):
return res
def deepgram_transcription(data):
def deepgram_transcription(data: bytes) -> list[dict]:
"""
Transcribe audio data using Deepgram API.
@ -234,16 +239,15 @@ def deepgram_transcription(data):
res = []
for word in response["results"]["channels"][0]["alternatives"][0]["words"]:
res.append(
{"value": word["word"], "start": word["start"], "end": word["end"]}
)
res.append({"value": word["word"], "start": word["start"], "end": word["end"]})
return res
except Exception as e:
print(f"Exception: {e}")
return []
def allosaurs_transcription(file):
def allosaurs_transcription(file: FileStateType) -> Any:
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_wav:
temp_wav.write(file["data"])
temp_wav_filename = temp_wav.name
@ -251,9 +255,7 @@ def allosaurs_transcription(file):
word_level_transcription = fill_gaps(deepgram_transcription(file["data"]), file)
model = read_recognizer()
phoneme_level_transcription = model.recognize(
temp_wav_filename, timestamp=True, emit=1.2
)
phoneme_level_transcription = model.recognize(temp_wav_filename, timestamp=True, emit=1.2)
phoneme_level_parsed = []
@ -262,13 +264,13 @@ def allosaurs_transcription(file):
[float(phoneme_string.split(" ")[0]), phoneme_string.split(" ")[2]]
)
phoneme_word_splits = get_phoneme_word_splits(
word_level_transcription, phoneme_level_parsed
)
phoneme_word_splits = get_phoneme_word_splits(word_level_transcription, phoneme_level_parsed)
return get_phoneme_transcriptions(phoneme_word_splits)
def get_phoneme_word_splits(word_level_transcription, phoneme_level_parsed):
def get_phoneme_word_splits(
word_level_transcription: list[dict], phoneme_level_parsed: list[list]
) -> list[dict]:
if len(word_level_transcription) == 0:
return []
@ -282,10 +284,7 @@ def get_phoneme_word_splits(word_level_transcription, phoneme_level_parsed):
while word_pointer < len(word_level_transcription) and phoneme_pointer < len(
phoneme_level_parsed
):
if (
phoneme_level_parsed[phoneme_pointer][0]
> word_level_transcription[word_pointer]["end"]
):
if phoneme_level_parsed[phoneme_pointer][0] > word_level_transcription[word_pointer]["end"]:
current_split["word_transcription"] = word_level_transcription[word_pointer]
phoneme_word_splits.append(current_split)
current_split = {"phonemes": [], "word_transcription": None}
@ -302,7 +301,7 @@ def get_phoneme_word_splits(word_level_transcription, phoneme_level_parsed):
return phoneme_word_splits
def get_phoneme_transcriptions(phoneme_word_splits):
def get_phoneme_transcriptions(phoneme_word_splits: list[Any]) -> list[dict]:
res = []
for phoneme_split in phoneme_word_splits:
@ -315,22 +314,14 @@ def get_phoneme_transcriptions(phoneme_word_splits):
start = phoneme_split["word_transcription"]["start"]
else:
# this is an (educated) guess, it could be way off :D
start = (
phoneme_split["phonemes"][i - 1][0]
+ phoneme_split["phonemes"][i][0]
) / 2
start = (phoneme_split["phonemes"][i - 1][0] + phoneme_split["phonemes"][i][0]) / 2
end = 0
if i + 1 == len(phoneme_split["phonemes"]):
end = phoneme_split["word_transcription"]["end"]
else:
end = (
phoneme_split["phonemes"][i + 1][0]
+ phoneme_split["phonemes"][i][0]
) / 2
end = (phoneme_split["phonemes"][i + 1][0] + phoneme_split["phonemes"][i][0]) / 2
res.append(
{"value": phoneme_split["phonemes"][i][1], "start": start, "end": end}
)
res.append({"value": phoneme_split["phonemes"][i][1], "start": start, "end": end})
return res

7
kernel/spectral/types.py Normal file
View file

@ -0,0 +1,7 @@
import parselmouth
from pydub import AudioSegment
# type definitions
AudioType = AudioSegment
SoundType = parselmouth.Sound
FileStateType = dict

View file

@ -1,4 +1,5 @@
import pytest
from spectral.database import Database
from spectral.main import app, get_db
from fastapi.testclient import TestClient
from fastapi import HTTPException
@ -48,6 +49,7 @@ def db_mock():
"groundTruth": "hai test",
}
mock.get_transcriptions.return_value = [[{"value": "hi", "start": 0, "end": 1}]]
mock.__class__ = Database
yield mock

View file

@ -7,6 +7,7 @@ from spectral.frame_analysis import (
)
import json
import os
from array import array
# Load the JSON file
with open(
@ -17,33 +18,40 @@ with open(
def test_speed_voiced_frame():
assert (
calculate_frame_duration(frame_data["voiced-1"]["data"], frame_data["voiced-1"]["fs"])
calculate_frame_duration(
array("h", frame_data["voiced-1"]["data"]), frame_data["voiced-1"]["fs"]
)
== 0.04
), "Expected duration for voiced frame to be 0.04 seconds"
def test_speed_unvoiced_frame():
assert (
calculate_frame_duration(frame_data["unvoiced-1"]["data"], frame_data["unvoiced-1"]["fs"])
calculate_frame_duration(
array("h", frame_data["unvoiced-1"]["data"]), frame_data["unvoiced-1"]["fs"]
)
== 0.04
), "Expected duration for unvoiced frame to be 0.04 seconds"
def test_speed_noise_frame():
assert (
calculate_frame_duration(frame_data["noise-1"]["data"], frame_data["noise-1"]["fs"]) == 0.04
calculate_frame_duration(
array("h", frame_data["noise-1"]["data"]), frame_data["noise-1"]["fs"]
)
== 0.04
), "Expected duration for noise frame to be 0.04 seconds"
def test_speed_empty_frame():
assert (
calculate_frame_duration([], 48000) == 0
calculate_frame_duration(array("h", []), 48000) == 0
), "Expected duration for empty frame to be 0 seconds"
def test_pitch_voiced():
assert calculate_frame_pitch(
frame_data["voiced-1"]["data"], frame_data["voiced-1"]["fs"]
array("h", frame_data["voiced-1"]["data"]), frame_data["voiced-1"]["fs"]
) == pytest.approx(
115.64, 0.01
), "Expected pitch for voiced frame to be approximately 115.64 Hz"
@ -51,22 +59,30 @@ def test_pitch_voiced():
def test_pitch_unvoiced():
assert math.isnan(
calculate_frame_pitch(frame_data["unvoiced-1"]["data"], frame_data["unvoiced-1"]["fs"])
), "Expected pitch for unvoiced frame to be NaN"
calculate_frame_pitch(
array("h", frame_data["unvoiced-1"]["data"]), frame_data["unvoiced-1"]["fs"]
)
)
def test_pitch_noise():
assert math.isnan(
calculate_frame_pitch(frame_data["noise-1"]["data"], frame_data["noise-1"]["fs"])
calculate_frame_pitch(
array("h", frame_data["noise-1"]["data"]), frame_data["noise-1"]["fs"]
)
), "Expected pitch for noise frame to be NaN"
def test_pitch_empty_frame():
assert math.isnan(calculate_frame_pitch([], 48000)), "Expected pitch for empty frame to be NaN"
assert math.isnan(
calculate_frame_pitch(array("h", []), 48000)
), "Expected pitch for empty frame to be NaN"
def test_formants_voiced_frame():
formants = calculate_frame_f1_f2(frame_data["voiced-1"]["data"], frame_data["voiced-1"]["fs"])
formants = calculate_frame_f1_f2(
array("h", frame_data["voiced-1"]["data"]), frame_data["voiced-1"]["fs"]
)
assert len(formants) == 2, "Expected two formants for voiced frame"
assert formants[0] == pytest.approx(
474.43, 0.01
@ -78,7 +94,7 @@ def test_formants_voiced_frame():
def test_formants_unvoiced_frame():
formants = calculate_frame_f1_f2(
frame_data["unvoiced-1"]["data"], frame_data["unvoiced-1"]["fs"]
array("h", frame_data["unvoiced-1"]["data"]), frame_data["unvoiced-1"]["fs"]
)
assert len(formants) == 2, "Expected two formants for unvoiced frame"
assert formants[0] == pytest.approx(
@ -90,7 +106,10 @@ def test_formants_unvoiced_frame():
def test_formants_noise_frame():
formants = calculate_frame_f1_f2(frame_data["noise-1"]["data"], frame_data["noise-1"]["fs"])
formants = calculate_frame_f1_f2(
array("h", frame_data["noise-1"]["data"]), frame_data["noise-1"]["fs"]
)
assert len(formants) == 2, "Expected two formants for noise frame"
assert formants[0] == pytest.approx(
192.72, 0.01
@ -101,7 +120,7 @@ def test_formants_noise_frame():
def test_formants_empty_frame():
formants = calculate_frame_f1_f2([], 0)
formants = calculate_frame_f1_f2(array("h", []), 0)
assert len(formants) == 2, "Expected two formants for empty frame"
assert math.isnan(formants[0]), "Expected first formant (f1) for empty frame to be NaN"
assert math.isnan(formants[1]), "Expected second formant (f2) for empty frame to be NaN"

View file

@ -1,12 +1,13 @@
import pytest
from fastapi import HTTPException
from spectral.frame_analysis import validate_frame_index
from array import array
def test_validate_frame_index_valid():
data = [0] * 100
frame_index = validate_frame_index(
data, {"frame": {"startIndex": 10, "endIndex": 20}}
array("h", data), {"frame": {"startIndex": 10, "endIndex": 20}}
)
assert frame_index == {"startIndex": 10, "endIndex": 20}
@ -14,7 +15,7 @@ def test_validate_frame_index_valid():
def test_validate_frame_index_both_none_indices():
data = [0] * 100
assert (
validate_frame_index(data, {"frame": {"startIndex": None, "endIndex": None}})
validate_frame_index(array("h", data), {"frame": {"startIndex": None, "endIndex": None}})
is None
)
@ -22,34 +23,34 @@ def test_validate_frame_index_both_none_indices():
def test_validate_frame_index_missing_start_index():
data = [0] * 100
with pytest.raises(HTTPException):
validate_frame_index(data, {"frame": {"startIndex": None, "endIndex": 20}})
validate_frame_index(array("h", data), {"frame": {"startIndex": None, "endIndex": 20}})
def test_validate_frame_index_missing_end_index():
data = [0] * 100
with pytest.raises(HTTPException):
validate_frame_index(data, {"frame": {"startIndex": 10, "endIndex": None}})
validate_frame_index(array("h", data), {"frame": {"startIndex": 10, "endIndex": None}})
def test_validate_frame_index_start_index_greater_than_end_index():
data = [0] * 100
with pytest.raises(HTTPException):
validate_frame_index(data, {"frame": {"startIndex": 20, "endIndex": 10}})
validate_frame_index(array("h", data), {"frame": {"startIndex": 20, "endIndex": 10}})
def test_validate_frame_index_negative_start_index():
data = [0] * 100
with pytest.raises(HTTPException):
validate_frame_index(data, {"frame": {"startIndex": -1, "endIndex": 20}})
validate_frame_index(array("h", data), {"frame": {"startIndex": -1, "endIndex": 20}})
def test_validate_frame_index_end_index_too_large():
data = [0] * 100
with pytest.raises(HTTPException):
validate_frame_index(data, {"frame": {"startIndex": 10, "endIndex": 200}})
validate_frame_index(array("h", data), {"frame": {"startIndex": 10, "endIndex": 200}})
def test_frame_none():
data = [0] * 100
frame_index = validate_frame_index(data, {"frame": None})
frame_index = validate_frame_index(array("h", data), {"frame": None})
assert frame_index is None