Skip to content

Commit

Permalink
Get whisperfile working very well
Browse files Browse the repository at this point in the history
1. GPU support is now working reliably, thanks to the sync. On my system
   it goes noticeably faster than CPU for the medium / large models, but
   pretty much the same speed for the tiny model. Since CPU mode outputs
   are more accurate & better looking, CPU will remain the default mode.

2. The recently-introduced tanhf() optimization for GeLU is now removed.
   It was only one ULP less accurate, but that was enough to cause nasty
   confusing artifacts to be introduced into whisperfile output, such as
   lines being repeated, omited, and even teleported to other locations.

3. `llamafile --trap` identified NaNs in whisperfile's logits subroutine
   which has been corrected. CPU mode output for the tiny q5_1 model has
   become nearly identical to whisper.cpp upstream. Some divergences are
   better, some are worse. But none of the errors are confusing, and our
   implementation goes 2.4x faster than upstream. On the other hand, our
   output for the medium model, is bit-for-bit identical to the upstream
   project on CPU and our implementation goes 5.7x faster.
  • Loading branch information
jart committed Aug 2, 2024
1 parent e9ee3f9 commit 5da8d62
Show file tree
Hide file tree
Showing 16 changed files with 46 additions and 61 deletions.
2 changes: 1 addition & 1 deletion build/config.mk
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ INSTALL = install

ARFLAGS = rcsD
CXXFLAGS = -frtti -std=gnu++23
CCFLAGS = -O -fexceptions -fsignaling-nans -ffunction-sections -fdata-sections
CCFLAGS = -O2 -fexceptions -fsignaling-nans -ffunction-sections -fdata-sections
CPPFLAGS_ = -iquote. -mcosmo -DGGML_MULTIPLATFORM -Wno-attributes -DLLAMAFILE_DEBUG
TARGET_ARCH = -Xx86_64-mtune=znver4

Expand Down
1 change: 1 addition & 0 deletions llama.cpp/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
params.kv_overrides.back().key[0] = 0;
}

FLAGS_READY = true;
params.n_gpu_layers = llamafile_gpu_layers(params.n_gpu_layers);

return true;
Expand Down
33 changes: 2 additions & 31 deletions llama.cpp/ggml-vector.inc
Original file line number Diff line number Diff line change
Expand Up @@ -1316,43 +1316,14 @@ Expm1f(float x)
return fmaf(p, t, t - 1);
}

/* Single-precision tanh(x) approximation.
The maximum error is 2.58 ULP.
Designed by Arm Limited. */
static inline float
Tanhf(float x)
{
union
{
float f;
unsigned i;
} u = { x };
unsigned iax = u.i & 0x7fffffff;
unsigned sign = u.i & ~0x7fffffff;

/* Above 0x1.205966p+3 tanhf rounds to 1 (or -1 for negative). */
if (iax > 0x41102cb3) {
if (iax > 0x7f800000)
return (x - x) / (x - x);
u.i = 0x3f800000 | sign;
return u.f;
}
if (iax < 0x34000000)
return x;

/* tanh(x) = (e^2x - 1) / (e^2x + 1). */
float q = Expm1f(2 * x);
return q / (q + 2);
}

void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); }
void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
void ggml_vec_log_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]); }
void ggml_vec_abs_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); }
void ggml_vec_sgn_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); }
void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; }
void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = Tanhf(x[i]); }
void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]); }
void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; }
void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
void ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); }
Expand Down Expand Up @@ -1400,7 +1371,7 @@ void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (i

static inline float ggml_gelu_f32(float x) {
// GeLU approximation that goes slower and we seem to be stuck with.
return .5f * x * (1.f + Tanhf(sqrtf(M_2_PI) * (x + .044715f * x * x * x)));
return .5f * x * (1.f + tanhf(sqrtf(M_2_PI) * (x + .044715f * x * x * x)));
}

void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
Expand Down
14 changes: 1 addition & 13 deletions llama.cpp/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -1799,24 +1799,14 @@ inline static void ggml_critical_section_start(void) {
}
}

