Skip to content
This repository has been archived by the owner on May 1, 2023. It is now read-only.

Commit

Permalink
feat: add a resampling feature
Browse files Browse the repository at this point in the history
  • Loading branch information
charles-zablit committed Oct 20, 2022
1 parent 2dd2c6c commit bc486b6
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 1 deletion.
48 changes: 47 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ websockets = "^10.3"
tqdm = "*"
torchaudio = "^0.12.1"
PySoundFile = {version = "^0.9.0.post1", platform = "windows"}
scipy = "^1.9.3"


[tool.poetry.group.dev.dependencies]
Expand Down
24 changes: 24 additions & 0 deletions whispering/resample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/usr/bin/env python3
import numpy as np
from scipy.interpolate import interp1d
from whisper.audio import SAMPLE_RATE


def resample(x: np.ndarray, source_sr: int) -> np.ndarray:
"""Resample a numpy array to match the SAMPLE_RATE
of Whisper (16 000 Hz).
Args:
x (np.ndarray): Source numpy array
source_sr (int): Source sample rate
Returns:
np.ndarray: The resampled array
"""

if source_sr == SAMPLE_RATE:
return x
factor = source_sr / SAMPLE_RATE
n = int(np.ceil(x.size / factor))
f = interp1d(np.linspace(0, 1, x.size), x, "linear")
return f(np.linspace(0, 1, n))
1 change: 1 addition & 0 deletions whispering/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class Context(BaseModel, arbitrary_types_allowed=True):
max_nospeech_skip: int

data_type: str = "float32"
source_sample_rate: int = 16000


class ParsedChunk(BaseModel):
Expand Down
2 changes: 2 additions & 0 deletions whispering/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from whispering.schema import CURRENT_PROTOCOL_VERSION, Context
from whispering.transcriber import WhisperStreamingTranscriber
from whispering.resample import resample

logger = getLogger(__name__)

Expand Down Expand Up @@ -77,6 +78,7 @@ async def serve_with_websocket_main(websocket):
)
return
audio = np.frombuffer(message, dtype=np.dtype(ctx.data_type)).astype(np.float32)
audio = resample(audio, ctx.source_sample_rate)
for chunk in g_wsp.transcribe(
audio=audio, # type: ignore
ctx=ctx,
Expand Down

0 comments on commit bc486b6

Please sign in to comment.