Merge branch 'kernel/beartype' into 'dev'
Kernel/beartype See merge request cse2000-software-project/2023-2024/cluster-n/11c/atypical-speech-project!105
This commit is contained in:
commit
fe72db9176
15 changed files with 182 additions and 122 deletions
|
@ -25,7 +25,7 @@ repos:
|
|||
hooks:
|
||||
- id: pyright
|
||||
types_or: [python, pyi, jupyter]
|
||||
additional_dependencies: [numpy, pytest, fastapi, praat-parselmouth, orjson, pydantic, scipy, psycopg, deepgram-sdk, pydub, ffmpeg-python, jiwer]
|
||||
additional_dependencies: [numpy, pytest, fastapi, praat-parselmouth, orjson, pydantic, scipy, psycopg, deepgram-sdk, pydub, ffmpeg-python, jiwer, beartype]
|
||||
stages: [pre-commit]
|
||||
- repo: https://github.com/crate-ci/typos
|
||||
rev: v1.21.0
|
||||
|
|
22
kernel/poetry.lock
generated
22
kernel/poetry.lock
generated
|
@ -1,4 +1,4 @@
|
|||
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "aiofiles"
|
||||
|
@ -189,6 +189,24 @@ tests = ["attrs[tests-no-zope]", "zope-interface"]
|
|||
tests-mypy = ["mypy (>=1.6)", "pytest-mypy-plugins"]
|
||||
tests-no-zope = ["attrs[tests-mypy]", "cloudpickle", "hypothesis", "pympler", "pytest (>=4.3.0)", "pytest-xdist[psutil]"]
|
||||
|
||||
[[package]]
|
||||
name = "beartype"
|
||||
version = "0.18.5"
|
||||
description = "Unbearably fast runtime type checking in pure Python."
|
||||
optional = false
|
||||
python-versions = ">=3.8.0"
|
||||
files = [
|
||||
{file = "beartype-0.18.5-py3-none-any.whl", hash = "sha256:5301a14f2a9a5540fe47ec6d34d758e9cd8331d36c4760fc7a5499ab86310089"},
|
||||
{file = "beartype-0.18.5.tar.gz", hash = "sha256:264ddc2f1da9ec94ff639141fbe33d22e12a9f75aa863b83b7046ffff1381927"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
all = ["typing-extensions (>=3.10.0.0)"]
|
||||
dev = ["autoapi (>=0.9.0)", "coverage (>=5.5)", "equinox", "mypy (>=0.800)", "numpy", "pandera", "pydata-sphinx-theme (<=0.7.2)", "pytest (>=4.0.0)", "sphinx", "sphinx (>=4.2.0,<6.0.0)", "sphinxext-opengraph (>=0.7.5)", "tox (>=3.20.1)", "typing-extensions (>=3.10.0.0)"]
|
||||
doc-rtd = ["autoapi (>=0.9.0)", "pydata-sphinx-theme (<=0.7.2)", "sphinx (>=4.2.0,<6.0.0)", "sphinxext-opengraph (>=0.7.5)"]
|
||||
test-tox = ["equinox", "mypy (>=0.800)", "numpy", "pandera", "pytest (>=4.0.0)", "sphinx", "typing-extensions (>=3.10.0.0)"]
|
||||
test-tox-coverage = ["coverage (>=5.5)"]
|
||||
|
||||
[[package]]
|
||||
name = "certifi"
|
||||
version = "2024.2.2"
|
||||
|
@ -3062,4 +3080,4 @@ multidict = ">=4.0"
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.11"
|
||||
content-hash = "1e35a62167b27370ababf328d0b9ca3c05472d6ef752fc5b7303fb7d7284db3f"
|
||||
content-hash = "f2ccf666bf19504a96354945c1adbd3beeed52c8c0800a86ead0a384610317b3"
|
||||
|
|
|
@ -50,6 +50,7 @@ mutmut = "^2.5.0"
|
|||
pytest-xdist = "^3.6.1"
|
||||
pytest-testmon = "^2.1.1"
|
||||
pytest-mock = "^3.14.0"
|
||||
beartype = "^0.18.5"
|
||||
|
||||
|
||||
[tool.mutmut]
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
try:
|
||||
from beartype.claw import beartype_this_package
|
||||
|
||||
beartype_this_package()
|
||||
except ImportError:
|
||||
# in case beartype is not installed: running in production
|
||||
pass
|
|
@ -18,7 +18,13 @@ class Database:
|
|||
Closes the database connection and cursor.
|
||||
"""
|
||||
|
||||
def __init__(self, user, password, host, port, dbname):
|
||||
user: str
|
||||
password: str
|
||||
host: str
|
||||
port: int
|
||||
dbname: str
|
||||
|
||||
def __init__(self, user: str, password: str, host: str, port: int, dbname: str):
|
||||
"""
|
||||
Initializes the Database object and opens a connection to the specified PostgreSQL database.
|
||||
|
||||
|
@ -35,7 +41,7 @@ class Database:
|
|||
self.port = port
|
||||
self.dbname = dbname
|
||||
|
||||
def connection(self):
|
||||
def connection(self) -> None:
|
||||
self.conn = psycopg.connect(
|
||||
dbname=self.dbname,
|
||||
user=self.user,
|
||||
|
@ -45,12 +51,12 @@ class Database:
|
|||
)
|
||||
self.cursor = self.conn.cursor()
|
||||
|
||||
def fetch_file(self, id):
|
||||
def fetch_file(self, id: int) -> dict:
|
||||
"""
|
||||
Fetches a file record from the database by its ID.
|
||||
|
||||
Args:
|
||||
id (string): The ID of the file to fetch.
|
||||
id (int): The ID of the file to fetch.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the file record's details.
|
||||
|
@ -73,7 +79,7 @@ class Database:
|
|||
result[self.snake_to_camel(column[0])] = db_res[column[1] - 1]
|
||||
return result
|
||||
|
||||
def snake_to_camel(self, snake_case_str):
|
||||
def snake_to_camel(self, snake_case_str: str) -> str:
|
||||
"""
|
||||
Converts a snake_case string to camelCase.
|
||||
|
||||
|
@ -91,12 +97,12 @@ class Database:
|
|||
components = snake_case_str.split("_")
|
||||
return components[0] + "".join(x.title() for x in components[1:])
|
||||
|
||||
def get_transcriptions(self, file_id):
|
||||
def get_transcriptions(self, file_id: int) -> list[list]:
|
||||
"""
|
||||
Fetches transcriptions associated with a file from the database.
|
||||
|
||||
Args:
|
||||
file_id (string): The ID of the file to fetch transcriptions for.
|
||||
file_id (int): The ID of the file to fetch transcriptions for.
|
||||
|
||||
Returns:
|
||||
list: A list of lists containing transcription entries, where each inner list represents a file transcription and contains dictionaries with "start", "end", and "value" keys.
|
||||
|
@ -131,7 +137,7 @@ class Database:
|
|||
res.append(parsed_file_transcriptions)
|
||||
return res
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
"""
|
||||
Closes the database connection and cursor.
|
||||
"""
|
||||
|
|
|
@ -1,9 +1,14 @@
|
|||
import parselmouth
|
||||
import numpy as np
|
||||
from fastapi import HTTPException
|
||||
from array import array
|
||||
|
||||
from .types import FileStateType
|
||||
|
||||
|
||||
def simple_frame_info(frame, fs, frame_info):
|
||||
def simple_frame_info(
|
||||
frame: array, fs: int | float, frame_info: dict[str, int] | None
|
||||
) -> dict[str, float] | None:
|
||||
"""
|
||||
Extracts and returns basic information from a given audio frame.
|
||||
|
||||
|
@ -42,7 +47,7 @@ def simple_frame_info(frame, fs, frame_info):
|
|||
return res
|
||||
|
||||
|
||||
def calculate_frame_duration(frame, fs):
|
||||
def calculate_frame_duration(frame: array, fs: int | float) -> float:
|
||||
"""
|
||||
This method calculates the duration of a frame based on the frame and the sample frequency.
|
||||
|
||||
|
@ -61,7 +66,7 @@ def calculate_frame_duration(frame, fs):
|
|||
return len(frame) / fs
|
||||
|
||||
|
||||
def calculate_frame_pitch(frame, fs):
|
||||
def calculate_frame_pitch(frame: array, fs: float | int) -> float:
|
||||
"""
|
||||
This method calculates the pitch of a frame.
|
||||
|
||||
|
@ -84,11 +89,11 @@ def calculate_frame_pitch(frame, fs):
|
|||
time_step=calculate_frame_duration(frame, fs) + 1
|
||||
) # the + 1 ensures that the complete frame is considered as 1 frame
|
||||
return pitch.get_value_at_time(0)
|
||||
except Exception as _:
|
||||
except Exception:
|
||||
return float("nan")
|
||||
|
||||
|
||||
def calculate_frame_f1_f2(frame, fs):
|
||||
def calculate_frame_f1_f2(frame: array, fs: int | float) -> list[float]:
|
||||
"""
|
||||
This method calculates the first and second fromant of a frame.
|
||||
|
||||
|
@ -114,11 +119,12 @@ def calculate_frame_f1_f2(frame, fs):
|
|||
formants.get_value_at_time(formant_number=1, time=0),
|
||||
formants.get_value_at_time(formant_number=2, time=0),
|
||||
]
|
||||
except Exception as _:
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return [float("nan"), float("nan")]
|
||||
|
||||
|
||||
def validate_frame_index(data, file_state):
|
||||
def validate_frame_index(data: array, file_state: FileStateType):
|
||||
"""
|
||||
Validate the frame index specified in the file_state.
|
||||
|
||||
|
@ -162,11 +168,7 @@ def validate_frame_index(data, file_state):
|
|||
status_code=400, detail="startIndex should be strictly lower than endIndex"
|
||||
)
|
||||
if start_index < 0:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="startIndex should be larger or equal to 0"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="startIndex should be larger or equal to 0")
|
||||
if end_index > len(data):
|
||||
raise HTTPException(
|
||||
status_code=400, detail="endIndex should be lower than the file length"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="endIndex should be lower than the file length")
|
||||
return {"startIndex": start_index, "endIndex": end_index}
|
||||
|
|
|
@ -25,18 +25,20 @@ from .database import Database
|
|||
import orjson
|
||||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
from collections.abc import Iterator
|
||||
|
||||
|
||||
def get_db(): # pragma: no cover
|
||||
def get_db() -> Iterator[Database]: # pragma: no cover
|
||||
db = None
|
||||
try:
|
||||
db = Database(
|
||||
os.getenv("POSTGRES_USER"),
|
||||
os.getenv("POSTGRES_PASSWORD"),
|
||||
os.getenv("POSTGRES_HOST"),
|
||||
os.getenv("POSTGRES_PORT"),
|
||||
os.getenv("POSTGRES_DB"),
|
||||
)
|
||||
user = os.getenv("POSTGRES_USER", "user")
|
||||
password = os.getenv("POSTGRES_PASSWORD", "password")
|
||||
host = os.getenv("POSTGRES_HOST", "localhost")
|
||||
port = int(os.getenv("POSTGRES_PORT", "5432"))
|
||||
dbname = os.getenv("POSTGRES_DB", "postgres")
|
||||
|
||||
db = Database(user, password, host, port, dbname)
|
||||
db.connection()
|
||||
yield db
|
||||
finally:
|
||||
|
@ -51,7 +53,7 @@ class ORJSONResponse(JSONResponse):
|
|||
|
||||
media_type = "application/json"
|
||||
|
||||
def render(self, content) -> bytes:
|
||||
def render(self, content: Any) -> bytes:
|
||||
return orjson.dumps(content)
|
||||
|
||||
|
||||
|
|
|
@ -8,11 +8,14 @@ from .frame_analysis import (
|
|||
validate_frame_index,
|
||||
)
|
||||
from .transcription import calculate_error_rates
|
||||
from .types import FileStateType
|
||||
import tempfile
|
||||
import subprocess
|
||||
from .database import Database
|
||||
from typing import Any
|
||||
|
||||
|
||||
def simple_info_mode(database, file_state):
|
||||
def simple_info_mode(database: Database, file_state: FileStateType) -> dict[str, Any]:
|
||||
"""
|
||||
Extracts and returns basic information about a signal and its corresponding frame.
|
||||
|
||||
|
@ -43,28 +46,26 @@ def simple_info_mode(database, file_state):
|
|||
|
||||
frame_index = validate_frame_index(audio.get_array_of_samples(), file_state)
|
||||
|
||||
result["frame"] = simple_frame_info(
|
||||
audio.get_array_of_samples(), audio.frame_rate, frame_index
|
||||
)
|
||||
result["frame"] = simple_frame_info(audio.get_array_of_samples(), audio.frame_rate, frame_index)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def spectrogram_mode(database, file_state):
|
||||
def spectrogram_mode(database: Database, file_state: FileStateType) -> Any:
|
||||
"""
|
||||
TBD
|
||||
"""
|
||||
return None
|
||||
|
||||
|
||||
def waveform_mode(database, file_state):
|
||||
def waveform_mode(database: Database, file_state: FileStateType) -> Any:
|
||||
"""
|
||||
TBD
|
||||
"""
|
||||
return None
|
||||
|
||||
|
||||
def vowel_space_mode(database, file_state):
|
||||
def vowel_space_mode(database: Database, file_state: FileStateType) -> dict[str, float] | None:
|
||||
"""
|
||||
Extracts and returns the first and second formants of a specified frame.
|
||||
|
||||
|
@ -97,14 +98,14 @@ def vowel_space_mode(database, file_state):
|
|||
return {"f1": formants[0], "f2": formants[1]}
|
||||
|
||||
|
||||
def transcription_mode(database, file_state):
|
||||
def transcription_mode(database: Database, file_state: FileStateType) -> Any:
|
||||
"""
|
||||
TBD
|
||||
"""
|
||||
return None
|
||||
|
||||
|
||||
def error_rate_mode(database, file_state):
|
||||
def error_rate_mode(database: Database, file_state: FileStateType) -> dict[str, Any] | None:
|
||||
"""
|
||||
Calculate the error rates of transcriptions against the ground truth.
|
||||
|
||||
|
@ -140,7 +141,7 @@ def error_rate_mode(database, file_state):
|
|||
return errorRate
|
||||
|
||||
|
||||
def get_file(database, file_state):
|
||||
def get_file(database: Database, file_state: FileStateType) -> FileStateType:
|
||||
"""
|
||||
Fetch a file from the database using the file_state information.
|
||||
|
||||
|
@ -171,7 +172,7 @@ def get_file(database, file_state):
|
|||
return file
|
||||
|
||||
|
||||
def convert_to_wav(data):
|
||||
def convert_to_wav(data: bytes) -> bytes:
|
||||
with tempfile.NamedTemporaryFile(delete=False) as temp_input:
|
||||
temp_input.write(data)
|
||||
temp_input.flush() # Ensure data is written to disk
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
signal_modes_response_examples: Dict[int | str, Dict[str, Any]] = {
|
||||
signal_modes_response_examples: dict[int | str, dict[str, Any]] = {
|
||||
200: {
|
||||
"content": {
|
||||
"application/json": {
|
||||
|
@ -208,7 +208,7 @@ signal_modes_response_examples: Dict[int | str, Dict[str, Any]] = {
|
|||
404: {"content": {"application/json": {"example": {"detail": "error message"}}}},
|
||||
}
|
||||
|
||||
transcription_response_examples: Dict[int | str, Dict[str, Any]] = {
|
||||
transcription_response_examples: dict[int | str, dict[str, Any]] = {
|
||||
200: {
|
||||
"content": {
|
||||
"application/json": {
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
import parselmouth
|
||||
import numpy as np
|
||||
from pydub import AudioSegment
|
||||
from .types import AudioType, SoundType
|
||||
import io
|
||||
from typing import Any
|
||||
from array import array
|
||||
|
||||
|
||||
def get_audio(file):
|
||||
def get_audio(file: dict[str, Any]) -> AudioType:
|
||||
"""
|
||||
Extract audio data and sampling rate from the given file.
|
||||
|
||||
|
@ -12,7 +15,7 @@ def get_audio(file):
|
|||
- file: A dictionary containing the file data, including audio bytes.
|
||||
|
||||
Returns:
|
||||
- A tuple (fs, data) where fs is the sampling rate and data is the array of audio samples.
|
||||
- A list, that contains audio data
|
||||
|
||||
Example:
|
||||
```python
|
||||
|
@ -24,7 +27,7 @@ def get_audio(file):
|
|||
return audio
|
||||
|
||||
|
||||
def simple_signal_info(audio):
|
||||
def simple_signal_info(audio: AudioType) -> dict[str, Any]:
|
||||
"""
|
||||
Extracts and returns basic information from a given audio signal.
|
||||
|
||||
|
@ -32,7 +35,6 @@ def simple_signal_info(audio):
|
|||
|
||||
Parameters:
|
||||
- signal (list of int): The audio signal data.
|
||||
- fs (float): The sample frequency of the audio signal.
|
||||
|
||||
Returns:
|
||||
- dict: A dictionary containing the duration and average pitch of the signal.
|
||||
|
@ -42,8 +44,8 @@ def simple_signal_info(audio):
|
|||
result = simple_signal_info(signal, fs)
|
||||
```
|
||||
"""
|
||||
duration = calculate_signal_duration(audio)
|
||||
avg_pitch = np.mean(
|
||||
duration: float = calculate_signal_duration(audio)
|
||||
avg_pitch: float = np.mean(
|
||||
calculate_sound_pitch(
|
||||
signal_to_sound(signal=audio.get_array_of_samples(), fs=audio.frame_rate)
|
||||
)["data"] # type: ignore
|
||||
|
@ -51,7 +53,7 @@ def simple_signal_info(audio):
|
|||
return {"duration": duration, "averagePitch": avg_pitch}
|
||||
|
||||
|
||||
def signal_to_sound(signal, fs):
|
||||
def signal_to_sound(signal: array, fs: float | int) -> SoundType:
|
||||
"""
|
||||
This method converts a signal to a parselmouth sound object.
|
||||
|
||||
|
@ -67,12 +69,10 @@ def signal_to_sound(signal, fs):
|
|||
result = signal_to_sound(signal, fs)
|
||||
```
|
||||
"""
|
||||
return parselmouth.Sound(
|
||||
values=np.array(signal).astype("float64"), sampling_frequency=fs
|
||||
)
|
||||
return parselmouth.Sound(values=np.array(signal).astype("float64"), sampling_frequency=fs)
|
||||
|
||||
|
||||
def calculate_signal_duration(audio):
|
||||
def calculate_signal_duration(audio: AudioType) -> float:
|
||||
"""
|
||||
This method calculates the duration of a signal based on the signal and the sample frequency.
|
||||
|
||||
|
@ -91,7 +91,9 @@ def calculate_signal_duration(audio):
|
|||
return audio.duration_seconds
|
||||
|
||||
|
||||
def calculate_sound_pitch(sound, time_step=None): # pragma: no cover
|
||||
def calculate_sound_pitch(
|
||||
sound: SoundType, time_step: float | None = None
|
||||
) -> dict[str, Any] | None: # pragma: no cover
|
||||
"""
|
||||
This method calculates the pitches present in a sound object.
|
||||
|
||||
|
@ -121,8 +123,11 @@ def calculate_sound_pitch(sound, time_step=None): # pragma: no cover
|
|||
|
||||
|
||||
def calculate_sound_spectrogram(
|
||||
sound, time_step=0.002, window_length=0.005, frequency_step=20.0
|
||||
): # pragma: no cover
|
||||
sound: SoundType,
|
||||
time_step: float = 0.002,
|
||||
window_length: float = 0.005,
|
||||
frequency_step: float = 20.0,
|
||||
) -> dict[str, Any] | None: # pragma: no cover
|
||||
"""
|
||||
This method calculates the spectrogram of a sound fragment.
|
||||
|
||||
|
@ -162,7 +167,7 @@ def calculate_sound_spectrogram(
|
|||
|
||||
|
||||
def calculate_sound_f1_f2(
|
||||
sound, time_step=None, window_length=0.025
|
||||
sound: SoundType, time_step: float | None = None, window_length: float = 0.025
|
||||
): # pragma: no cover
|
||||
"""
|
||||
This method calculates the first and second formant of a sound fragment.
|
||||
|
@ -184,10 +189,8 @@ def calculate_sound_f1_f2(
|
|||
```
|
||||
"""
|
||||
try:
|
||||
formants = sound.to_formant_burg(
|
||||
time_step=time_step, window_length=window_length
|
||||
)
|
||||
data = []
|
||||
formants = sound.to_formant_burg(time_step=time_step, window_length=window_length)
|
||||
data: list = []
|
||||
for frame in np.arange(1, len(formants) + 1):
|
||||
data.append(
|
||||
[
|
||||
|
|
|
@ -1,13 +1,18 @@
|
|||
from deepgram import DeepgramClient, PrerecordedOptions, FileSource
|
||||
from fastapi import HTTPException
|
||||
from jiwer import process_words, process_characters
|
||||
import jiwer
|
||||
from .signal_analysis import get_audio, calculate_signal_duration
|
||||
from .types import FileStateType
|
||||
from allosaurus.app import read_recognizer # type: ignore
|
||||
import tempfile
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
|
||||
def calculate_error_rates(reference_annotations, hypothesis_annotations):
|
||||
def calculate_error_rates(
|
||||
reference_annotations: list[dict], hypothesis_annotations: list[dict]
|
||||
) -> dict | None:
|
||||
"""
|
||||
Calculate error rates between the reference transcription and annotations.
|
||||
|
||||
|
@ -32,7 +37,7 @@ def calculate_error_rates(reference_annotations, hypothesis_annotations):
|
|||
return {"wordLevel": word_level, "characterLevel": character_level}
|
||||
|
||||
|
||||
def annotation_to_sentence(annotations):
|
||||
def annotation_to_sentence(annotations: list) -> str:
|
||||
"""
|
||||
Convert annotations to a single hypothesis string.
|
||||
|
||||
|
@ -57,7 +62,7 @@ def annotation_to_sentence(annotations):
|
|||
return res[: len(res) - 1]
|
||||
|
||||
|
||||
def word_level_processing(reference, hypothesis):
|
||||
def word_level_processing(reference: str, hypothesis: str) -> dict[str, Any]:
|
||||
"""
|
||||
Process word-level error metrics between the reference and hypothesis.
|
||||
|
||||
|
@ -89,7 +94,7 @@ def word_level_processing(reference, hypothesis):
|
|||
return result
|
||||
|
||||
|
||||
def character_level_processing(reference, hypothesis):
|
||||
def character_level_processing(reference: str, hypothesis: str) -> dict[str, Any]:
|
||||
"""
|
||||
Process character-level error metrics between the reference and hypothesis.
|
||||
|
||||
|
@ -119,7 +124,7 @@ def character_level_processing(reference, hypothesis):
|
|||
return result
|
||||
|
||||
|
||||
def get_alignments(unparsed_alignments):
|
||||
def get_alignments(unparsed_alignments: list[jiwer.process.AlignmentChunk]) -> list[dict]:
|
||||
"""
|
||||
Convert unparsed alignments into a structured format.
|
||||
|
||||
|
@ -148,7 +153,7 @@ def get_alignments(unparsed_alignments):
|
|||
return alignments
|
||||
|
||||
|
||||
def get_transcription(model, file):
|
||||
def get_transcription(model: str, file: FileStateType):
|
||||
"""
|
||||
Get transcription of an audio file using the specified model.
|
||||
|
||||
|
@ -171,7 +176,7 @@ def get_transcription(model, file):
|
|||
raise HTTPException(status_code=404, detail="Model was not found")
|
||||
|
||||
|
||||
def fill_gaps(transcriptions, file):
|
||||
def fill_gaps(transcriptions: list[dict], file: FileStateType) -> list[dict]:
|
||||
res = []
|
||||
|
||||
audio = get_audio(file)
|
||||
|
@ -194,7 +199,7 @@ def fill_gaps(transcriptions, file):
|
|||
return res
|
||||
|
||||
|
||||
def deepgram_transcription(data):
|
||||
def deepgram_transcription(data: bytes) -> list[dict]:
|
||||
"""
|
||||
Transcribe audio data using Deepgram API.
|
||||
|
||||
|
@ -234,16 +239,15 @@ def deepgram_transcription(data):
|
|||
|
||||
res = []
|
||||
for word in response["results"]["channels"][0]["alternatives"][0]["words"]:
|
||||
res.append(
|
||||
{"value": word["word"], "start": word["start"], "end": word["end"]}
|
||||
)
|
||||
res.append({"value": word["word"], "start": word["start"], "end": word["end"]})
|
||||
return res
|
||||
|
||||
except Exception as e:
|
||||
print(f"Exception: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def allosaurs_transcription(file):
|
||||
def allosaurs_transcription(file: FileStateType) -> Any:
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_wav:
|
||||
temp_wav.write(file["data"])
|
||||
temp_wav_filename = temp_wav.name
|
||||
|
@ -251,9 +255,7 @@ def allosaurs_transcription(file):
|
|||
word_level_transcription = fill_gaps(deepgram_transcription(file["data"]), file)
|
||||
|
||||
model = read_recognizer()
|
||||
phoneme_level_transcription = model.recognize(
|
||||
temp_wav_filename, timestamp=True, emit=1.2
|
||||
)
|
||||
phoneme_level_transcription = model.recognize(temp_wav_filename, timestamp=True, emit=1.2)
|
||||
|
||||
phoneme_level_parsed = []
|
||||
|
||||
|
@ -262,13 +264,13 @@ def allosaurs_transcription(file):
|
|||
[float(phoneme_string.split(" ")[0]), phoneme_string.split(" ")[2]]
|
||||
)
|
||||
|
||||
phoneme_word_splits = get_phoneme_word_splits(
|
||||
word_level_transcription, phoneme_level_parsed
|
||||
)
|
||||
phoneme_word_splits = get_phoneme_word_splits(word_level_transcription, phoneme_level_parsed)
|
||||
return get_phoneme_transcriptions(phoneme_word_splits)
|
||||
|
||||
|
||||
def get_phoneme_word_splits(word_level_transcription, phoneme_level_parsed):
|
||||
def get_phoneme_word_splits(
|
||||
word_level_transcription: list[dict], phoneme_level_parsed: list[list]
|
||||
) -> list[dict]:
|
||||
if len(word_level_transcription) == 0:
|
||||
return []
|
||||
|
||||
|
@ -282,10 +284,7 @@ def get_phoneme_word_splits(word_level_transcription, phoneme_level_parsed):
|
|||
while word_pointer < len(word_level_transcription) and phoneme_pointer < len(
|
||||
phoneme_level_parsed
|
||||
):
|
||||
if (
|
||||
phoneme_level_parsed[phoneme_pointer][0]
|
||||
> word_level_transcription[word_pointer]["end"]
|
||||
):
|
||||
if phoneme_level_parsed[phoneme_pointer][0] > word_level_transcription[word_pointer]["end"]:
|
||||
current_split["word_transcription"] = word_level_transcription[word_pointer]
|
||||
phoneme_word_splits.append(current_split)
|
||||
current_split = {"phonemes": [], "word_transcription": None}
|
||||
|
@ -302,7 +301,7 @@ def get_phoneme_word_splits(word_level_transcription, phoneme_level_parsed):
|
|||
return phoneme_word_splits
|
||||
|
||||
|
||||
def get_phoneme_transcriptions(phoneme_word_splits):
|
||||
def get_phoneme_transcriptions(phoneme_word_splits: list[Any]) -> list[dict]:
|
||||
res = []
|
||||
|
||||
for phoneme_split in phoneme_word_splits:
|
||||
|
@ -315,22 +314,14 @@ def get_phoneme_transcriptions(phoneme_word_splits):
|
|||
start = phoneme_split["word_transcription"]["start"]
|
||||
else:
|
||||
# this is an (educated) guess, it could be way off :D
|
||||
start = (
|
||||
phoneme_split["phonemes"][i - 1][0]
|
||||
+ phoneme_split["phonemes"][i][0]
|
||||
) / 2
|
||||
start = (phoneme_split["phonemes"][i - 1][0] + phoneme_split["phonemes"][i][0]) / 2
|
||||
|
||||
end = 0
|
||||
if i + 1 == len(phoneme_split["phonemes"]):
|
||||
end = phoneme_split["word_transcription"]["end"]
|
||||
else:
|
||||
end = (
|
||||
phoneme_split["phonemes"][i + 1][0]
|
||||
+ phoneme_split["phonemes"][i][0]
|
||||
) / 2
|
||||
end = (phoneme_split["phonemes"][i + 1][0] + phoneme_split["phonemes"][i][0]) / 2
|
||||
|
||||
res.append(
|
||||
{"value": phoneme_split["phonemes"][i][1], "start": start, "end": end}
|
||||
)
|
||||
res.append({"value": phoneme_split["phonemes"][i][1], "start": start, "end": end})
|
||||
|
||||
return res
|
||||
|
|
7
kernel/spectral/types.py
Normal file
7
kernel/spectral/types.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
import parselmouth
|
||||
from pydub import AudioSegment
|
||||
|
||||
# type definitions
|
||||
AudioType = AudioSegment
|
||||
SoundType = parselmouth.Sound
|
||||
FileStateType = dict
|
|
@ -1,4 +1,5 @@
|
|||
import pytest
|
||||
from spectral.database import Database
|
||||
from spectral.main import app, get_db
|
||||
from fastapi.testclient import TestClient
|
||||
from fastapi import HTTPException
|
||||
|
@ -48,6 +49,7 @@ def db_mock():
|
|||
"groundTruth": "hai test",
|
||||
}
|
||||
mock.get_transcriptions.return_value = [[{"value": "hi", "start": 0, "end": 1}]]
|
||||
mock.__class__ = Database
|
||||
|
||||
yield mock
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ from spectral.frame_analysis import (
|
|||
)
|
||||
import json
|
||||
import os
|
||||
from array import array
|
||||
|
||||
# Load the JSON file
|
||||
with open(
|
||||
|
@ -17,33 +18,40 @@ with open(
|
|||
|
||||
def test_speed_voiced_frame():
|
||||
assert (
|
||||
calculate_frame_duration(frame_data["voiced-1"]["data"], frame_data["voiced-1"]["fs"])
|
||||
calculate_frame_duration(
|
||||
array("h", frame_data["voiced-1"]["data"]), frame_data["voiced-1"]["fs"]
|
||||
)
|
||||
== 0.04
|
||||
), "Expected duration for voiced frame to be 0.04 seconds"
|
||||
|
||||
|
||||
def test_speed_unvoiced_frame():
|
||||
assert (
|
||||
calculate_frame_duration(frame_data["unvoiced-1"]["data"], frame_data["unvoiced-1"]["fs"])
|
||||
calculate_frame_duration(
|
||||
array("h", frame_data["unvoiced-1"]["data"]), frame_data["unvoiced-1"]["fs"]
|
||||
)
|
||||
== 0.04
|
||||
), "Expected duration for unvoiced frame to be 0.04 seconds"
|
||||
|
||||
|
||||
def test_speed_noise_frame():
|
||||
assert (
|
||||
calculate_frame_duration(frame_data["noise-1"]["data"], frame_data["noise-1"]["fs"]) == 0.04
|
||||
calculate_frame_duration(
|
||||
array("h", frame_data["noise-1"]["data"]), frame_data["noise-1"]["fs"]
|
||||
)
|
||||
== 0.04
|
||||
), "Expected duration for noise frame to be 0.04 seconds"
|
||||
|
||||
|
||||
def test_speed_empty_frame():
|
||||
assert (
|
||||
calculate_frame_duration([], 48000) == 0
|
||||
calculate_frame_duration(array("h", []), 48000) == 0
|
||||
), "Expected duration for empty frame to be 0 seconds"
|
||||
|
||||
|
||||
def test_pitch_voiced():
|
||||
assert calculate_frame_pitch(
|
||||
frame_data["voiced-1"]["data"], frame_data["voiced-1"]["fs"]
|
||||
array("h", frame_data["voiced-1"]["data"]), frame_data["voiced-1"]["fs"]
|
||||
) == pytest.approx(
|
||||
115.64, 0.01
|
||||
), "Expected pitch for voiced frame to be approximately 115.64 Hz"
|
||||
|
@ -51,22 +59,30 @@ def test_pitch_voiced():
|
|||
|
||||
def test_pitch_unvoiced():
|
||||
assert math.isnan(
|
||||
calculate_frame_pitch(frame_data["unvoiced-1"]["data"], frame_data["unvoiced-1"]["fs"])
|
||||
), "Expected pitch for unvoiced frame to be NaN"
|
||||
calculate_frame_pitch(
|
||||
array("h", frame_data["unvoiced-1"]["data"]), frame_data["unvoiced-1"]["fs"]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_pitch_noise():
|
||||
assert math.isnan(
|
||||
calculate_frame_pitch(frame_data["noise-1"]["data"], frame_data["noise-1"]["fs"])
|
||||
calculate_frame_pitch(
|
||||
array("h", frame_data["noise-1"]["data"]), frame_data["noise-1"]["fs"]
|
||||
)
|
||||
), "Expected pitch for noise frame to be NaN"
|
||||
|
||||
|
||||
def test_pitch_empty_frame():
|
||||
assert math.isnan(calculate_frame_pitch([], 48000)), "Expected pitch for empty frame to be NaN"
|
||||
assert math.isnan(
|
||||
calculate_frame_pitch(array("h", []), 48000)
|
||||
), "Expected pitch for empty frame to be NaN"
|
||||
|
||||
|
||||
def test_formants_voiced_frame():
|
||||
formants = calculate_frame_f1_f2(frame_data["voiced-1"]["data"], frame_data["voiced-1"]["fs"])
|
||||
formants = calculate_frame_f1_f2(
|
||||
array("h", frame_data["voiced-1"]["data"]), frame_data["voiced-1"]["fs"]
|
||||
)
|
||||
assert len(formants) == 2, "Expected two formants for voiced frame"
|
||||
assert formants[0] == pytest.approx(
|
||||
474.43, 0.01
|
||||
|
@ -78,7 +94,7 @@ def test_formants_voiced_frame():
|
|||
|
||||
def test_formants_unvoiced_frame():
|
||||
formants = calculate_frame_f1_f2(
|
||||
frame_data["unvoiced-1"]["data"], frame_data["unvoiced-1"]["fs"]
|
||||
array("h", frame_data["unvoiced-1"]["data"]), frame_data["unvoiced-1"]["fs"]
|
||||
)
|
||||
assert len(formants) == 2, "Expected two formants for unvoiced frame"
|
||||
assert formants[0] == pytest.approx(
|
||||
|
@ -90,7 +106,10 @@ def test_formants_unvoiced_frame():
|
|||
|
||||
|
||||
def test_formants_noise_frame():
|
||||
formants = calculate_frame_f1_f2(frame_data["noise-1"]["data"], frame_data["noise-1"]["fs"])
|
||||
formants = calculate_frame_f1_f2(
|
||||
array("h", frame_data["noise-1"]["data"]), frame_data["noise-1"]["fs"]
|
||||
)
|
||||
|
||||
assert len(formants) == 2, "Expected two formants for noise frame"
|
||||
assert formants[0] == pytest.approx(
|
||||
192.72, 0.01
|
||||
|
@ -101,7 +120,7 @@ def test_formants_noise_frame():
|
|||
|
||||
|
||||
def test_formants_empty_frame():
|
||||
formants = calculate_frame_f1_f2([], 0)
|
||||
formants = calculate_frame_f1_f2(array("h", []), 0)
|
||||
assert len(formants) == 2, "Expected two formants for empty frame"
|
||||
assert math.isnan(formants[0]), "Expected first formant (f1) for empty frame to be NaN"
|
||||
assert math.isnan(formants[1]), "Expected second formant (f2) for empty frame to be NaN"
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from spectral.frame_analysis import validate_frame_index
|
||||
from array import array
|
||||
|
||||
|
||||
def test_validate_frame_index_valid():
|
||||
data = [0] * 100
|
||||
frame_index = validate_frame_index(
|
||||
data, {"frame": {"startIndex": 10, "endIndex": 20}}
|
||||
array("h", data), {"frame": {"startIndex": 10, "endIndex": 20}}
|
||||
)
|
||||
assert frame_index == {"startIndex": 10, "endIndex": 20}
|
||||
|
||||
|
@ -14,7 +15,7 @@ def test_validate_frame_index_valid():
|
|||
def test_validate_frame_index_both_none_indices():
|
||||
data = [0] * 100
|
||||
assert (
|
||||
validate_frame_index(data, {"frame": {"startIndex": None, "endIndex": None}})
|
||||
validate_frame_index(array("h", data), {"frame": {"startIndex": None, "endIndex": None}})
|
||||
is None
|
||||
)
|
||||
|
||||
|
@ -22,34 +23,34 @@ def test_validate_frame_index_both_none_indices():
|
|||
def test_validate_frame_index_missing_start_index():
|
||||
data = [0] * 100
|
||||
with pytest.raises(HTTPException):
|
||||
validate_frame_index(data, {"frame": {"startIndex": None, "endIndex": 20}})
|
||||
validate_frame_index(array("h", data), {"frame": {"startIndex": None, "endIndex": 20}})
|
||||
|
||||
|
||||
def test_validate_frame_index_missing_end_index():
|
||||
data = [0] * 100
|
||||
with pytest.raises(HTTPException):
|
||||
validate_frame_index(data, {"frame": {"startIndex": 10, "endIndex": None}})
|
||||
validate_frame_index(array("h", data), {"frame": {"startIndex": 10, "endIndex": None}})
|
||||
|
||||
|
||||
def test_validate_frame_index_start_index_greater_than_end_index():
|
||||
data = [0] * 100
|
||||
with pytest.raises(HTTPException):
|
||||
validate_frame_index(data, {"frame": {"startIndex": 20, "endIndex": 10}})
|
||||
validate_frame_index(array("h", data), {"frame": {"startIndex": 20, "endIndex": 10}})
|
||||
|
||||
|
||||
def test_validate_frame_index_negative_start_index():
|
||||
data = [0] * 100
|
||||
with pytest.raises(HTTPException):
|
||||
validate_frame_index(data, {"frame": {"startIndex": -1, "endIndex": 20}})
|
||||
validate_frame_index(array("h", data), {"frame": {"startIndex": -1, "endIndex": 20}})
|
||||
|
||||
|
||||
def test_validate_frame_index_end_index_too_large():
|
||||
data = [0] * 100
|
||||
with pytest.raises(HTTPException):
|
||||
validate_frame_index(data, {"frame": {"startIndex": 10, "endIndex": 200}})
|
||||
validate_frame_index(array("h", data), {"frame": {"startIndex": 10, "endIndex": 200}})
|
||||
|
||||
|
||||
def test_frame_none():
|
||||
data = [0] * 100
|
||||
frame_index = validate_frame_index(data, {"frame": None})
|
||||
frame_index = validate_frame_index(array("h", data), {"frame": None})
|
||||
assert frame_index is None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue