Skip to content

Commit

Permalink
port log_sum.v
Browse files Browse the repository at this point in the history
  • Loading branch information
affeldt-aist committed Oct 3, 2024
1 parent d54cb02 commit ed54195
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 94 deletions.
8 changes: 6 additions & 2 deletions lib/realType_logb.v
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@ Section ln_ext.
Context {R : realType}.
Implicit Type x : R.

Lemma ln2_gt0 : 0 < ln 2 :> R.
Proof. by rewrite ln_gt0// ltr1n. Qed.
Lemma ln2_gt0 : 0 < ln 2 :> R. Proof. by rewrite ln_gt0// ltr1n. Qed.

Lemma ln2_neq0 : ln 2 != 0 :> R. Proof. by rewrite gt_eqF// ln2_gt0. Qed.

Lemma ln2_ge0 : 0 <= ln 2 :> R. Proof. by rewrite ltW// ln2_gt0. Qed.

Lemma ln_id_cmp x : 0 < x -> ln x <= x - 1.
Proof.
move=> x0.
Expand Down Expand Up @@ -96,6 +97,9 @@ Proof. by move=> x y x0 y0; rewrite /log ler_Log. Qed.
Lemma logV x : 0 < x -> log x^-1 = - log x :> R.
Proof. by move=> x0; rewrite /log LogV. Qed.

Lemma logM x y : 0 < x -> 0 < y -> log (x * y) = log x + log y.
Proof. move=> x0 y0; exact: (@LogM _ 2 _ _ x0 y0). Qed.

Lemma logDiv x y : 0 < x -> 0 < y -> log (x / y) = log x - log y.
Proof. by move=> x0 y0; exact: (@LogDiv _ _ _ _ x0 y0). Qed.

Expand Down
185 changes: 93 additions & 92 deletions probability/log_sum.v
Original file line number Diff line number Diff line change
@@ -1,72 +1,72 @@
(* infotheo: information theory and error-correcting codes in Coq *)
(* Copyright (C) 2020 infotheo authors, license: LGPL-2.1-or-later *)
From mathcomp Require Import all_ssreflect all_algebra.
Require Import Reals Lra.
From mathcomp Require Import Rstruct lra.
Require Import ssrR realType_ext Reals_ext Ranalysis_ext logb ln_facts bigop_ext.
From mathcomp Require Import Rstruct reals lra exp.
Require Import ssrR realType_ext realType_logb bigop_ext.

(******************************************************************************)
(* The log-sum Inequality *)
(******************************************************************************)

Import GRing.Theory Num.Theory Order.TTheory.
Set Implicit Arguments.
Unset Strict Implicit.
Import Prenex Implicits.

Local Open Scope reals_ext_scope.
Local Open Scope R_scope.
Local Open Scope ring_scope.

Import Order.POrderTheory GRing.Theory Num.Theory.

Local Notation "'\sum_{' C '}' f" :=
(\sum_(a | a \in C) f a) (at level 10, format "\sum_{ C } f").

