Stochastic Gradient Descent with Importance Sampling

Stochastic Gradient Descent with
Importance Sampling
Rachel Ward
UT Austin
Joint work with
Deanna Needell (Claremont McKenna College)
and Nathan Srebro (TTIC / Technion)
Recall: Gradient descent
Problem:
Minimize F (x) =
1
n
Pn
i=1
fi (x),
n very large
Gradient descent:
Initialize x0
x(j+1) = x(j)
= x(j)
rF (x(j) )
n
X
rfi (x(j) )
i=1
Not practical in huge dimension n!
Stochastic gradient descent
Minimize F (x) =
Initialize x0
1
n
Pn
i=1
fi (x)
For j = 1, 2, . . .
draw index i = ij at random
x(j+1) = x(j)
rfi (x(j) )
Goal: nonasymptotic bounds on Ekx(j)
x ⇤ k2
n
X
1
Minimize F (x) =
fi (x) = Efi (x)
n i=1
We begin by assuming
1. krfi (x)
rfi (y)k2  Li kx
2. F is strongly convex: hx
yk2
y, rF (x)
rF (y)i
µkx
Convergence rate of SGD should depend on
1. A condition number,
av =
1
n
P
i Li
,
µ
2. Consistency:
max =
2
maxi Li
,
µ
2 =
= Ekrfi (x⇤ )k2
p1
n
P
µ
i
L2i
yk2
SGD - convergence rates
(Needell, Srebro, W’, 2013) :
Under smoothness and convexity assumptions, the SGD iterates satisfy
Ekx
(k)
2
x⇤ k  [1
Corollary:
Ekx(k)
k
supi Li )] kx0
2 µ(1
x⇤ k22  " after
k = 2 log("/"0 )
⇣
supi Li
µ
+
2
µ2 "
⇤ 2
x k +
2
µ(1
supi Li ) ,
⌘
SGD iterations, with optimized constant step-size.
SGD - convergence rates
We showed:
⇣
k = 2 log("/"0 ) supµi Li +
2
µ2 "
⌘
SGD iterations suffice for Ekx(k)
x⇤ k22  "
Compare to:
(Bach and Moulines, 2011) :
k = 2 log("/"0 )
✓p
1
n
P
µ
i
L2i
◆2
SGD iterations suffice for Ekx(k)
+
2
µ2 "
!
x⇤ k22  "
k/
supi Li
µ
vs. k /
⇣1
n
P
2
i Li
µ
⌘2
steps
Di↵erence in proof is that we use the co-coercivity lemma
for smooth functions with Lipschitz gradient
kx(k+1)
x⇤ k2 = kx(k)
 kx(k)
+2
2
 kx(k)
+2
2
x⇤
rfi (xk )k2
x ⇤ k2
2 < x(k)
x ⇤ k2
2 < x(k)
krfi (x(k) )
Li < x(k)
x⇤ , rfi (x(k) ) >
rfi (x⇤ )k2 + 2
2
krfi (x⇤ )k2
x⇤ , rfi (x(k) ) >
x⇤ , rfi (xk )
rfi (x⇤ ) > + 2
2
krfi (x⇤ )k2
These convergence rates are tight
Consider the least squares case:
n
1X
2
F (x) =
(hai , xi bi )
2 i=1
0
B
B
B
B
B
1
2
B
= kAx bk
A=B
2
B
B
2
B
Assume consistency: Ax⇤ = b,
=0
B
@
supi Li
= (n sup kai k2 )(kA† k2 )
µ
i
a1
a2
a3
..
.
..
.
..
.
an
1
C
C
C
C
C
C
C
C
C
C
C
A
These convergence rates are tight
Consider the system
Here,
supi Li
µ
0
B
B
B
B
B
@
1
0
0
..
.
0
0p
1/pn
1/ n
..
.
p
1/ n
1
0
C✓
◆ B
C
B
C 1
B
=B
C
C 0
B
A
@
1
0
0
..
.
0
1
C
C
C
C
C
A
= n supi kai k2 kA† k2 = n
In this example, we need k = n steps to get any accuracy
Better convergence rates using
weighted sampling strategies?
SGD with weighted sampling
observe:
F (x) =
1
n
P
(w) 1
f
(x)
=
E
i i
wi fi (x)
given weights wi such that
P
i
wi = n.
SGD unbiased update with weighted sampling: Let f˜i =
x(j+1)
=
x(j)
=
x(j)
where P(ik = i) =
rf˜ik (x(j) )
1
rfik (x(j) )
w ik
Pwi
j wj
1
wi f i
SGD with weighted sampling
F (x) =
1
n
P
(w) 1
f
(x)
=
E
i i
wi fi (x)
given weights wi such that
P
i
wi = n.
Weighted sampling strategies in stochastic optimization not new:
• (Strohmer, Vershynin 2009): randomized Kaczmarz algorithm
• (Nesterov 2010, Lee, Sidford 2013): biased sampling for
accelerated stochastic coordinate descent
SGD with weighted sampling
Our previous result: for F (x) = Efi (x) and xk+1 = xk
Ekxk
x⇤ k22  "
after k = 2 log("/"0 )
⇣
Lmax [fi ]
µ
+
2
[fi ]
µ2 "
⌘
steps
Corollary for weighted sampling:
E(w) kxk
x⇤ k22  " after
h
i
✓
1
Lmax w fi
i
k = 2 log("/"0 )
+
µ
2
h
1
wi fi
µ2 "
i◆
steps
rfi (xk ),
Choice of weights
For F (x) = E(w) fi (x),
E(w) kxk
x⇤ k22  " after
h
i
✓
Lmax w1 fi
i
k = 2 log("/"0 )
+
µ
If
2
2
h
1
wi fi
µ2 "
= 0, choose weights to minimize Lmax
E(w) kxk
h
i◆
1
wi f i
i
steps
:
x⇤ k22  " after
⇣1P ⌘
L
k = 2 log("/"0 ) n µi i steps, using weights wi =
nLi
P
i Li
Improved convergence rate with weighted sampling
Recall the example:
0
B
B
B
B
B
@
1
0
0
..
.
0
0p
1/pn
1/ n
..
.
p
1/ n
Since this system is consistent,
1
0
C✓
◆ B
C
B
C 1
B
=B
C
C 0
B
A
@
2
1
0
0
..
.
0
1
C
C
C
C
C
A
= 0 and
Lmax
(A) = n
µ
)
O(n) steps using uniform sampling
Li
(A) = 2
µ
)
O(1) steps using biased sampling
Improved convergence rate with weighted sampling
Recall the example:
0
B
B
B
B
B
@
1
0
0
..
.
0
0p
1/pn
1/ n
..
.
p
1/ n
1
0
C✓
◆ B
C
B
C 1
B
=B
C
C 0
B
A
@
1
0
0
..
.
0
1
C
C
C
C
C
A
But what if the system is not consistent?
E(w) kxk
x⇤ k22  " after
h
i
✓
1
Lmax w fi
i
k = 2 log("/"0 )
+
µ
nLi
P
i Li
2
h
1
wi fi
µ2 "
i◆
steps
h
1
wi f i
gives Lmax
h
i
Choosing weights wi = 1 gives 2 w1i fi =
Choosing weights wi =
Partially-biased sampling:
Choosing weights wi = 12 + 12 PnLLi i gives
ih
i
i
h
P
Lmax w1i fi  2 n1 i Li and 2 w1i fi  2
2
2
i
=
[fi ]
1
n
P
i
Li
[fi ]
Partially biased sampling gives strictly better convergence rate,
up to a factor of 2
SGD - convergence rates
Uniform sampling:
(Bach and Moulines, 2011) :
✓p
k = 2 log("/"0 )
1
n
P
i
L2i
µ
◆2
SGD iterations suffice for Ekx(k)
(Needell, Srebro,
⇣ W, 2013:):2 ⌘
k = 2 log("/"0 ) supµi Li + µ2 "
SGD iterations suffice for Ekx(k)
(Partially)
biased sampling:
k = 4 log("/"0 )
⇣1
n
P
i
µ
Li
+
2
µ2 "
⌘
SGD iterations suffice for Ekx(k)
+
2
µ2 "
!
x⇤ k22  "
x⇤ k22  "
x⇤ k22  "
We have been operating in the setting
1. Each rfi has Lipschitz constant Li :
krfi (x) rfi (y)k2  Li kx yk2
2. F has strong convexity parameter µ:
hx y, rF (x) rF (y)i µkx yk2
Other settings and weaker assumptions.
- Removing strong convexity assumption
- Non-smooth
Smoothness, but no strong convexity:
• Each rfi has Lipschitz constant Li :
krfi (x) rfi (y)k2  Li kx yk2
• (Srebro et. al 2010): Number of iterations of SGD:
✓
◆
2
(supi Li )kx⇤ k F (x⇤ ) + "
k=O
"
"
• (Foygel and Srebro 2011): Cannot replace supi Li with
Using weights wi =
1
n
Li
P
k=O
i
Li
✓
,
( n1
P
i
2
Li )kx⇤ k F (x⇤ ) + "
"
"
◆
1
n
P
i
Li
Even less restrictive, we assume now only that
1. Each fi has Lipschitz constant Gi :
kfi (x) fi (y)k2  Gi kx yk2
• (Srebro et. al 2010): Number of iterations of SGD still
depends linearly on supi Li .
• (Foygel and Srebro
2011): Dependence cannot be
P
replaced with n1 i Li (using uniform sampling)
Using weights wi =
1
n
Li
P
i Li
, dependence is replaced by
1
n
P
i
Li
Less restrictive:
kfi (x)
fi (y)k2  Gi kx
yk2
Here, SGD convergence rate depends linearly on G2i
Gi
Using weights wi = 1 P
, dependence is reduced to
G
i
i
n
P
P 2
1
1
2
( n i Gi )  n i Gi
1X 2
1X
Gi = (
Gi )2 + Var[Gi ]
n i
n i
(Zang, Zhao 2013) also consider importance sampling in this setting.
Future work:
• Strategies for sampling blocks of indices at each iteration.
Optimal block size?
• Optimal adaptive sample strategies given limited
or no information about Li ?
• Simulations on real data
Thanks!
References
[1] Needell, Srebro, Ward. “Stochastic gradient descent and the randomized
Kaczmarz algorithm.” arXiv preprint arXiv:1310.5715 (2013).