#ifdef GGML_USE_OPENMP
void ggml_barrier(struct ggml_compute_state_shared * shared) {
if (shared->n_threads == 1) {
return;
}

#pragma omp barrier
}
#else
void ggml_barrier(struct ggml_compute_state_shared * shared) {
if (shared->n_threads == 1)
return;
int n = shared->n_threads;
atomic_int * count = &shared->n_barrier;
atomic_uint * phase = &shared->n_barrier_passed;
int n = shared->n_threads;
unsigned i = atomic_load_explicit(phase, memory_order_relaxed);
if (atomic_fetch_add_explicit(count, 1, memory_order_acq_rel) == n - 1) {
// last thread
atomic_store_explicit(count, 0, memory_order_relaxed);
atomic_store_explicit(phase, i + 1, memory_order_release);
} else {
Expand All @@ -1825,7 +1815,6 @@ void ggml_barrier(struct ggml_compute_state_shared * shared) {
pthread_pause_np();
}
}
#endif

// TODO: make this somehow automatically executed
// some sort of "sentry" mechanism
Expand Down Expand Up @@ -12808,7 +12797,6 @@ GGML_CALL void ggml_rope_yarn_corr_dims(
dims[1] = MIN(n_dims - 1, end);
}

__target_clones("avx2") // [jart]
static void ggml_compute_forward_rope_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst,
Expand Down
2 changes: 1 addition & 1 deletion llama.cpp/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14356,7 +14356,7 @@ static int llama_decode_internal(
// never helps. This number appears to be optimal for all
// models ranging from TinyLLaMA 1.1B to mighty Mixtral 8x22B.
if (n_tokens <= 2) {
n_threads = std::min(20, n_threads);
n_threads = std::min(32, n_threads);
}

llama_graph_compute(lctx, gf, n_threads);
Expand Down
1 change: 1 addition & 0 deletions llama.cpp/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2718,6 +2718,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
}
}

FLAGS_READY = true;
params.embedding = true; // [jart] #243 always enable embedding mode
params.n_gpu_layers = llamafile_gpu_layers(params.n_gpu_layers);

Expand Down
6 changes: 4 additions & 2 deletions llamafile/cuda.c
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ __static_yoink("llama.cpp/ggml-backend-impl.h");