Definition log_sum_stmt {A : finType} (C : {set A}) (f g : {ffun A -> R}) :=
Definition log_sum_stmt {R : realType} {A : finType} (C : {set A}) (f g : {ffun A -> R}) :=
(forall x, 0 <= f x) ->
(forall x, 0 <= g x) ->
f `<< g ->
\sum_{C} f * log (\sum_{C} f / \sum_{C} g) <=
\sum_(a | a \in C) f a * log (f a / g a).

Lemma log_sum1 {A : finType} (C : {set A}) (f g : {ffun A -> R}) :
Lemma log_sum1 {R : realType} {A : finType} (C : {set A}) (f g : {ffun A -> R}) :
(forall a, a \in C -> 0 < f a) -> log_sum_stmt C f g.
Proof.
move=> fspos f0 g0 fg.
case/boolP : (C == set0) => [ /eqP -> | Hc].
by apply/RleP; rewrite !big_set0 mul0R lexx.
by rewrite !big_set0 mul0r lexx.
have gspos : forall a, a \in C -> 0 < g a.
move=> a a_C. case (g0 a) => //.
move=> a a_C.
rewrite lt_neqAle g0 andbT; apply/eqP.
move=>/esym/(dominatesE fg) abs.
by move: (fspos _ a_C); rewrite abs => /ltRR.
by move: (fspos _ a_C); rewrite abs ltxx.
have Fnot0 : \sum_{ C } f != 0.
apply/eqP => /psumr_eq0P abs.
case/set0Pn : Hc => a aC.
move: (fspos _ aC); rewrite abs //.
by move=> /RltP; rewrite ltxx.
by move=> i iC; exact/RleP.
by rewrite ltxx.
have Gnot0 : \sum_{ C } g != 0.
apply/eqP => /psumr_eq0P abs.
case/set0Pn : Hc => a aC.
move: (gspos _ aC); rewrite abs //.
by move=> /RltP; rewrite ltxx.
by move=> i iC; exact/RleP.
by move: (gspos _ aC); rewrite abs // ltxx.
wlog : Fnot0 g g0 Gnot0 fg gspos / \sum_{ C } f = \sum_{ C } g.
move=> Hwlog.
set k := (\sum_{ C } f / \sum_{ C } g).
have Fspos : 0 < \sum_{ C } f.
suff Fpos : 0 <= \sum_{ C } f by apply/RltP; rewrite lt0r Fnot0; exact/RleP.
by apply/RleP/sumr_ge0 => ? ?; exact/RleP/ltRW/fspos.
suff Fpos : 0 <= \sum_{ C } f by rewrite lt0r Fnot0.
by apply/sumr_ge0 => ? ?; exact/ltW/fspos.
have Gspos : 0 < \sum_{ C } g.
suff Gpocs : 0 <= \sum_{ C } g by apply/RltP; rewrite lt0r Gnot0; exact/RleP.
by apply/RleP/sumr_ge0 => ? ?; exact/RleP/ltRW/gspos.
have kspos : 0 < k by exact: divR_gt0.
suff Gpocs : 0 <= \sum_{ C } g by rewrite lt0r Gnot0.
by apply/sumr_ge0 => ? ?; exact/ltW/gspos.
have kspos : 0 < k by exact: divr_gt0.
set kg := [ffun x => k * g x].
have kg_pos : forall a, 0 <= kg a.
by move=> a; rewrite /kg /= ffunE; apply mulR_ge0 => //; exact: ltRW.
by move=> a; rewrite /kg /= ffunE; apply mulr_ge0 => //; exact: ltW.
have kabs_con : f `<< kg.
apply/dominates_scale => //; exact/gtR_eqF.
by apply/dominates_scale => //; rewrite ?gt_eqF//.
have kgspos : forall a, a \in C -> 0 < kg a.
by move=> a a_C; rewrite ffunE; apply mulR_gt0 => //; exact: gspos.
by move=> a a_C; rewrite ffunE; apply mulr_gt0 => //; exact: gspos.
have Hkg : \sum_{C} kg = \sum_{C} f.
transitivity (\sum_(a in C) k * g a).
by apply eq_bigr => a aC; rewrite /= ffunE.
by rewrite -big_distrr /= /k /Rdiv -mulRA mulRC mulVR // mul1R.
by rewrite -big_distrr /= /k -mulrA mulVf ?mulr1.
have Htmp : \sum_{ C } kg != 0.
rewrite /=.
evar (h : A -> R); rewrite (eq_bigr h); last first.
Expand All @@ -75,58 +75,54 @@ wlog : Fnot0 g g0 Gnot0 fg gspos / \sum_{ C } f = \sum_{ C } g.
by apply eq_bigr => a aC /=; rewrite ffunE.
symmetry in Hkg.
move: {Hwlog}(Hwlog Fnot0 kg kg_pos Htmp kabs_con kgspos Hkg) => /= Hwlog.
rewrite Hkg {1}/Rdiv mulRV // /log Log_1 mulR0 in Hwlog.
rewrite Hkg mulfV // log1 mulr0 in Hwlog.
set rhs := \sum_(_ | _) _ in Hwlog.
rewrite (_ : rhs = \sum_(a | a \in C) (f a * log (f a / g a) - f a * log k)) in Hwlog; last first.
rewrite /rhs.
apply eq_bigr => a a_C.
rewrite /Rdiv /log LogM; last 2 first.
rewrite logM; last 2 first.
exact/fspos.
rewrite ffunE; apply/invR_gt0/mulR_gt0 => //; exact/gspos.
rewrite LogV; last first.
rewrite ffunE; apply mulR_gt0 => //; exact: gspos.
rewrite ffunE LogM //; last exact: gspos.
rewrite LogM //; last 2 first.
by rewrite ffunE invr_gt0// mulr_gt0//; exact/gspos.
rewrite logV; last first.
rewrite ffunE; apply mulr_gt0 => //; exact: gspos.
rewrite ffunE logM //; last exact: gspos.
rewrite logM //; last 2 first.
exact/fspos.
by apply invR_gt0 => //; apply gspos.
by rewrite LogV; [field | apply gspos].
rewrite big_split /= -big_morph_oppR -big_distrl /= in Hwlog.
by rewrite -subR_ge0.
by rewrite invr_gt0//; apply gspos.
by rewrite logV; [lra | apply gspos].
rewrite big_split /= -big_morph_oppr -big_distrl /= in Hwlog.
by rewrite -subr_ge0.
move=> Htmp; rewrite Htmp.
rewrite /Rdiv mulRV; last by rewrite -Htmp.
rewrite /log Log_1 mulR0.
rewrite mulfV; last by rewrite -Htmp.
rewrite log1 mulr0.
suff : 0 <= \sum_(a | a \in C) f a * ln (f a / g a).
move=> H.
rewrite /log /Rdiv.
set rhs := \sum_( _ | _ ) _.
have -> : rhs = \sum_(H | H \in C) (f H * (ln (f H / g H))) / ln 2.
rewrite /rhs.
apply eq_bigr => a a_C; by rewrite /Rdiv -mulRA.
by apply eq_bigr => a a_C; by rewrite -mulrA.
rewrite -big_distrl /=.
by apply mulR_ge0 => //; exact/invR_ge0.
apply (@leR_trans (\sum_(a | a \in C) f a * (1 - g a / f a))).
apply (@leR_trans (\sum_(a | a \in C) (f a - g a))).
rewrite big_split /= -big_morph_oppR Htmp addRN.
by apply/RleP; rewrite lexx.
apply/Req_le/eq_bigr => a a_C.
rewrite mulRDr mulR1 mulRN.
case: (Req_EM_T (g a) 0) => [->|ga_not_0].
by rewrite div0R mulR0.
by field; exact/eqP/gtR_eqF/(fspos _ a_C).
apply: leR_sumR => a C_a.
apply leR_wpmul2l; first exact/ltRW/fspos.
rewrite -[X in _ <= X]oppRK leR_oppr -ln_Rinv; last first.
apply divR_gt0; by [apply fspos | apply gspos].
rewrite invRM; last 2 first.
exact/gtR_eqF/(fspos _ C_a).
by rewrite invR_neq0' // gtR_eqF //; exact/(gspos _ C_a).
rewrite invRK mulRC; apply: leR_trans.
by apply/ln_id_cmp/divR_gt0; [apply gspos | apply fspos].
apply Req_le.
by field; exact/eqP/gtR_eqF/(fspos _ C_a).
by rewrite mulr_ge0// invr_ge0// ln2_ge0.
apply (@le_trans _ _ (\sum_(a | a \in C) f a * (1 - g a / f a))).
apply (@le_trans _ _ (\sum_(a | a \in C) (f a - g a))).
by rewrite big_split /= -big_morph_oppr Htmp subrr.
rewrite le_eqVlt; apply/orP; left; apply/eqP.
apply/eq_bigr => a a_C.
rewrite mulrDr mulr1 mulrN.
have [->|ga_not_0] := eqVneq (g a) 0.
by rewrite mul0r mulr0.
by rewrite mulrCA divff ?mulr1// gt_eqF//; exact/(fspos _ a_C).
apply: ler_sum => a C_a.
apply ler_wpmul2l; first exact/ltW/fspos.

