254 lines
11 KiB
TypeScript
254 lines
11 KiB
TypeScript
/**
|
||
* openWakeWord pipeline в браузере.
|
||
*
|
||
* Цепочка: 1280-семпловый audio chunk @ 16kHz
|
||
* → melspectrogram.onnx → ~8 новых mel-фреймов
|
||
* → embedding_model.onnx (sliding 76-frame window, stride 8) → 96-D embedding
|
||
* → cosmo.onnx (классификатор по последним 16 embedding'ам) → score 0..1
|
||
* → score > threshold ⇒ onWake()
|
||
*
|
||
* Audio capture через AudioWorklet (`/wake/wake-capture-worklet.js`).
|
||
* ONNX inference на main thread через onnxruntime-web (WASM, single-thread).
|
||
*/
|
||
// onnxruntime-web .mjs builds используют top-level import.meta.url, что
|
||
// не парсится next-swc. Загружаем CJS-сборку через <script> тег → window.ort.
|
||
// Это обходит webpack полностью.
|
||
declare global { interface Window { ort?: any } }
|
||
|
||
const ORT_SCRIPT_URL = '/vad/ort.wasm.min.js'
|
||
let _ortLoadPromise: Promise<any> | null = null
|
||
|
||
async function getOrt(): Promise<any> {
|
||
if (typeof window === 'undefined') throw new Error('wake-word is client-only')
|
||
if (window.ort) return window.ort
|
||
if (_ortLoadPromise) return _ortLoadPromise
|
||
_ortLoadPromise = new Promise((resolve, reject) => {
|
||
const existing = document.querySelector(`script[src="${ORT_SCRIPT_URL}"]`)
|
||
if (existing) {
|
||
existing.addEventListener('load', () => resolve(window.ort))
|
||
existing.addEventListener('error', () => reject(new Error('ort script failed')))
|
||
return
|
||
}
|
||
const s = document.createElement('script')
|
||
s.src = ORT_SCRIPT_URL
|
||
s.async = true
|
||
s.onload = () => {
|
||
const ort = window.ort
|
||
if (!ort) return reject(new Error('ort not defined after load'))
|
||
ort.env.wasm.numThreads = 1
|
||
ort.env.wasm.simd = true
|
||
ort.env.wasm.wasmPaths = '/vad/'
|
||
resolve(ort)
|
||
}
|
||
s.onerror = () => reject(new Error('ort script load error'))
|
||
document.head.appendChild(s)
|
||
})
|
||
return _ortLoadPromise
|
||
}
|
||
|
||
// Параметры pipeline (как в openwakeword)
|
||
const AUDIO_CHUNK = 1280 // 80мс @ 16kHz
|
||
const MEL_BINS = 32
|
||
const MEL_WINDOW = 76 // фреймов на embedding
|
||
const MEL_STRIDE = 8 // шаг в фреймах
|
||
const EMB_DIM = 96
|
||
const EMB_WINDOW = 16 // последние 16 embedding'ов идут в classifier
|
||
|
||
export interface WakeWordOptions {
|
||
modelPath: string // путь к classifier (cosmo.onnx)
|
||
melPath?: string // /wake/melspectrogram.onnx
|
||
embPath?: string // /wake/embedding_model.onnx
|
||
workletPath?: string // /wake/wake-capture-worklet.js
|
||
threshold?: number // 0..1, по умолчанию 0.5
|
||
cooldownMs?: number // dead-time после успешного wake, 2000ms по умолчанию
|
||
onWake: (score: number) => void
|
||
onScore?: (score: number) => void // опц. для отладки
|
||
onError?: (e: Error) => void
|
||
}
|
||
|
||
export class WakeWordDetector {
|
||
private opts: Required<Omit<WakeWordOptions, 'onScore' | 'onError'>> & Pick<WakeWordOptions, 'onScore' | 'onError'>
|
||
private ctx: AudioContext | null = null
|
||
private stream: MediaStream | null = null
|
||
private source: MediaStreamAudioSourceNode | null = null
|
||
private worklet: AudioWorkletNode | null = null
|
||
private mel: any = null
|
||
private emb: any = null
|
||
private cls: any = null
|
||
// I/O имена тензоров — тащим из session.input/outputNames.
|
||
private melInName = ''
|
||
private melOutName = ''
|
||
private embInName = ''
|
||
private embOutName = ''
|
||
private clsInName = ''
|
||
private clsOutName = ''
|
||
// Кольцевые буферы
|
||
private melBuf: Float32Array = new Float32Array(0) // flatten [T*32]
|
||
private melFrames = 0
|
||
private embBuf: Float32Array[] = [] // массив 96-D векторов
|
||
private cooldownChunks = 0
|
||
private running = false
|
||
|
||
constructor(options: WakeWordOptions) {
|
||
this.opts = {
|
||
melPath: '/wake/melspectrogram.onnx',
|
||
embPath: '/wake/embedding_model.onnx',
|
||
workletPath: '/wake/wake-capture-worklet.js',
|
||
threshold: 0.5,
|
||
cooldownMs: 2000,
|
||
...options,
|
||
}
|
||
}
|
||
|
||
async start(externalStream?: MediaStream): Promise<void> {
|
||
if (this.running) return
|
||
console.log('[wake] start: loading ort + models')
|
||
const t0 = performance.now()
|
||
const ort = await getOrt()
|
||
console.log(`[wake] ort ready in ${(performance.now() - t0).toFixed(0)}ms`)
|
||
|
||
// 1. Загружаем модели параллельно (до user gesture, чтобы AudioContext не висел)
|
||
const [mel, emb, cls] = await Promise.all([
|
||
ort.InferenceSession.create(this.opts.melPath, { executionProviders: ['wasm'] }),
|
||
ort.InferenceSession.create(this.opts.embPath, { executionProviders: ['wasm'] }),
|
||
ort.InferenceSession.create(this.opts.modelPath, { executionProviders: ['wasm'] }),
|
||
])
|
||
this.mel = mel
|
||
this.emb = emb
|
||
this.cls = cls
|
||
this.melInName = mel.inputNames[0]; this.melOutName = mel.outputNames[0]
|
||
this.embInName = emb.inputNames[0]; this.embOutName = emb.outputNames[0]
|
||
this.clsInName = cls.inputNames[0]; this.clsOutName = cls.outputNames[0]
|
||
console.log(`[wake] models loaded in ${(performance.now() - t0).toFixed(0)}ms`,
|
||
{ mel: { in: this.melInName, out: this.melOutName },
|
||
emb: { in: this.embInName, out: this.embOutName },
|
||
cls: { in: this.clsInName, out: this.clsOutName } })
|
||
|
||
// 2. Audio context @ 16kHz (если браузер не уважит — обработаем на стороне)
|
||
this.ctx = new AudioContext({ sampleRate: 16000 })
|
||
if (this.ctx.state === 'suspended') await this.ctx.resume()
|
||
console.log(`[wake] AudioContext sampleRate=${this.ctx.sampleRate} state=${this.ctx.state}`)
|
||
if (this.ctx.sampleRate !== 16000) {
|
||
console.warn(`[wake] AudioContext sampleRate=${this.ctx.sampleRate}, ожидается 16000 — wake-word скорее всего не сработает`)
|
||
this.opts.onError?.(new Error(`AudioContext sampleRate=${this.ctx.sampleRate}`))
|
||
}
|
||
|
||
// 3. Mic stream
|
||
this.stream = externalStream ?? await navigator.mediaDevices.getUserMedia({
|
||
audio: { echoCancellation: true, noiseSuppression: true, autoGainControl: false },
|
||
})
|
||
|
||
// 4. AudioWorklet
|
||
await this.ctx.audioWorklet.addModule(this.opts.workletPath)
|
||
this.source = this.ctx.createMediaStreamSource(this.stream)
|
||
this.worklet = new AudioWorkletNode(this.ctx, 'wake-capture')
|
||
let chunkCount = 0
|
||
this.worklet.port.onmessage = (e) => {
|
||
if (chunkCount === 0) console.log('[wake] first audio chunk received')
|
||
chunkCount++
|
||
this.onChunk(e.data as Float32Array)
|
||
}
|
||
this.source.connect(this.worklet)
|
||
// Worklet не подключается к destination → не звучит в колонках.
|
||
|
||
this.running = true
|
||
console.log('[wake] running')
|
||
}
|
||
|
||
async stop(): Promise<void> {
|
||
this.running = false
|
||
try { this.worklet?.disconnect() } catch {}
|
||
try { this.source?.disconnect() } catch {}
|
||
try { this.stream?.getTracks().forEach((t) => t.stop()) } catch {}
|
||
try { await this.ctx?.close() } catch {}
|
||
this.worklet = null; this.source = null; this.stream = null; this.ctx = null
|
||
this.melBuf = new Float32Array(0); this.melFrames = 0; this.embBuf = []
|
||
this.cooldownChunks = 0
|
||
}
|
||
|
||
/** На время записи команды — отключаем wake-обработку, не освобождая ресурсы. */
|
||
pause() { this.running = false }
|
||
resume() {
|
||
if (this.mel && this.emb && this.cls && this.ctx) {
|
||
this.running = true
|
||
// Сбрасываем буферы — иначе хвост старого аудио вызовет ложный wake.
|
||
this.melBuf = new Float32Array(0); this.melFrames = 0; this.embBuf = []
|
||
this.cooldownChunks = 0
|
||
}
|
||
}
|
||
|
||
private async onChunk(chunk: Float32Array) {
|
||
if (!this.running || !this.mel || !this.emb || !this.cls) return
|
||
if (this.cooldownChunks > 0) { this.cooldownChunks--; return }
|
||
|
||
const ort = await getOrt()
|
||
|
||
// openWakeWord ожидает float32 в range int16 (≈ ×32768)
|
||
const audio = new Float32Array(AUDIO_CHUNK)
|
||
for (let i = 0; i < AUDIO_CHUNK; i++) audio[i] = chunk[i] * 32768
|
||
|
||
try {
|
||
// 1. Mel-spectrogram
|
||
const melTensor = new ort.Tensor('float32', audio, [1, AUDIO_CHUNK])
|
||
const melOut = await this.mel.run({ [this.melInName]: melTensor })
|
||
const melData = melOut[this.melOutName].data as Float32Array
|
||
const melDims = melOut[this.melOutName].dims as readonly number[]
|
||
// Ожидается [1, T, 32] (или [1, 1, T, 32] / [T, 32]). Извлекаем T фреймов по 32 бина.
|
||
const newFrames = melDims.length === 4 ? melDims[2] : melDims.length === 3 ? melDims[1] : melDims[0]
|
||
const expected = newFrames * MEL_BINS
|
||
if (melData.length < expected) return
|
||
// Скейлинг как в openwakeword: x/10 + 2
|
||
const scaled = new Float32Array(expected)
|
||
for (let i = 0; i < expected; i++) scaled[i] = melData[i] / 10 + 2
|
||
|
||
// Append к mel-буферу (flatten)
|
||
const merged = new Float32Array(this.melBuf.length + scaled.length)
|
||
merged.set(this.melBuf); merged.set(scaled, this.melBuf.length)
|
||
this.melBuf = merged
|
||
this.melFrames += newFrames
|
||
|
||
// 2. Sliding embedding — пока хватает на одно окно, считаем и сдвигаем
|
||
while (this.melFrames >= MEL_WINDOW) {
|
||
const window = this.melBuf.subarray(0, MEL_WINDOW * MEL_BINS)
|
||
// Embedding model: вход [1, 76, 32, 1]
|
||
const embInput = new Float32Array(MEL_WINDOW * MEL_BINS)
|
||
embInput.set(window)
|
||
const embTensor = new (ort as any).Tensor('float32', embInput, [1, MEL_WINDOW, MEL_BINS, 1])
|
||
const embOut = await this.emb.run({ [this.embInName]: embTensor })
|
||
const embData = embOut[this.embOutName].data as Float32Array
|
||
// embedding ожидается длины 96 (последние EMB_DIM)
|
||
const e = new Float32Array(EMB_DIM)
|
||
e.set(embData.slice(-EMB_DIM))
|
||
this.embBuf.push(e)
|
||
if (this.embBuf.length > EMB_WINDOW + 4) this.embBuf.shift()
|
||
|
||
// Сдвигаем mel-буфер на MEL_STRIDE фреймов
|
||
this.melBuf = this.melBuf.slice(MEL_STRIDE * MEL_BINS)
|
||
this.melFrames -= MEL_STRIDE
|
||
|
||
// 3. Classifier
|
||
if (this.embBuf.length >= EMB_WINDOW) {
|
||
const last = this.embBuf.slice(-EMB_WINDOW)
|
||
const flat = new Float32Array(EMB_WINDOW * EMB_DIM)
|
||
for (let i = 0; i < EMB_WINDOW; i++) flat.set(last[i], i * EMB_DIM)
|
||
const clsTensor = new (ort as any).Tensor('float32', flat, [1, EMB_WINDOW, EMB_DIM])
|
||
const clsOut = await this.cls.run({ [this.clsInName]: clsTensor })
|
||
const score = (clsOut[this.clsOutName].data as Float32Array)[0]
|
||
this.opts.onScore?.(score)
|
||
if (score >= this.opts.threshold) {
|
||
const cooldownChunks = Math.ceil(this.opts.cooldownMs / 80)
|
||
this.cooldownChunks = cooldownChunks
|
||
this.embBuf = [] // сброс — не зацикливаем wake
|
||
this.melBuf = new Float32Array(0); this.melFrames = 0
|
||
this.opts.onWake(score)
|
||
return
|
||
}
|
||
}
|
||
}
|
||
} catch (e) {
|
||
console.error('[wake-word] chunk error:', e)
|
||
this.opts.onError?.(e as Error)
|
||
}
|
||
}
|
||
}
|