#define NVCC_FLAGS \
(!IsWindows() ? "-std=c++11" : "-DIGNORE123"), "-O3", "--shared", "--use_fast_math", \
"-Xcudafe", "--diag_suppress=177", "--forward-unknown-to-host-compiler", \
"--compiler-options", \
"-Xcudafe", "--diag_suppress=177", "-Xcudafe", "--diag_suppress=940", "-Xcudafe", \
"--diag_suppress=1305", "--forward-unknown-to-host-compiler", "--compiler-options", \
(!IsWindows() \
? (!IsAarch64() \
? "-fPIC -O3 -march=native -mtune=native -std=c++11 -Wno-unused-function " \
Expand Down Expand Up @@ -747,6 +747,8 @@ static bool link_cuda_dso(const char *dso, const char *dir) {

static bool import_cuda_impl(void) {

npassert(FLAGS_READY);

// No dynamic linking support on OpenBSD yet.
if (IsOpenbsd())
return false;
Expand Down
2 changes: 2 additions & 0 deletions llamafile/flags.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include "llama.cpp/llama.h"

bool FLAGS_READY = false;
bool FLAG_log_disable = false;
bool FLAG_mlock = false;
bool FLAG_mmap = true;
Expand Down Expand Up @@ -324,5 +325,6 @@ void llamafile_get_flags(int argc, char **argv) {
if (!FLAG_model)
required("--model");

FLAGS_READY = true;
FLAG_n_gpu_layers = llamafile_gpu_layers(FLAG_n_gpu_layers);
}
1 change: 1 addition & 0 deletions llamafile/llamafile.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
extern "C" {
#endif

extern bool FLAGS_READY;
extern bool FLAG_log_disable;
extern bool FLAG_mlock;
extern bool FLAG_mmap;
Expand Down
2 changes: 2 additions & 0 deletions llamafile/metal.c
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,8 @@ static bool LinkMetal(const char *dso) {

static bool ImportMetalImpl(void) {

npassert(FLAGS_READY);

// Ensure this is MacOS ARM64.
if (!IsXnuSilicon()) {
return false;
Expand Down
5 changes: 3 additions & 2 deletions llamafile/pool_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <assert.h>
#include <cosmo.h>
#include <pthread.h>
#include <stdio.h>
#include <time.h>

#define BENCHMARK(ITERATIONS, WORK_PER_RUN, CODE) \
Expand All @@ -33,9 +34,9 @@
double nanos = \
(timespec_tonanos(timespec_sub(timespec_real(), start)) + work - 1) / (double)work; \
if (nanos < 1000) { \
kprintf("%10g ns %2dx %s\n", nanos, (ITERATIONS), #CODE); \
printf("%10g ns %2dx %s\n", nanos, (ITERATIONS), #CODE); \
} else { \
kprintf("%10lld ns %2dx %s\n", (long long)nanos, (ITERATIONS), #CODE); \
printf("%10lld ns %2dx %s\n", (long long)nanos, (ITERATIONS), #CODE); \
} \
} while (0)

Expand Down
2 changes: 2 additions & 0 deletions stable-diffusion.cpp/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,8 @@ void parse_args(int argc, const char** argv, SDParams& params) {
params.output_path = "output.gguf";
}
}

FLAGS_READY = true;
}

static std::string sd_basename(const std::string& path) {
Expand Down
18 changes: 13 additions & 5 deletions whisper.cpp/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,23 @@ wget https://archive.org/download/raven/raven_poe_64kb.mp3
sox raven_poe_64kb.mp3 -r 16k raven_poe_64kb.wav
```

Then you can use the large model, which is the best.
The tiny model may get some words wrong. For example, it might think
"quoth" is "quof". You can solve that using the medium model, which
enables whisperfile to decode The Raven perfectly. However it's slower.

```
wget https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-medium.en.bin
o//whisper.cpp/main -m ggml-medium.en.bin -f raven_poe_64kb.wav --no-prints
```

Lastly, there's the large model, which is the best, but also slowest.

```
wget -O whisper-large-v3.bin https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-v3.bin
o//whisper.cpp/main -m whisper-large-v3.bin -f raven_poe_64kb.wav --no-prints
```

## GPU Support
### GPU Mode

Pass the `--gpu auto` flag to use GPU mode. This is currently
experimental. It appears to be working with `whisper-tiny.en-q5_1.bin`
but isn't reliable yet for the F16 models.
Pass the `--gpu auto` flag to use GPU mode. This can be particularly
helpful in speeding up the medium and large models.
3 changes: 2 additions & 1 deletion whisper.cpp/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params
fprintf(stderr, "error: invalid --gpu flag value: %s\n", argv[i]);
exit(1);
}
return true;
} else

if (arg == "-h" || arg == "--help") {
Expand Down Expand Up @@ -214,6 +213,8 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params
}
}

FLAGS_READY = true;

return true;
}

Expand Down
2 changes: 2 additions & 0 deletions whisper.cpp/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve
}
}

FLAGS_READY = true;

return true;
}

Expand Down
13 changes: 8 additions & 5 deletions whisper.cpp/whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5143,12 +5143,15 @@ static void whisper_process_logits(
// logsumexp over timestamps
float timestamp_logprob = -INFINITY;
{
float logsumexp = 0.0f;
const float logprob_max = *std::max_element(logprobs.begin() + vocab.token_beg, logprobs.end());
float logsumexp = ggml_vec_soft_max_f32(n_logits - vocab.token_beg, 0,
&logprobs[vocab.token_beg],
logprob_max); // [jart]
if (logsumexp > 0.0f) {
timestamp_logprob = logf(logsumexp) + logprob_max;
if (logprob_max > -INFINITY) {
float logsumexp = ggml_vec_soft_max_f32(n_logits - vocab.token_beg, 0,
&logprobs[vocab.token_beg],
logprob_max); // [jart]
if (logsumexp > 0.0f) {
timestamp_logprob = logf(logsumexp) + logprob_max;
}
}
}

Expand Down

0 comments on commit 5da8d62

Please sign in to comment.