Check warning on line 116 in probability/log_sum.v

View workflow job for this annotation

GitHub Actions / build (mathcomp/mathcomp:2.2.0-coq-8.19)

Notation ler_wpmul2l is deprecated since mathcomp 1.17.0.

Check warning on line 116 in probability/log_sum.v

View workflow job for this annotation

GitHub Actions / build (mathcomp/mathcomp:2.2.0-coq-8.19)

Notation ler_wpmul2l is deprecated since mathcomp 1.17.0.
rewrite -[X in _ <= X]opprK lerNr -lnV; last first.
by rewrite posrE divr_gt0//; [apply fspos | apply gspos].
rewrite invfM.
rewrite invrK mulrC; apply: le_trans.
by apply/ln_id_cmp; rewrite divr_gt0//; [apply gspos | apply fspos].
by rewrite opprB.
Qed.

Lemma log_sum {A : finType} (C : {set A}) (f g : {ffun A -> R}) :
Lemma log_sum {R : realType} {A : finType} (C : {set A}) (f g : {ffun A -> R}) :
log_sum_stmt C f g.
Proof.
move=> f0 g0 fg.
Expand All @@ -140,13 +136,13 @@ suff : \sum_{D} f * log (\sum_{D} f / \sum_{D} g) <=
move Hlhs : (a \in C) => lhs.
destruct lhs => //.
symmetry.
rewrite in_setU /C1 /C1 !in_set Hlhs /=.
rewrite in_setU !in_set Hlhs /=.
by destruct (f a == 0).
by rewrite in_setU in_set Hlhs /= /C1 in_set Hlhs.
by rewrite in_setU in_set Hlhs /= in_set Hlhs.
have DID' : [disjoint D & D'].
rewrite -setI_eq0.
apply/eqP/setP => a.
rewrite in_set0 /C1 /C1 in_setI !in_set.
rewrite in_set0 in_setI !in_set.
by destruct (a \in C) => //=; rewrite andNb.
have H1 : \sum_{C} f = \sum_{D} f.
rewrite setUC in DUD'.
Expand All @@ -155,45 +151,50 @@ suff : \sum_{D} f * log (\sum_{D} f / \sum_{D} g) <=
apply eq_bigr => a.
rewrite /D' in_set.
by case/andP => _ /eqP.
by rewrite big_const iter_addR mulR0 add0R.
by rewrite big_const iter_addr addr0 mul0rn add0r.
rewrite -H1 in H.
have pos_F : 0 <= \sum_{C} f by apply/RleP/sumr_ge0 => ? ?; exact/RleP.
apply (@leR_trans (\sum_{C} f * log (\sum_{C} f / \sum_{D} g))).
case/Rle_lt_or_eq_dec : pos_F => pos_F; last first.
by rewrite -pos_F !mul0R.
have H2 : 0 <= \sum_(a | a \in D) g a by apply/RleP/sumr_ge0 => ? _; exact/RleP.
case/Rle_lt_or_eq_dec : H2 => H2; last first.
have pos_F : 0 <= \sum_{C} f by apply/sumr_ge0 => ? ?.
apply (@le_trans _ _ (\sum_{C} f * log (\sum_{C} f / \sum_{D} g))).
move: pos_F; rewrite le_eqVlt => /predU1P[pos_F|pos_F].
by rewrite -pos_F !mul0r.
have H2 : 0 <= \sum_(a | a \in D) g a by apply/sumr_ge0.
move: H2; rewrite le_eqVlt => /predU1P[g0'|gt0'].
have : 0 = \sum_{D} f.
transitivity (\sum_(a | a \in D) 0).
by rewrite big_const iter_addR mulR0.
transitivity (\sum_(a | a \in D) (0:R))%R.
by rewrite big1.
apply: eq_bigr => a a_C1.
rewrite (dominatesE fg) //.
apply/(@psumr_eq0P _ _ (mem D) g) => // i _.
exact/RleP.
move=> abs; rewrite -abs in H1; rewrite H1 in pos_F.
by move/ltRR : pos_F.
by apply/(@psumr_eq0P _ _ (mem D)) => //.
by move=> abs; rewrite -abs in H1; rewrite H1 ltxx in pos_F.
have H3 : 0 < \sum_(a | a \in C) g a.
rewrite setUC in DUD'.
rewrite DUD' (big_union _ g DID') /=.
by apply: addR_gt0wr => //; apply/RleP/sumr_ge0=> ? _; exact/RleP.
apply/(leR_wpmul2l (ltRW pos_F))/Log_increasing_le => //.
by apply divR_gt0 => //; rewrite -HG.
apply/(leR_wpmul2l (ltRW pos_F))/leR_inv => //.
rewrite ltr_pwDr//.
by apply/sumr_ge0 => //.
apply/ler_wpM2l => //.
exact/ltW.
rewrite ler_log// ?posrE//; last 2 first.
by apply divr_gt0 => //; rewrite -HG.
by apply divr_gt0 => //; rewrite -HG.
apply/ler_wpM2l => //.
exact/ltW.
rewrite lef_pV2//.
rewrite setUC in DUD'.
rewrite DUD' (big_union _ g DID') /= -[X in X <= _]add0R; apply leR_add2r.
by apply/RleP/sumr_ge0 => ? ?; exact/RleP.
apply: (leR_trans H).
rewrite DUD' (big_union _ g DID') /=.
rewrite lerDr//.
by apply/sumr_ge0.
apply: (le_trans H).
rewrite setUC in DUD'.
rewrite DUD' (big_union _ (fun a => f a * log (f a / g a)) DID') /=.
rewrite (_ : \sum_(_ | _ \in D') _ = 0); last first.
transitivity (\sum_(a | a \in D') 0).
transitivity (\sum_(a | a \in D') (0:R)).
apply eq_bigr => a.
by rewrite /D' in_set => /andP[a_C /eqP ->]; rewrite mul0R.
by rewrite big_const iter_addR mulR0.
by apply/RleP; rewrite add0R lexx.
by rewrite /D' in_set => /andP[a_C /eqP ->]; rewrite mul0r.
by rewrite big1.
by rewrite add0r lexx.
apply: log_sum1 => // a.
rewrite /C1 in_set.
rewrite in_set.
case/andP => a_C fa_not_0.
case (f0 a) => // abs.
by rewrite abs eqxx in fa_not_0.
case :(f0 a) => // abs.
by rewrite lt_neqAle eq_sym fa_not_0.
Qed.

0 comments on commit ed54195

Please sign in to comment.