Skip to content

Commit

Permalink
"Eager" streaming mode via word timestamps (argmaxinc#95)
Browse files Browse the repository at this point in the history
* Add streaming with word timestamps algorithm and test

* Support distil models

* Lint

* Improve word level streaming

- Also adds streaming simulation to CLI

* Restore missing cli flag

* Test 3 agreement words

* Minor cleanup

* Fix merge

* Remove extra visionos check

* Test cleanup

* Update prefill token logic, add prompt and prefix to cli

* Add experimental eager mode to example app

* Refinement for streaming accuracy

* Use brew for huggingface-cli instead of python

* Add option to hide decoder preview, and various UI cleanup

* Add check for word timestamp support

* Fix punctuation merging

* Cleanup and fix test

* Remove build warning
  • Loading branch information
ZachNagengast authored Mar 30, 2024
1 parent 7ca089e commit 97e655c
Show file tree
Hide file tree
Showing 15 changed files with 896 additions and 317 deletions.
4 changes: 2 additions & 2 deletions Examples/WhisperAX/WhisperAX.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -890,7 +890,7 @@
LD_RUNPATH_SEARCH_PATHS = "@executable_path/Frameworks";
"LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = "@executable_path/../Frameworks";
MACOSX_DEPLOYMENT_TARGET = 14.0;
MARKETING_VERSION = 0.1.2;
MARKETING_VERSION = 0.2.0;
PRODUCT_BUNDLE_IDENTIFIER = com.argmax.whisperkit.WhisperAX;
PRODUCT_NAME = "$(TARGET_NAME)";
SDKROOT = auto;
Expand Down Expand Up @@ -934,7 +934,7 @@
LD_RUNPATH_SEARCH_PATHS = "@executable_path/Frameworks";
"LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = "@executable_path/../Frameworks";
MACOSX_DEPLOYMENT_TARGET = 14.0;
MARKETING_VERSION = 0.1.2;
MARKETING_VERSION = 0.2.0;
PRODUCT_BUNDLE_IDENTIFIER = com.argmax.whisperkit.WhisperAX;
PRODUCT_NAME = "$(TARGET_NAME)";
SDKROOT = auto;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
"kind" : "remoteSourceControl",
"location" : "https://github.com/huggingface/swift-transformers.git",
"state" : {
"revision" : "4f915610451d29a05948802a140880ff37494dad",
"version" : "0.1.6"
"revision" : "74b94211bdc741694ed7e700a1104c72e5ba68fe",
"version" : "0.1.7"
}
}
],
Expand Down
400 changes: 312 additions & 88 deletions Examples/WhisperAX/WhisperAX/Views/ContentView.swift

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ setup:
@echo "Setting up environment..."
@which $(PIP_COMMAND)
@which $(PYTHON_COMMAND)
@$(PIP_COMMAND) install -U huggingface_hub
@echo "Checking for Homebrew..."
@which brew > /dev/null || (echo "Error: Homebrew is not installed. Install it form here https://brew.sh and try again" && exit 1)
@echo "Homebrew is installed."
@echo "Checking for huggingface-cli..."
@which huggingface-cli > /dev/null || (echo "Installing huggingface-cli..." && brew install huggingface-cli)
@echo "huggingface-cli is installed."
@echo "Checking for git-lfs..."
@which git-lfs > /dev/null || (echo "Installing git-lfs..." && brew install git-lfs)
@echo "git-lfs is installed."
Expand Down
4 changes: 2 additions & 2 deletions Package.resolved
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
"kind" : "remoteSourceControl",
"location" : "https://github.com/huggingface/swift-transformers.git",
"state" : {
"revision" : "4f915610451d29a05948802a140880ff37494dad",
"version" : "0.1.6"
"revision" : "74b94211bdc741694ed7e700a1104c72e5ba68fe",
"version" : "0.1.7"
}
}
],
Expand Down
2 changes: 1 addition & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ let package = Package(
),
],
dependencies: [
.package(url: "https://github.com/huggingface/swift-transformers.git", exact: "0.1.6"),
.package(url: "https://github.com/huggingface/swift-transformers.git", exact: "0.1.7"),
.package(url: "https://github.com/apple/swift-argument-parser.git", exact: "1.3.0"),
],
targets: [
Expand Down
47 changes: 34 additions & 13 deletions Sources/WhisperKit/Core/Models.swift
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ public struct ModelComputeOptions {
self.prefillCompute = prefillCompute
self.textDecoderCompute = textDecoderCompute

if #available(macOS 14.0, iOS 17.0, watchOS 10, visionOS 1, *) {
if #available(macOS 14.0, iOS 17.0, watchOS 10, *) {
self.audioEncoderCompute = audioEncoderCompute ?? .cpuAndNeuralEngine
} else {
self.audioEncoderCompute = audioEncoderCompute ?? .cpuAndGPU
Expand Down Expand Up @@ -213,7 +213,11 @@ public struct DecodingCache {
/// - usePrefillCache: If true, the kv cache will be prefilled based on the prefill data mlmodel.
/// - skipSpecialTokens: Whether to skip special tokens in the output.
/// - withoutTimestamps: Whether to include timestamps in the transcription result.
/// - wordTimestamps: Whether to include word-level timestamps in the transcription result.
/// - maxInitialTimestamp: Maximal initial timestamp.
/// - clipTimestamps: Array of timestamps (in seconds) to split the audio into segments for transcription.
/// - promptTokens: Array of token IDs to use as the conditioning prompt for the decoder. These are prepended to the prefill tokens.
/// - prefixTokens: Array of token IDs to use as the initial prefix for the decoder. These are appended to the prefill tokens.
/// - suppressBlank: If true, blank tokens will be suppressed during decoding.
/// - supressTokens: List of token IDs to suppress during decoding.
/// - compressionRatioThreshold: If the compression ratio of the transcription text is above this value, it is too repetitive and treated as failed.
Expand All @@ -238,6 +242,8 @@ public struct DecodingOptions {
public var wordTimestamps: Bool
public var maxInitialTimestamp: Float?
public var clipTimestamps: [Float]
public var promptTokens: [Int]?
public var prefixTokens: [Int]?
public var suppressBlank: Bool
public var supressTokens: [Int]
public var compressionRatioThreshold: Float?
Expand All @@ -260,6 +266,8 @@ public struct DecodingOptions {
wordTimestamps: Bool = false,
maxInitialTimestamp: Float? = nil,
clipTimestamps: [Float] = [],
promptTokens: [Int]? = nil,
prefixTokens: [Int]? = nil,
suppressBlank: Bool = false,
supressTokens: [Int]? = nil,
compressionRatioThreshold: Float? = 2.4,
Expand All @@ -282,6 +290,8 @@ public struct DecodingOptions {
self.wordTimestamps = wordTimestamps
self.maxInitialTimestamp = maxInitialTimestamp
self.clipTimestamps = clipTimestamps
self.promptTokens = promptTokens
self.prefixTokens = prefixTokens
self.suppressBlank = suppressBlank
self.supressTokens = supressTokens ?? [] // nonSpeechTokens() // TODO: implement these as default
self.compressionRatioThreshold = compressionRatioThreshold
Expand Down Expand Up @@ -400,19 +410,25 @@ public struct TranscriptionResult: Codable {
public var timings: TranscriptionTimings?
}

public extension TranscriptionResult {
var allWords: [WordTiming] {
return segments.compactMap { $0.words }.flatMap { $0 }
}
}

public struct TranscriptionSegment: Hashable, Codable {
public var id: Int
public var seek: Int
public var start: Float
public var end: Float
public var text: String
public var tokens: [Int]
public var tokenLogProbs: [[Int: Float]]
public var temperature: Float
public var avgLogprob: Float
public var compressionRatio: Float
public var noSpeechProb: Float
public var words: [WordTiming]?
public var id: Int = 0
public var seek: Int = 0
public var start: Float = 0.0
public var end: Float = 0.0
public var text: String = ""
public var tokens: [Int] = []
public var tokenLogProbs: [[Int: Float]] = [[:]]
public var temperature: Float = 1.0
public var avgLogprob: Float = 0.0
public var compressionRatio: Float = 1.0
public var noSpeechProb: Float = 0.0
public var words: [WordTiming]? = nil
}

public struct WordTiming: Hashable, Codable {
Expand Down Expand Up @@ -889,6 +905,7 @@ public struct SpecialTokens {
public let noSpeechToken: Int
public let noTimestampsToken: Int
public let specialTokenBegin: Int
public let startOfPreviousToken: Int
public let startOfTranscriptToken: Int
public let timeTokenBegin: Int
public let transcribeToken: Int
Expand All @@ -901,6 +918,7 @@ public struct SpecialTokens {
noSpeechToken: Int,
noTimestampsToken: Int,
specialTokenBegin: Int,
startOfPreviousToken: Int,
startOfTranscriptToken: Int,
timeTokenBegin: Int,
transcribeToken: Int,
Expand All @@ -912,6 +930,7 @@ public struct SpecialTokens {
self.noSpeechToken = noSpeechToken
self.noTimestampsToken = noTimestampsToken
self.specialTokenBegin = specialTokenBegin
self.startOfPreviousToken = startOfPreviousToken
self.startOfTranscriptToken = startOfTranscriptToken
self.timeTokenBegin = timeTokenBegin
self.transcribeToken = transcribeToken
Expand All @@ -938,6 +957,7 @@ struct WhisperTokenizerWrapper: WhisperTokenizer {
noSpeechToken: tokenizer.convertTokenToId("<|nospeech|>") ?? Self.defaultNoSpeechToken,
noTimestampsToken: tokenizer.convertTokenToId("<|notimestamps|>") ?? Self.defaultNoTimestampsToken,
specialTokenBegin: tokenizer.convertTokenToId("<|endoftext|>") ?? Self.defaultSpecialTokenBegin,
startOfPreviousToken: tokenizer.convertTokenToId("<|startofprev|>") ?? Self.defaultStartOfPreviousToken,
startOfTranscriptToken: tokenizer.convertTokenToId("<|startoftranscript|>") ?? Self.defaultStartOfTranscriptToken,
timeTokenBegin: tokenizer.convertTokenToId("<|0.00|>") ?? Self.defaultTimeTokenBegin,
transcribeToken: tokenizer.convertTokenToId("<|transcribe|>") ?? Self.defaultTranscribeToken,
Expand Down Expand Up @@ -1200,6 +1220,7 @@ extension WhisperTokenizerWrapper {
static var defaultWhitespaceToken: Int { 220 }
static var defaultSpecialTokenBegin: Int { 50257 }
static var defaultEndToken: Int { 50257 }
static var defaultStartOfPreviousToken: Int { 50361 }
static var defaultStartOfTranscriptToken: Int { 50258 }
static var defaultEnglishToken: Int { 50259 }
static var defaultTranscribeToken: Int { 50359 }
Expand Down
69 changes: 43 additions & 26 deletions Sources/WhisperKit/Core/SegmentSeeker.swift
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ open class SegmentSeeker: SegmentSeeking {
// check if single or double timestamp ending
let lastThreeTokens = isTimestampToken.suffix(3)
let singleTimestampEnding = lastThreeTokens == [false, true, false]
let noTimestampEnding = lastThreeTokens == [false, false, false]


// find all end indexes of time token pairs
var sliceIndexes = [Int]()
Expand All @@ -104,6 +106,8 @@ open class SegmentSeeker: SegmentSeeking {
if singleTimestampEnding {
let singleTimestampEndingIndex = isTimestampToken.lastIndex(where: { $0 })!
sliceIndexes.append(singleTimestampEndingIndex + 1)
} else if noTimestampEnding {
sliceIndexes.append(currentTokens.count)
}

var lastSliceStart = 0
Expand Down Expand Up @@ -138,10 +142,14 @@ open class SegmentSeeker: SegmentSeeking {
}

// Seek to the last timestamp in the segment
let lastTimestampToken = currentTokens[lastSliceStart] - timeToken
let lastTimestampSeconds = Float(lastTimestampToken) * secondsPerTimeToken
let lastTimestampSamples = Int(lastTimestampSeconds * Float(sampleRate))
seek += lastTimestampSamples
if !noTimestampEnding {
let lastTimestampToken = currentTokens[lastSliceStart] - timeToken
let lastTimestampSeconds = Float(lastTimestampToken) * secondsPerTimeToken
let lastTimestampSamples = Int(lastTimestampSeconds * Float(sampleRate))
seek += lastTimestampSamples
} else {
seek += segmentSize
}
} else {
// Model is not giving any consecutive timestamps, so lump all the current tokens together
var durationSeconds = Float(segmentSize) / Float(sampleRate)
Expand Down Expand Up @@ -274,41 +282,50 @@ open class SegmentSeeker: SegmentSeeking {
}

func mergePunctuations(alignment: [WordTiming], prepended: String, appended: String) -> [WordTiming] {
var mergedAlignment = [WordTiming]()
var prependedAlignment = [WordTiming]()
var appendedAlignment = [WordTiming]()

// Include the first word if it's not a prepended punctuation
if !alignment.isEmpty && !prepended.contains(alignment[0].word.trimmingCharacters(in: .whitespaces)) {
mergedAlignment.append(alignment[0])
prependedAlignment.append(alignment[0])
}

// Merge prepended punctuations
for i in 1..<alignment.count {
let currentWord = alignment[i]
if i > 1, currentWord.word.starts(with: " "), prepended.contains(currentWord.word.trimmingCharacters(in: .whitespaces)) {
mergedAlignment[mergedAlignment.count - 1].word += currentWord.word
mergedAlignment[mergedAlignment.count - 1].tokens += currentWord.tokens
mergedAlignment[mergedAlignment.count - 1].end = currentWord.end
var currentWord = alignment[i]
let previousWord = alignment[i - 1]
if previousWord.word.starts(with: " "), prepended.contains(previousWord.word.trimmingCharacters(in: .whitespaces)) {
currentWord.word = previousWord.word + currentWord.word
currentWord.tokens = previousWord.tokens + currentWord.tokens
currentWord.start = min(previousWord.start, currentWord.start)
prependedAlignment[prependedAlignment.count - 1] = currentWord
} else {
mergedAlignment.append(currentWord)
prependedAlignment.append(currentWord)
}
}

// Include the first word always for append checks
if !prependedAlignment.isEmpty {
appendedAlignment.append(prependedAlignment[0])
}

// Merge appended punctuations
var i = 0
while i < mergedAlignment.count {
var shouldSkipNextWord = false
if i < mergedAlignment.count - 1, appended.contains(mergedAlignment[i + 1].word) {
mergedAlignment[i].word += mergedAlignment[i + 1].word
mergedAlignment[i].tokens += mergedAlignment[i + 1].tokens
mergedAlignment[i].end = mergedAlignment[i + 1].end
shouldSkipNextWord = true
for i in 1..<prependedAlignment.count {
let currentWord = prependedAlignment[i]
var previousWord = prependedAlignment[i - 1]
if !previousWord.word.hasSuffix(" "), appended.contains(currentWord.word.trimmingCharacters(in: .whitespaces)) {
previousWord.word = previousWord.word + currentWord.word
previousWord.tokens = previousWord.tokens + currentWord.tokens
previousWord.end = max(previousWord.end, currentWord.end)
appendedAlignment[appendedAlignment.count - 1] = previousWord
} else {
appendedAlignment.append(currentWord)
}

i += shouldSkipNextWord ? 2 : 1
}

// Filter out the empty word timings and punctuation words that have been merged
return mergedAlignment.filter { !$0.word.isEmpty && !appended.contains($0.word) && !prepended.contains($0.word) }
let mergedAlignment = appendedAlignment.filter { !$0.word.isEmpty && !appended.contains($0.word) && !prepended.contains($0.word) }
return mergedAlignment
}

func findAlignment(
Expand Down Expand Up @@ -489,9 +506,9 @@ open class SegmentSeeker: SegmentSeeking {
continue
}

let start = round((timeOffset + timing.start) * 100) / 100.0
let end = round((timeOffset + timing.end) * 100) / 100.0
let probability = round(timing.probability * 100) / 100.0
let start = (timeOffset + timing.start).rounded(2)
let end = (timeOffset + timing.end).rounded(2)
let probability = timing.probability.rounded(2)
let wordTiming = WordTiming(word: timing.word,
tokens: timingTokens,
start: start,
Expand Down
Loading

0 comments on commit 97e655c

Please sign in to comment.