1 00:00:12,160 --> 00:00:15,679 hello again 2 00:00:13,840 --> 00:00:17,840 next up at the data science and 3 00:00:15,679 --> 00:00:20,800 analytics specialist track here at pycon 4 00:00:17,840 --> 00:00:23,840 australia 2021 we're welcoming matt 5 00:00:20,800 --> 00:00:25,119 kelsey who's a fixture at the mli meetup 6 00:00:23,840 --> 00:00:26,800 here in melbourne 7 00:00:25,119 --> 00:00:28,880 during the day he works as a machine 8 00:00:26,800 --> 00:00:31,119 learning research engineer at edge 9 00:00:28,880 --> 00:00:34,239 impulse a development platform for 10 00:00:31,119 --> 00:00:36,239 machine learning at the edge because not 11 00:00:34,239 --> 00:00:38,800 everything has to be centralized on big 12 00:00:36,239 --> 00:00:40,719 servers belonging to someone else 13 00:00:38,800 --> 00:00:42,320 matt has worked across a range of 14 00:00:40,719 --> 00:00:44,960 machine learning domains over the last 15 00:00:42,320 --> 00:00:47,760 20 years including work at thoughtworks 16 00:00:44,960 --> 00:00:50,320 are sponsored today google brain 17 00:00:47,760 --> 00:00:54,520 wavy and iws 18 00:00:50,320 --> 00:00:54,520 mat blogs at http.com 19 00:00:55,520 --> 00:00:59,680 today matt will be talking to us about 20 00:00:57,840 --> 00:01:01,280 jax which provides automatic 21 00:00:59,680 --> 00:01:02,879 differentiation 22 00:01:01,280 --> 00:01:05,600 allowing extremely high performance 23 00:01:02,879 --> 00:01:07,760 machine learning on modern accelerators 24 00:01:05,600 --> 00:01:09,600 all from python 25 00:01:07,760 --> 00:01:11,439 matt will be showing us the fundamentals 26 00:01:09,600 --> 00:01:12,799 of jaxx and an intro to some of the 27 00:01:11,439 --> 00:01:13,600 libraries that are being developed on 28 00:01:12,799 --> 00:01:15,759 top 29 00:01:13,600 --> 00:01:18,080 so please a big venueless hand of 30 00:01:15,759 --> 00:01:21,119 applause for matt kelsey and high 31 00:01:18,080 --> 00:01:23,280 performance machine learning with jax 32 00:01:21,119 --> 00:01:25,439 awesome thank you so much yeah this is 33 00:01:23,280 --> 00:01:27,200 this is a fun talk so i've been on sort 34 00:01:25,439 --> 00:01:28,400 of a jax bandwagon for about a year and 35 00:01:27,200 --> 00:01:30,000 a half now and it's just it's just 36 00:01:28,400 --> 00:01:31,840 getting better and better so i'm very 37 00:01:30,000 --> 00:01:33,200 excited to share this tool 38 00:01:31,840 --> 00:01:34,640 um yeah so this is about high 39 00:01:33,200 --> 00:01:36,079 performance machine learning so there is 40 00:01:34,640 --> 00:01:38,159 quite a bit of a focus on the sort of 41 00:01:36,079 --> 00:01:41,439 things that jax provides in terms of 42 00:01:38,159 --> 00:01:43,200 making things go particularly fast 43 00:01:41,439 --> 00:01:44,720 there's going to be three parts i guess 44 00:01:43,200 --> 00:01:46,079 to this talk there's going to be some 45 00:01:44,720 --> 00:01:47,280 sort of fundamentals so we're going to 46 00:01:46,079 --> 00:01:49,520 talk about like what are the building 47 00:01:47,280 --> 00:01:50,880 blocks that jax fundamentally gives us 48 00:01:49,520 --> 00:01:52,399 i'm going to talk a little bit about 49 00:01:50,880 --> 00:01:54,320 some of the things that 50 00:01:52,399 --> 00:01:55,920 it provides around multiple hosts 51 00:01:54,320 --> 00:01:57,759 because you know really scaling up over 52 00:01:55,920 --> 00:02:00,240 many machines is the the best way we 53 00:01:57,759 --> 00:02:01,840 have of running at very high sort of um 54 00:02:00,240 --> 00:02:04,159 very large scale 55 00:02:01,840 --> 00:02:06,079 and because jax is quite um simple in 56 00:02:04,159 --> 00:02:07,680 what it does um it's sort of you know 57 00:02:06,079 --> 00:02:09,360 very fundamental we'll touch on a little 58 00:02:07,680 --> 00:02:11,760 bit about the higher level libraries 59 00:02:09,360 --> 00:02:13,280 that we often use because jax by itself 60 00:02:11,760 --> 00:02:14,720 may not be enough 61 00:02:13,280 --> 00:02:16,640 okay so we're going to start with what 62 00:02:14,720 --> 00:02:18,000 is jax that's that's a big question and 63 00:02:16,640 --> 00:02:20,000 i'm actually going to start a little bit 64 00:02:18,000 --> 00:02:22,400 with like an overly simplified view of 65 00:02:20,000 --> 00:02:23,840 machine learning because for a you know 66 00:02:22,400 --> 00:02:26,400 machine learning is a very very broad 67 00:02:23,840 --> 00:02:27,760 field but for a lot of the work we do it 68 00:02:26,400 --> 00:02:29,520 does sort of boil down to these 69 00:02:27,760 --> 00:02:31,680 supervised learning type problems where 70 00:02:29,520 --> 00:02:33,200 we have some set of input that we're 71 00:02:31,680 --> 00:02:34,560 trying to map to some output and we're 72 00:02:33,200 --> 00:02:36,640 really interested in these functions 73 00:02:34,560 --> 00:02:38,800 we're trying to learn these functions f 74 00:02:36,640 --> 00:02:40,959 and even though it's not you know the 75 00:02:38,800 --> 00:02:42,480 complete domain of machine learning 76 00:02:40,959 --> 00:02:44,080 a lot of a lot of supervised learning 77 00:02:42,480 --> 00:02:45,040 uses gradient descent these days it's a 78 00:02:44,080 --> 00:02:47,040 big thing 79 00:02:45,040 --> 00:02:49,519 so there's a couple of things we want 80 00:02:47,040 --> 00:02:52,000 out of a tool like jax for 81 00:02:49,519 --> 00:02:53,680 gradient descent based methods and the 82 00:02:52,000 --> 00:02:54,879 first one fundamentally is gradients we 83 00:02:53,680 --> 00:02:56,720 need to 84 00:02:54,879 --> 00:02:59,040 get these functions and and calculate 85 00:02:56,720 --> 00:03:01,120 what gradients are sort of you know 86 00:02:59,040 --> 00:03:02,959 described by these functions and 87 00:03:01,120 --> 00:03:04,879 you know you really can think of jax as 88 00:03:02,959 --> 00:03:06,720 the next generation of a very old tool 89 00:03:04,879 --> 00:03:08,480 called autograd maybe not very old it's 90 00:03:06,720 --> 00:03:09,680 about what 10 years old now 91 00:03:08,480 --> 00:03:11,840 and because there are some of the 92 00:03:09,680 --> 00:03:13,040 original authors of autograd which you 93 00:03:11,840 --> 00:03:14,480 know really pioneered a bunch of 94 00:03:13,040 --> 00:03:16,720 symbolic stuff are actually the authors 95 00:03:14,480 --> 00:03:19,440 of jax as well so it's fair to say that 96 00:03:16,720 --> 00:03:21,840 jax is the next generation of autograd 97 00:03:19,440 --> 00:03:23,360 and um you know we we need things to go 98 00:03:21,840 --> 00:03:24,959 really fast that's something you know 99 00:03:23,360 --> 00:03:28,720 very data intensive very compute 100 00:03:24,959 --> 00:03:30,480 intensive and jax is unashamedly 101 00:03:28,720 --> 00:03:32,480 just a really great python interface for 102 00:03:30,480 --> 00:03:34,560 another tool called xla so this is the 103 00:03:32,480 --> 00:03:36,080 accelerated linear algebra compiler 104 00:03:34,560 --> 00:03:37,680 which um was developed by google and 105 00:03:36,080 --> 00:03:39,280 we'll talk a little bit more about this 106 00:03:37,680 --> 00:03:42,239 but you know it's not unfair to say that 107 00:03:39,280 --> 00:03:44,159 jax is just like a front end for xla 108 00:03:42,239 --> 00:03:46,159 because to be honest a lot of the heavy 109 00:03:44,159 --> 00:03:48,640 lifting is done by this compiler tech 110 00:03:46,159 --> 00:03:51,040 which jax just provides a very 111 00:03:48,640 --> 00:03:52,159 elegant python interface for 112 00:03:51,040 --> 00:03:54,319 okay so let's talk about these two 113 00:03:52,159 --> 00:03:56,959 things but before we get to those i just 114 00:03:54,319 --> 00:03:59,120 want to say that if nothing else 115 00:03:56,959 --> 00:04:02,000 the takeaway from this talk is that 116 00:03:59,120 --> 00:04:03,920 jax is a good implementation of the 117 00:04:02,000 --> 00:04:05,920 numpy api so 118 00:04:03,920 --> 00:04:07,360 if i think about you know some numpy 119 00:04:05,920 --> 00:04:09,519 stuff we might want to do here's a 120 00:04:07,360 --> 00:04:11,280 arbitrary calculation i'm describing 121 00:04:09,519 --> 00:04:13,439 where i've got a dot product between a 122 00:04:11,280 --> 00:04:14,480 couple of matrices i'm adding something 123 00:04:13,439 --> 00:04:17,359 we can actually 124 00:04:14,480 --> 00:04:18,959 substitute in rather than using numpy an 125 00:04:17,359 --> 00:04:21,359 implementation of these things that is 126 00:04:18,959 --> 00:04:23,120 provided by jax numpy and you know 127 00:04:21,359 --> 00:04:24,880 i've done this quite a bit i've got some 128 00:04:23,120 --> 00:04:27,360 signal processing work i do as part of 129 00:04:24,880 --> 00:04:29,440 work and that is all described in scipy 130 00:04:27,360 --> 00:04:30,720 and numpy and it's actually i've been 131 00:04:29,440 --> 00:04:33,199 pretty it's been pretty good at just 132 00:04:30,720 --> 00:04:35,840 literally porting it to jax by swapping 133 00:04:33,199 --> 00:04:37,199 the numpy imports to jacks.numpy and 134 00:04:35,840 --> 00:04:38,400 with a couple of caveats and small 135 00:04:37,199 --> 00:04:40,880 changes 136 00:04:38,400 --> 00:04:42,720 i'm able to make my programs run under 137 00:04:40,880 --> 00:04:44,960 jacks instead now 138 00:04:42,720 --> 00:04:46,560 the wh why do i care well all the things 139 00:04:44,960 --> 00:04:48,240 we're going to talk about that come in 140 00:04:46,560 --> 00:04:49,360 this talk about the high performance are 141 00:04:48,240 --> 00:04:51,360 suddenly available to me and they're 142 00:04:49,360 --> 00:04:53,759 just sort of transparent so if nothing 143 00:04:51,360 --> 00:04:55,280 else walk away from this talk right now 144 00:04:53,759 --> 00:04:57,040 with with an experiment you might want 145 00:04:55,280 --> 00:04:58,960 to try where you just literally swap in 146 00:04:57,040 --> 00:05:00,320 jax numpy instead of numpy 147 00:04:58,960 --> 00:05:01,680 but let's talk a little bit more detail 148 00:05:00,320 --> 00:05:03,759 about then some of the things we've got 149 00:05:01,680 --> 00:05:05,600 too so auto grade this is the first 150 00:05:03,759 --> 00:05:07,680 thing i talked about wasn't it here's a 151 00:05:05,600 --> 00:05:10,479 function it's a pretty simple quadratic 152 00:05:07,680 --> 00:05:12,240 2x squared plus 3x plus 3. if if this 153 00:05:10,479 --> 00:05:13,840 was um the complicated machine learning 154 00:05:12,240 --> 00:05:15,600 function i was trying to work with and i 155 00:05:13,840 --> 00:05:18,320 was using gradient descent 156 00:05:15,600 --> 00:05:19,600 i might want to 157 00:05:18,320 --> 00:05:20,720 get some gradients for this function for 158 00:05:19,600 --> 00:05:22,320 something so 159 00:05:20,720 --> 00:05:25,360 the first thing i want to introduce from 160 00:05:22,320 --> 00:05:28,320 jax is is a little function called grad 161 00:05:25,360 --> 00:05:30,800 and what grad gives us is um a gradient 162 00:05:28,320 --> 00:05:32,800 of a function so grad takes a function 163 00:05:30,800 --> 00:05:35,440 and returns another function 164 00:05:32,800 --> 00:05:37,759 where uh when i call this this this g of 165 00:05:35,440 --> 00:05:39,840 f i get the the gradient with respect to 166 00:05:37,759 --> 00:05:42,000 the the parameter there so let's see 2x 167 00:05:39,840 --> 00:05:45,680 squared i bring down the two that's 4x 168 00:05:42,000 --> 00:05:47,840 plus um three so uh the gradient uh at 169 00:05:45,680 --> 00:05:49,759 three oh sorry the gradient at three is 170 00:05:47,840 --> 00:05:51,440 15. so this this is good 171 00:05:49,759 --> 00:05:52,400 now this illustrates a real key point 172 00:05:51,440 --> 00:05:54,800 about how 173 00:05:52,400 --> 00:05:57,600 jax fundamentally works this thing that 174 00:05:54,800 --> 00:05:59,840 jax has given us grad takes a function 175 00:05:57,600 --> 00:06:01,280 and returns another function with the 176 00:05:59,840 --> 00:06:02,560 same signature 177 00:06:01,280 --> 00:06:04,720 so we're going to see this time and time 178 00:06:02,560 --> 00:06:07,120 again that the main thing that jax does 179 00:06:04,720 --> 00:06:09,520 is these functional transformations 180 00:06:07,120 --> 00:06:10,960 jax looks at functions and gives us new 181 00:06:09,520 --> 00:06:13,600 functions back that are in some way 182 00:06:10,960 --> 00:06:16,319 changed based on what we wanted to do 183 00:06:13,600 --> 00:06:18,160 and how does jax do this well it's based 184 00:06:16,319 --> 00:06:20,160 on an idea called tracing a fundamental 185 00:06:18,160 --> 00:06:22,800 idea so here's another thing that jax 186 00:06:20,160 --> 00:06:25,039 has it's these make expressions and what 187 00:06:22,800 --> 00:06:26,960 i can do is i can say 188 00:06:25,039 --> 00:06:29,440 here's my function f again i'm going to 189 00:06:26,960 --> 00:06:30,639 pass that to jax to say well make me a 190 00:06:29,440 --> 00:06:33,120 new version of this function that's 191 00:06:30,639 --> 00:06:34,880 called trace trace f when i call this 192 00:06:33,120 --> 00:06:36,960 function it doesn't evaluate it directly 193 00:06:34,880 --> 00:06:39,199 the return of this function is actually 194 00:06:36,960 --> 00:06:41,280 jax's view on what is actually happening 195 00:06:39,199 --> 00:06:43,199 inside this function so you can see here 196 00:06:41,280 --> 00:06:44,880 we've got this sort of symbolic 197 00:06:43,199 --> 00:06:46,240 representation of the compute and we 198 00:06:44,880 --> 00:06:47,600 don't need to dive too much into this 199 00:06:46,240 --> 00:06:49,120 but you can see there's some things 200 00:06:47,600 --> 00:06:50,560 being multiplied and some stuff being 201 00:06:49,120 --> 00:06:52,319 added 202 00:06:50,560 --> 00:06:54,639 and so jax is really about 203 00:06:52,319 --> 00:06:56,560 doing these things by tracing by tracing 204 00:06:54,639 --> 00:06:58,319 this function and having a symbolic view 205 00:06:56,560 --> 00:07:00,400 of things we get a bit of insight into 206 00:06:58,319 --> 00:07:02,560 how the gradient is being done so if i 207 00:07:00,400 --> 00:07:04,240 do that same tracing again this time on 208 00:07:02,560 --> 00:07:05,919 the gradients i can see a different 209 00:07:04,240 --> 00:07:08,880 function and so this is fundamentally 210 00:07:05,919 --> 00:07:10,560 about how autograd works we trace these 211 00:07:08,880 --> 00:07:12,560 functions we do this forward pass this 212 00:07:10,560 --> 00:07:14,840 symbolic representation which is then 213 00:07:12,560 --> 00:07:17,120 manipulated to turn into a 214 00:07:14,840 --> 00:07:18,080 gradient now there's a really and again 215 00:07:17,120 --> 00:07:20,160 we don't need to talk about the 216 00:07:18,080 --> 00:07:22,560 specifics here there's some numbers and 217 00:07:20,160 --> 00:07:24,160 complicated ops floating around but 218 00:07:22,560 --> 00:07:26,880 there's a subtle point here i want to 219 00:07:24,160 --> 00:07:29,199 make this is interesting about functions 220 00:07:26,880 --> 00:07:31,440 this make jax expression takes a 221 00:07:29,199 --> 00:07:32,639 function as does grad which takes a 222 00:07:31,440 --> 00:07:34,319 function so you can see here we've 223 00:07:32,639 --> 00:07:36,560 already got this nice composability i 224 00:07:34,319 --> 00:07:38,960 take my function f i wrap it with this 225 00:07:36,560 --> 00:07:40,639 grad which mutates f to be the gradient 226 00:07:38,960 --> 00:07:42,960 i then can then pass that to jack's 227 00:07:40,639 --> 00:07:44,319 expression so jax is really about these 228 00:07:42,960 --> 00:07:46,720 building blocks these building blocks 229 00:07:44,319 --> 00:07:48,960 which take functions and and give us new 230 00:07:46,720 --> 00:07:50,879 functions that we can compose in 231 00:07:48,960 --> 00:07:52,560 interesting ways so this composition of 232 00:07:50,879 --> 00:07:54,560 functions is a really big important 233 00:07:52,560 --> 00:07:56,080 thing 234 00:07:54,560 --> 00:07:58,400 okay so the second part we talked about 235 00:07:56,080 --> 00:07:59,840 was going fast and the first 236 00:07:58,400 --> 00:08:02,400 thing that jax provides us is some 237 00:07:59,840 --> 00:08:05,440 just-in-time compilation so again here's 238 00:08:02,400 --> 00:08:07,280 our function 2x squared plus 3x 3. 239 00:08:05,440 --> 00:08:09,599 very interesting one thing that jax 240 00:08:07,280 --> 00:08:12,080 gives us is a jit so we can take this 241 00:08:09,599 --> 00:08:14,319 function again rf function and we can 242 00:08:12,080 --> 00:08:16,479 just in time compile it to get a new 243 00:08:14,319 --> 00:08:18,160 version of the function that is 244 00:08:16,479 --> 00:08:20,000 that has the same signature and can be 245 00:08:18,160 --> 00:08:22,080 run with the same values and give us 246 00:08:20,000 --> 00:08:24,560 back the same result so in many ways 247 00:08:22,080 --> 00:08:25,919 this is the most anti-climatic slide 248 00:08:24,560 --> 00:08:27,039 ever ever written because it looks 249 00:08:25,919 --> 00:08:28,400 pretty boring but there's a lot of 250 00:08:27,039 --> 00:08:30,639 really interesting stuff that happens 251 00:08:28,400 --> 00:08:32,719 under the hood with this jit and 252 00:08:30,639 --> 00:08:34,719 fundamentally an important part of this 253 00:08:32,719 --> 00:08:36,640 is the fact that um this this 254 00:08:34,719 --> 00:08:37,760 computation that's being done here would 255 00:08:36,640 --> 00:08:39,200 now be being done on whatever 256 00:08:37,760 --> 00:08:41,519 accelerator i have so even though i've 257 00:08:39,200 --> 00:08:42,719 described this compute in terms of just 258 00:08:41,519 --> 00:08:45,600 just vanilla 259 00:08:42,719 --> 00:08:48,480 python or using that jack's numpy if i 260 00:08:45,600 --> 00:08:50,880 had a gpu this would be running just by 261 00:08:48,480 --> 00:08:52,000 itself immediately on my gpu which is an 262 00:08:50,880 --> 00:08:53,440 important thing about how our 263 00:08:52,000 --> 00:08:56,640 compilation works 264 00:08:53,440 --> 00:08:58,800 for for jit and how does jit work well 265 00:08:56,640 --> 00:09:00,480 again i can talk about looking at the 266 00:08:58,800 --> 00:09:01,839 the jacks expression what is actually 267 00:09:00,480 --> 00:09:03,040 what has happened when i've called this 268 00:09:01,839 --> 00:09:04,640 chit function 269 00:09:03,040 --> 00:09:06,000 we've got our f like we had before and 270 00:09:04,640 --> 00:09:07,519 it doesn't really matter what this f is 271 00:09:06,000 --> 00:09:08,959 it's it's quite simple here but this 272 00:09:07,519 --> 00:09:10,720 could be a very very complicated you 273 00:09:08,959 --> 00:09:11,600 know 100 layer neural network or 274 00:09:10,720 --> 00:09:13,120 something 275 00:09:11,600 --> 00:09:14,720 and what's subtly different here is now 276 00:09:13,120 --> 00:09:17,760 jax is saying i'm going to actually make 277 00:09:14,720 --> 00:09:19,600 a call out to this tool xla so what jax 278 00:09:17,760 --> 00:09:21,360 has done here is wrapped up 279 00:09:19,600 --> 00:09:23,440 had a look at this function run through 280 00:09:21,360 --> 00:09:25,040 it built this trace and then it's 281 00:09:23,440 --> 00:09:27,360 passing it to this compiler under the 282 00:09:25,040 --> 00:09:28,560 hood and that compiler is is very high 283 00:09:27,360 --> 00:09:30,480 performance you can do all the classic 284 00:09:28,560 --> 00:09:32,000 things we want from a compiler 285 00:09:30,480 --> 00:09:34,000 a couple of key things like you know 286 00:09:32,000 --> 00:09:35,920 fusing ops together or reordering things 287 00:09:34,000 --> 00:09:37,040 whatever it wants to do but also most 288 00:09:35,920 --> 00:09:38,800 importantly 289 00:09:37,040 --> 00:09:40,480 uh converting things to kernels that are 290 00:09:38,800 --> 00:09:42,320 going to run on the device i have so 291 00:09:40,480 --> 00:09:44,480 this would natively run 292 00:09:42,320 --> 00:09:46,320 xla will make sure that it's the compute 293 00:09:44,480 --> 00:09:47,760 is then done on my gpu or whatever i 294 00:09:46,320 --> 00:09:48,880 might have tpus and stuff we'll talk 295 00:09:47,760 --> 00:09:50,800 about later 296 00:09:48,880 --> 00:09:52,240 so that it it it looks like it's such a 297 00:09:50,800 --> 00:09:54,160 simple function but it's doing so much 298 00:09:52,240 --> 00:09:55,920 and it's incredible how much it does 299 00:09:54,160 --> 00:09:57,200 so that's a fundamental thing that we'll 300 00:09:55,920 --> 00:09:58,800 be dealing with 301 00:09:57,200 --> 00:10:01,120 another another really important thing 302 00:09:58,800 --> 00:10:02,399 with um making things go fast that we're 303 00:10:01,120 --> 00:10:03,760 probably all familiar with is 304 00:10:02,399 --> 00:10:05,600 vectorization 305 00:10:03,760 --> 00:10:07,120 so here's a here's a function that looks 306 00:10:05,600 --> 00:10:08,560 a bit scary but um 307 00:10:07,120 --> 00:10:10,800 it's a bit arbitrary i just sort of made 308 00:10:08,560 --> 00:10:12,320 this function up it's some function that 309 00:10:10,800 --> 00:10:14,160 does some it does some weird stuff don't 310 00:10:12,320 --> 00:10:16,560 try to reason too much about what this 311 00:10:14,160 --> 00:10:18,560 what this is actually doing it's got a 312 00:10:16,560 --> 00:10:21,279 couple of matrices it does some you know 313 00:10:18,560 --> 00:10:22,880 along axis mins and maxes it does a dot 314 00:10:21,279 --> 00:10:25,200 product of things and then it added adds 315 00:10:22,880 --> 00:10:26,480 a scalar like this doesn't actually mean 316 00:10:25,200 --> 00:10:28,240 anything i just wanted to say that 317 00:10:26,480 --> 00:10:29,279 there's some you know compute you might 318 00:10:28,240 --> 00:10:30,560 want to do 319 00:10:29,279 --> 00:10:31,839 and here's the result we get down the 320 00:10:30,560 --> 00:10:33,440 bottom now if 321 00:10:31,839 --> 00:10:35,120 if we wanted to run this function over 322 00:10:33,440 --> 00:10:36,800 lots and lots of data the common way we 323 00:10:35,120 --> 00:10:39,040 would do things is to vectorize stuff 324 00:10:36,800 --> 00:10:40,480 and by vectorizing our mean we normally 325 00:10:39,040 --> 00:10:42,959 think about this idea of adding a 326 00:10:40,480 --> 00:10:44,959 leading dimension into our data 327 00:10:42,959 --> 00:10:46,480 so that when we pass it to the compute a 328 00:10:44,959 --> 00:10:48,560 lot of the linear algebra libraries 329 00:10:46,480 --> 00:10:50,079 underneath will work a lot faster if you 330 00:10:48,560 --> 00:10:51,519 can do things in batching 331 00:10:50,079 --> 00:10:53,200 now you know we're probably all familiar 332 00:10:51,519 --> 00:10:54,640 with this idea that when we do add these 333 00:10:53,200 --> 00:10:56,320 leading dimensions 334 00:10:54,640 --> 00:10:58,240 we just have to make sure that our code 335 00:10:56,320 --> 00:11:00,240 is aware of those leading dimensions and 336 00:10:58,240 --> 00:11:02,399 can sort of handle them so if i was to 337 00:11:00,240 --> 00:11:03,600 vectorize myself this function i would 338 00:11:02,399 --> 00:11:05,760 go through and just make sure a couple 339 00:11:03,600 --> 00:11:07,200 of things hold so i can see these axes 340 00:11:05,760 --> 00:11:08,560 here well now that i've got another 341 00:11:07,200 --> 00:11:10,560 dimension a batch dimension i have to 342 00:11:08,560 --> 00:11:13,040 make sure these axes sort of move around 343 00:11:10,560 --> 00:11:14,959 i've got a dot product here and i know 344 00:11:13,040 --> 00:11:16,640 by experience that the dot product is 345 00:11:14,959 --> 00:11:18,240 okay with having leading dimensions so i 346 00:11:16,640 --> 00:11:20,320 don't need to change anything there and 347 00:11:18,240 --> 00:11:21,839 then i've got an addition of a scalar so 348 00:11:20,320 --> 00:11:23,519 i might i might be a little bit worried 349 00:11:21,839 --> 00:11:24,880 and think what's the broadcasting going 350 00:11:23,519 --> 00:11:27,440 to be i'm going to make a mistake about 351 00:11:24,880 --> 00:11:29,120 something here but you know i think if 352 00:11:27,440 --> 00:11:31,279 you've done a bunch of stuff with numpy 353 00:11:29,120 --> 00:11:33,279 or any other sort of equivalent compute 354 00:11:31,279 --> 00:11:35,440 you sort of get a sense of knowing how 355 00:11:33,279 --> 00:11:36,399 to modify this function to make a batch 356 00:11:35,440 --> 00:11:37,839 aware 357 00:11:36,399 --> 00:11:39,519 but we don't need to do this anymore one 358 00:11:37,839 --> 00:11:40,959 of the things that jax provides is this 359 00:11:39,519 --> 00:11:42,800 awesome tool 360 00:11:40,959 --> 00:11:45,040 called vmap and what is vmap again it's 361 00:11:42,800 --> 00:11:47,600 a transformation of functions i take my 362 00:11:45,040 --> 00:11:50,240 function f i run it through vmap and jax 363 00:11:47,600 --> 00:11:51,680 returns me a new version of my function 364 00:11:50,240 --> 00:11:54,880 and what it does is actually by using 365 00:11:51,680 --> 00:11:57,440 this tracing idea re-plumbs the compute 366 00:11:54,880 --> 00:11:59,600 that is expressed by f to automatically 367 00:11:57,440 --> 00:12:01,040 handle the the leading batch so now if i 368 00:11:59,600 --> 00:12:03,279 have this vectorized f i can call 369 00:12:01,040 --> 00:12:05,120 through with an a b and c which has now 370 00:12:03,279 --> 00:12:06,800 got this leading dimension so i've got a 371 00:12:05,120 --> 00:12:08,399 i've added an extra two here basically 372 00:12:06,800 --> 00:12:09,360 and i've turned this scalar into a short 373 00:12:08,399 --> 00:12:11,120 vector 374 00:12:09,360 --> 00:12:12,639 now for that particular f we were just 375 00:12:11,120 --> 00:12:14,480 looking at it's not terribly exciting 376 00:12:12,639 --> 00:12:16,240 and interesting but as things get more 377 00:12:14,480 --> 00:12:17,920 and more complicated this idea that jax 378 00:12:16,240 --> 00:12:19,519 will be responsible for the plumbing of 379 00:12:17,920 --> 00:12:21,040 the of the leading dimension is is a 380 00:12:19,519 --> 00:12:23,360 really big thing and it can make this 381 00:12:21,040 --> 00:12:25,120 huge difference here's an example 382 00:12:23,360 --> 00:12:26,880 where i where i use this which i thought 383 00:12:25,120 --> 00:12:28,880 was really interesting i had i had some 384 00:12:26,880 --> 00:12:31,200 data this is some satellite imaging and 385 00:12:28,880 --> 00:12:33,440 i wanted to basically black out a bunch 386 00:12:31,200 --> 00:12:34,880 of augmentations of that image so you 387 00:12:33,440 --> 00:12:36,639 know with satellite imaging one of the 388 00:12:34,880 --> 00:12:37,680 classic augmentations we can do is to 389 00:12:36,639 --> 00:12:40,320 you know 390 00:12:37,680 --> 00:12:41,760 sets of 90 degree flips rotation sorry 391 00:12:40,320 --> 00:12:43,680 and then potentially left and right 392 00:12:41,760 --> 00:12:46,399 flips so i've made this little augment 393 00:12:43,680 --> 00:12:48,240 it rotates 90 degrees by some 394 00:12:46,399 --> 00:12:50,079 number of the 90 degree turns and then 395 00:12:48,240 --> 00:12:51,760 it either flips or doesn't flip it's 396 00:12:50,079 --> 00:12:52,399 like a zero one so here's an example so 397 00:12:51,760 --> 00:12:54,800 i 398 00:12:52,399 --> 00:12:57,120 pass an image in i've said rotate one a 399 00:12:54,800 --> 00:12:59,519 lot of 90 degrees and don't flip so 400 00:12:57,120 --> 00:13:02,639 basically i get this image down here 401 00:12:59,519 --> 00:13:04,240 i can use vmap to express the compute 402 00:13:02,639 --> 00:13:05,680 that i want to do all of these in 403 00:13:04,240 --> 00:13:07,600 parallel so what i'm going to do is i'm 404 00:13:05,680 --> 00:13:10,480 going to say here's my augment function 405 00:13:07,600 --> 00:13:13,440 and what does vmap do it adds an extra 406 00:13:10,480 --> 00:13:15,360 front dimension basically so like like a 407 00:13:13,440 --> 00:13:17,040 leading dimension to all the arguments 408 00:13:15,360 --> 00:13:20,160 so i'm going to say here's my rotations 409 00:13:17,040 --> 00:13:22,079 i want to do 0 1 2 3 one two three and i 410 00:13:20,160 --> 00:13:24,480 want to pair those zip those together 411 00:13:22,079 --> 00:13:26,079 with all the possible combos of flips 412 00:13:24,480 --> 00:13:28,000 so you can see these eight these eight 413 00:13:26,079 --> 00:13:29,200 pairings here between these two 414 00:13:28,000 --> 00:13:30,480 correspond to all the possible 415 00:13:29,200 --> 00:13:31,680 augmentations 416 00:13:30,480 --> 00:13:33,600 so what i'm going to do is i'm going to 417 00:13:31,680 --> 00:13:36,160 say i want jax to make a version of 418 00:13:33,600 --> 00:13:37,760 augment that handles batching but one 419 00:13:36,160 --> 00:13:40,000 little thing that i've said here is that 420 00:13:37,760 --> 00:13:42,399 i only want you jax to 421 00:13:40,000 --> 00:13:44,320 add that leading dimension to the second 422 00:13:42,399 --> 00:13:46,399 and the third argument so this is this 423 00:13:44,320 --> 00:13:48,639 rotation and flip but i don't want you 424 00:13:46,399 --> 00:13:50,639 to plumb that extra dimension for the 425 00:13:48,639 --> 00:13:51,839 first argument so i've said none here 426 00:13:50,639 --> 00:13:54,240 and what jax will do is basically 427 00:13:51,839 --> 00:13:56,160 broadcast that image multiple times now 428 00:13:54,240 --> 00:13:57,760 when i call this function i get this 429 00:13:56,160 --> 00:14:00,079 result because what's happened here 430 00:13:57,760 --> 00:14:01,680 under the hood is that jax is zipping 431 00:14:00,079 --> 00:14:03,120 basically across these 432 00:14:01,680 --> 00:14:04,399 and and it's going all the way through 433 00:14:03,120 --> 00:14:06,320 it's going all the way through down to 434 00:14:04,399 --> 00:14:08,000 these functions to make them vectorized 435 00:14:06,320 --> 00:14:09,199 to do that which is what which is really 436 00:14:08,000 --> 00:14:11,199 cool i think 437 00:14:09,199 --> 00:14:13,920 so now when we call this image 438 00:14:11,199 --> 00:14:15,680 uh combos or combos augment rather than 439 00:14:13,920 --> 00:14:18,000 just getting one image back we get eight 440 00:14:15,680 --> 00:14:20,160 images back now i might actually also 441 00:14:18,000 --> 00:14:22,240 want to do this whole thing in batch as 442 00:14:20,160 --> 00:14:24,320 well so i want to actually start with a 443 00:14:22,240 --> 00:14:26,320 set of images not just one image but i 444 00:14:24,320 --> 00:14:28,000 might have like ten images 445 00:14:26,320 --> 00:14:30,720 and i want all eight combos for those 446 00:14:28,000 --> 00:14:32,320 ten so i've got my all combos augment 447 00:14:30,720 --> 00:14:33,839 which was v-mapped i'll just v-map it 448 00:14:32,320 --> 00:14:35,440 again because what does v-map do it 449 00:14:33,839 --> 00:14:37,440 takes an arbitrary function it just adds 450 00:14:35,440 --> 00:14:38,959 another batch dimension at the front and 451 00:14:37,440 --> 00:14:40,959 because i know that this is sort of the 452 00:14:38,959 --> 00:14:42,480 level of compute i want to deal with i 453 00:14:40,959 --> 00:14:44,720 might ask 454 00:14:42,480 --> 00:14:46,399 jax at this point also jit so go through 455 00:14:44,720 --> 00:14:48,959 and now take this compute that's been 456 00:14:46,399 --> 00:14:51,120 expressed by a v map over a v map with 457 00:14:48,959 --> 00:14:52,720 these with these sort of fixed values 458 00:14:51,120 --> 00:14:54,800 and do whatever you need to do to 459 00:14:52,720 --> 00:14:56,880 compile it to make it really fast 460 00:14:54,800 --> 00:14:59,120 and that this is blindingly fast the the 461 00:14:56,880 --> 00:15:00,880 code it ends up being resulting runs 462 00:14:59,120 --> 00:15:02,079 very very quickly and this is a pattern 463 00:15:00,880 --> 00:15:03,839 i've seen a number of times now i have 464 00:15:02,079 --> 00:15:05,519 this sort of you know if i'm v mapping 465 00:15:03,839 --> 00:15:07,120 something and i have extra batch 466 00:15:05,519 --> 00:15:08,800 dimensions appear i don't need to be 467 00:15:07,120 --> 00:15:10,880 scared i just just keep wrapping with v 468 00:15:08,800 --> 00:15:12,880 maps and everything is fine so that's 469 00:15:10,880 --> 00:15:14,959 really cool so fundamentally vmap does 470 00:15:12,880 --> 00:15:17,199 this by re-plumbing and sort of 471 00:15:14,959 --> 00:15:19,120 completely re-changing what my function 472 00:15:17,199 --> 00:15:21,040 is very handy 473 00:15:19,120 --> 00:15:24,240 a quick side note on 474 00:15:21,040 --> 00:15:25,360 on tpu so tpus are google's asic custom 475 00:15:24,240 --> 00:15:27,199 designed 476 00:15:25,360 --> 00:15:29,279 sort of replacement for their in-house 477 00:15:27,199 --> 00:15:31,279 gpu farms 478 00:15:29,279 --> 00:15:33,920 tpus you know come in a couple of 479 00:15:31,279 --> 00:15:36,240 different sort of sort of models the the 480 00:15:33,920 --> 00:15:38,399 third generation the smallest tpu board 481 00:15:36,240 --> 00:15:40,639 i can get if you you know renting one 482 00:15:38,399 --> 00:15:42,639 basically from the cloud is this our v3 483 00:15:40,639 --> 00:15:43,920 8 and 8 corresponds to the fact that a 484 00:15:42,639 --> 00:15:46,480 tpu 485 00:15:43,920 --> 00:15:48,240 board like this has four chips each of 486 00:15:46,480 --> 00:15:50,720 the chips has two cores so i've got 487 00:15:48,240 --> 00:15:51,839 eight cores in total 488 00:15:50,720 --> 00:15:53,519 what's interesting about this is when 489 00:15:51,839 --> 00:15:55,040 you rent these tpus that's the minimum 490 00:15:53,519 --> 00:15:56,959 you get you've got no choice each of 491 00:15:55,040 --> 00:15:58,240 these eight cores is pretty powerful but 492 00:15:56,959 --> 00:16:00,240 you can't you can't go any lower than 493 00:15:58,240 --> 00:16:02,320 that so that's interesting and and 494 00:16:00,240 --> 00:16:03,440 google provides these tpus in these sort 495 00:16:02,320 --> 00:16:05,759 of you know 496 00:16:03,440 --> 00:16:07,600 boards like this or in these huge fat 497 00:16:05,759 --> 00:16:09,839 combinations of boards with this crazy 498 00:16:07,600 --> 00:16:12,320 interconnect uh where they that's called 499 00:16:09,839 --> 00:16:14,720 a pod where you have um hundreds of 500 00:16:12,320 --> 00:16:17,279 these tpu boards put together to to make 501 00:16:14,720 --> 00:16:18,560 effectively 2048 calls now 502 00:16:17,279 --> 00:16:20,399 the reason i bring this up is because 503 00:16:18,560 --> 00:16:22,639 the development of the tpus was done 504 00:16:20,399 --> 00:16:24,240 hand in hand with the development of xla 505 00:16:22,639 --> 00:16:26,240 which is the compilation technology we 506 00:16:24,240 --> 00:16:28,639 talked about xla was sort of introduced 507 00:16:26,240 --> 00:16:30,639 at the same time as an intermediate 508 00:16:28,639 --> 00:16:33,440 between the code being run either to 509 00:16:30,639 --> 00:16:35,360 gpus or tpus so xla is very much around 510 00:16:33,440 --> 00:16:38,959 making high performance code for 511 00:16:35,360 --> 00:16:41,360 tpus and jax is in many ways a front end 512 00:16:38,959 --> 00:16:43,759 for xla so we're going to get a lot of 513 00:16:41,360 --> 00:16:45,680 really a huge benefit from using jacks 514 00:16:43,759 --> 00:16:47,519 very very simply across these very large 515 00:16:45,680 --> 00:16:48,880 compute clusters and and it's quite a 516 00:16:47,519 --> 00:16:50,399 standing way you can do nowadays that 517 00:16:48,880 --> 00:16:53,839 even like a couple of years ago was 518 00:16:50,399 --> 00:16:55,360 quite difficult to sort of orchestrate 519 00:16:53,839 --> 00:16:57,120 so i'm going to introduce another 520 00:16:55,360 --> 00:16:58,240 concept called pmap so this is 521 00:16:57,120 --> 00:17:00,720 parallelization rather than 522 00:16:58,240 --> 00:17:02,160 vectorization and you know if if you go 523 00:17:00,720 --> 00:17:04,079 right now to google colab and you start 524 00:17:02,160 --> 00:17:06,640 a collab you're allowed to just get a 525 00:17:04,079 --> 00:17:09,360 free tpu so if you sort of you know pick 526 00:17:06,640 --> 00:17:11,199 device tpu and you just sort of import 527 00:17:09,360 --> 00:17:13,039 jacks maybe there's some stuff you need 528 00:17:11,199 --> 00:17:15,600 to do but it's basically import jacks 529 00:17:13,039 --> 00:17:17,120 you'll get jax's has a view over these 530 00:17:15,600 --> 00:17:19,199 eight cores 531 00:17:17,120 --> 00:17:21,039 so rather than just single one we've got 532 00:17:19,199 --> 00:17:22,559 eight and jax provides us some tools for 533 00:17:21,039 --> 00:17:24,480 making use of this 534 00:17:22,559 --> 00:17:26,160 here's an example of how pmap works 535 00:17:24,480 --> 00:17:28,400 here's a function it's very simple again 536 00:17:26,160 --> 00:17:31,280 abc just we'll just add them up and 537 00:17:28,400 --> 00:17:32,960 we're going to p map this function 538 00:17:31,280 --> 00:17:34,400 like we did with the v map and pmap 539 00:17:32,960 --> 00:17:36,400 takes a function and it gives us a new 540 00:17:34,400 --> 00:17:39,360 version of this function that can run 541 00:17:36,400 --> 00:17:40,880 and um it's again this idea is that 542 00:17:39,360 --> 00:17:43,280 there's a leading dimension which is 543 00:17:40,880 --> 00:17:45,600 being put on top of things so if i have 544 00:17:43,280 --> 00:17:48,080 some a and b which are eight by three 545 00:17:45,600 --> 00:17:50,400 and i'm on this tpu 546 00:17:48,080 --> 00:17:52,400 machine which has eight cores then when 547 00:17:50,400 --> 00:17:54,880 i run this compute i'm actually running 548 00:17:52,400 --> 00:17:57,360 this this computation 549 00:17:54,880 --> 00:17:58,880 in parallel each of these rows being run 550 00:17:57,360 --> 00:18:00,960 on its own device 551 00:17:58,880 --> 00:18:03,280 so what what jax is doing here is it's a 552 00:18:00,960 --> 00:18:06,160 little bit different to vmap vmap we 553 00:18:03,280 --> 00:18:07,840 actually re-plumb the code pmap actually 554 00:18:06,160 --> 00:18:09,520 takes this function and actually does a 555 00:18:07,840 --> 00:18:13,120 compilation of it and ships it out to 556 00:18:09,520 --> 00:18:16,000 the devices we have so if i have this 557 00:18:13,120 --> 00:18:18,160 the v38 which has eight devices when i 558 00:18:16,000 --> 00:18:20,080 call this pmap i'm actually 559 00:18:18,160 --> 00:18:22,400 compiling this code 560 00:18:20,080 --> 00:18:23,840 sharding this this input and actually 561 00:18:22,400 --> 00:18:25,840 sending it out to the devices and all 562 00:18:23,840 --> 00:18:28,000 the devices do things in parallel and 563 00:18:25,840 --> 00:18:29,520 it's not until i actually instantiate it 564 00:18:28,000 --> 00:18:31,919 in some way in this case printing it to 565 00:18:29,520 --> 00:18:33,679 my screen that the result is actually 566 00:18:31,919 --> 00:18:35,280 brought back from those devices and 567 00:18:33,679 --> 00:18:37,360 assembled and that's why we see this 568 00:18:35,280 --> 00:18:38,799 here sharded device array 569 00:18:37,360 --> 00:18:40,960 because until i actually do something 570 00:18:38,799 --> 00:18:43,200 and act on it it turns out that actually 571 00:18:40,960 --> 00:18:45,520 each of these rows lives in the memory 572 00:18:43,200 --> 00:18:47,280 of the device so this is this is really 573 00:18:45,520 --> 00:18:49,520 important because if you've done much 574 00:18:47,280 --> 00:18:50,960 with like multi machines you you really 575 00:18:49,520 --> 00:18:52,799 need to make sure that the the data 576 00:18:50,960 --> 00:18:53,919 being shipped around versus the code 577 00:18:52,799 --> 00:18:55,760 being shipped around is under your 578 00:18:53,919 --> 00:18:56,960 control because sometimes you don't want 579 00:18:55,760 --> 00:18:58,400 to just be 580 00:18:56,960 --> 00:19:00,080 you know sort of wasting time moving 581 00:18:58,400 --> 00:19:02,240 memory around when you don't need to so 582 00:19:00,080 --> 00:19:04,880 pmap gives us a lot of control over how 583 00:19:02,240 --> 00:19:06,880 we're sort of sharding data versus the 584 00:19:04,880 --> 00:19:08,480 actual compute 585 00:19:06,880 --> 00:19:10,000 and 586 00:19:08,480 --> 00:19:11,760 it gives us actually more than that as 587 00:19:10,000 --> 00:19:13,200 well it gives us a bunch of collective 588 00:19:11,760 --> 00:19:14,400 operators so i'll go through this one 589 00:19:13,200 --> 00:19:15,600 pretty quickly it doesn't you know 590 00:19:14,400 --> 00:19:17,120 there's a couple of details here we 591 00:19:15,600 --> 00:19:18,640 don't care too much about but if i'm 592 00:19:17,120 --> 00:19:20,400 doing some machine learning over a large 593 00:19:18,640 --> 00:19:22,799 number of machines so again this this 594 00:19:20,400 --> 00:19:24,559 huge cluster of 2000 machines 595 00:19:22,799 --> 00:19:26,480 i might have some sort of loss function 596 00:19:24,559 --> 00:19:28,880 where you know i'm passing in parameters 597 00:19:26,480 --> 00:19:30,320 and i've got my x and y 598 00:19:28,880 --> 00:19:31,840 reintroducing some of the concepts we've 599 00:19:30,320 --> 00:19:33,280 talked about i might want to calculate 600 00:19:31,840 --> 00:19:34,640 gradients 601 00:19:33,280 --> 00:19:36,160 for these parameters with respect to 602 00:19:34,640 --> 00:19:38,000 this input and output and how would i do 603 00:19:36,160 --> 00:19:39,679 that well i can take my loss function i 604 00:19:38,000 --> 00:19:41,600 can ask 605 00:19:39,679 --> 00:19:43,360 jax to give us the gradients 606 00:19:41,600 --> 00:19:46,000 and i can ask jax to please make that 607 00:19:43,360 --> 00:19:48,960 fast in whatever way is required and at 608 00:19:46,000 --> 00:19:51,200 this point if i've pmped this function 609 00:19:48,960 --> 00:19:53,919 at this point each of the devices 610 00:19:51,200 --> 00:19:55,120 if i've passed a shard of x and y would 611 00:19:53,919 --> 00:19:56,720 be looking at a different set of 612 00:19:55,120 --> 00:19:58,320 gradients so 613 00:19:56,720 --> 00:20:00,880 the thing that pmap really is about is 614 00:19:58,320 --> 00:20:03,039 this parallelization so this piece would 615 00:20:00,880 --> 00:20:05,120 all be running in parallel across the 616 00:20:03,039 --> 00:20:08,080 machines and each device as long as the 617 00:20:05,120 --> 00:20:09,919 x and y was separate would be 618 00:20:08,080 --> 00:20:11,200 like gradients for a different chunk of 619 00:20:09,919 --> 00:20:13,440 my data 620 00:20:11,200 --> 00:20:15,280 now very in a very very simple way jax 621 00:20:13,440 --> 00:20:18,320 provides us other operators which work 622 00:20:15,280 --> 00:20:20,640 across the mesh in this case p sum so by 623 00:20:18,320 --> 00:20:22,559 simply running this p sum here uh what 624 00:20:20,640 --> 00:20:25,360 what what jax would do in this case is 625 00:20:22,559 --> 00:20:27,280 across all the devices actually collect 626 00:20:25,360 --> 00:20:29,520 the gradients together and sum them up 627 00:20:27,280 --> 00:20:31,520 so if each of if i had 2000 machines and 628 00:20:29,520 --> 00:20:33,600 each of them had individual gradients at 629 00:20:31,520 --> 00:20:34,880 this point if i collect them together i 630 00:20:33,600 --> 00:20:36,240 would actually sum all those gradients 631 00:20:34,880 --> 00:20:38,960 up and depending on how i wanted to use 632 00:20:36,240 --> 00:20:40,799 those early is p sum or p mean 633 00:20:38,960 --> 00:20:42,080 and behind the scenes jax is doing all 634 00:20:40,799 --> 00:20:43,840 the work to have all these machines 635 00:20:42,080 --> 00:20:45,120 interchange and to sum and to put things 636 00:20:43,840 --> 00:20:47,520 together and 637 00:20:45,120 --> 00:20:49,200 and we just we just get it magically 638 00:20:47,520 --> 00:20:50,880 so this is huge because 639 00:20:49,200 --> 00:20:53,679 trying to do this stuff yourself is is a 640 00:20:50,880 --> 00:20:55,440 complete pain and requires so much code 641 00:20:53,679 --> 00:20:57,440 but but jax just makes it absolutely 642 00:20:55,440 --> 00:21:01,200 trivial with these type of operators and 643 00:20:57,440 --> 00:21:03,840 it just scales from the tpu with the v38 644 00:21:01,200 --> 00:21:06,000 up to these tpus uh pod slices with 645 00:21:03,840 --> 00:21:09,120 literally thousands of cores without any 646 00:21:06,000 --> 00:21:11,200 code change it's quite astounding 647 00:21:09,120 --> 00:21:13,600 and you know pmap and vmap sometimes 648 00:21:11,200 --> 00:21:16,080 feel like um they're quite specific and 649 00:21:13,600 --> 00:21:18,080 jax is a very much a moving target 650 00:21:16,080 --> 00:21:20,240 i love this this is the the next version 651 00:21:18,080 --> 00:21:22,320 of pmap potentially which is called xmap 652 00:21:20,240 --> 00:21:24,080 and i love this this stock string is 653 00:21:22,320 --> 00:21:25,679 aspirational it's just like 654 00:21:24,080 --> 00:21:26,559 jax really is bleeding edge some of this 655 00:21:25,679 --> 00:21:28,000 stuff 656 00:21:26,559 --> 00:21:30,080 but you know they're providing more and 657 00:21:28,000 --> 00:21:32,159 more ways to make use of this crazy 658 00:21:30,080 --> 00:21:33,520 compute and to give hints around how you 659 00:21:32,159 --> 00:21:35,520 want to actually 660 00:21:33,520 --> 00:21:36,799 make make use of these meshes it's quite 661 00:21:35,520 --> 00:21:37,760 astounding 662 00:21:36,799 --> 00:21:40,080 okay 663 00:21:37,760 --> 00:21:41,840 so everything we talked about here 664 00:21:40,080 --> 00:21:43,840 is quite low level we talked about like 665 00:21:41,840 --> 00:21:45,760 um you know there's vectorization and 666 00:21:43,840 --> 00:21:47,200 parallelization it's fun stuff to do but 667 00:21:45,760 --> 00:21:49,600 it's nothing to do with how we usually 668 00:21:47,200 --> 00:21:51,679 do things so a lot of people ask me you 669 00:21:49,600 --> 00:21:53,600 know do you like jax or keras better and 670 00:21:51,679 --> 00:21:55,280 they're a little bit um sort of it's a 671 00:21:53,600 --> 00:21:56,559 bit of a mismatch you can't really make 672 00:21:55,280 --> 00:21:58,799 this comparison 673 00:21:56,559 --> 00:22:01,120 and for me keras is you know two two 674 00:21:58,799 --> 00:22:02,960 fundamental things it's a um 675 00:22:01,120 --> 00:22:04,799 it's it's a way of describing model 676 00:22:02,960 --> 00:22:07,120 definitions so you know really rich way 677 00:22:04,799 --> 00:22:09,280 of um assembling what they call layers 678 00:22:07,120 --> 00:22:10,640 into models and these sort of you know 679 00:22:09,280 --> 00:22:11,840 optimizers that would go with these 680 00:22:10,640 --> 00:22:13,760 layers and models so that's a really 681 00:22:11,840 --> 00:22:14,960 handy thing keras has 682 00:22:13,760 --> 00:22:16,799 and another thing that keras is very 683 00:22:14,960 --> 00:22:19,600 good at is having an opinionated fit 684 00:22:16,799 --> 00:22:22,159 loop so this idea about given some data 685 00:22:19,600 --> 00:22:23,919 and an optimizer and a model just to 686 00:22:22,159 --> 00:22:26,559 sort of run in batches over it and do 687 00:22:23,919 --> 00:22:29,679 things with with control from callbacks 688 00:22:26,559 --> 00:22:31,120 so keras really has those two things and 689 00:22:29,679 --> 00:22:33,039 um none of those are anything to do with 690 00:22:31,120 --> 00:22:34,400 jax jax couldn't care less about fit 691 00:22:33,039 --> 00:22:35,840 loops it couldn't care less about 692 00:22:34,400 --> 00:22:37,679 optimizers although it does have some 693 00:22:35,840 --> 00:22:39,440 sort of you know reference optimizers in 694 00:22:37,679 --> 00:22:41,520 the experimental 695 00:22:39,440 --> 00:22:43,679 sort of branches and 696 00:22:41,520 --> 00:22:45,600 modules but for fundamentally jacks 697 00:22:43,679 --> 00:22:47,679 around that you know making use of xla 698 00:22:45,600 --> 00:22:49,520 to do this great compute so 699 00:22:47,679 --> 00:22:50,880 if you provide an awesome 700 00:22:49,520 --> 00:22:52,159 library for doing this sort of stuff but 701 00:22:50,880 --> 00:22:53,520 you don't have 702 00:22:52,159 --> 00:22:54,880 you know high level libraries on top 703 00:22:53,520 --> 00:22:56,559 what happens you get the cambrian 704 00:22:54,880 --> 00:22:59,120 explosion of libraries and this this 705 00:22:56,559 --> 00:23:00,799 this we see time and time again so if 706 00:22:59,120 --> 00:23:02,080 you wanted to do anything with jax you 707 00:23:00,799 --> 00:23:03,200 probably don't want to do it directly 708 00:23:02,080 --> 00:23:05,760 but you actually want to do it with one 709 00:23:03,200 --> 00:23:07,600 of you know 50 million different um 710 00:23:05,760 --> 00:23:09,120 libraries that come out all the time and 711 00:23:07,600 --> 00:23:10,799 they have different flavors and sort of 712 00:23:09,120 --> 00:23:12,720 different separations of concerns you 713 00:23:10,799 --> 00:23:14,080 know like our lacks is interesting 714 00:23:12,720 --> 00:23:16,720 reinforcement learning tracks is 715 00:23:14,080 --> 00:23:18,799 interested in transformers jax md is 716 00:23:16,720 --> 00:23:22,640 like you know using these meshes to do 717 00:23:18,799 --> 00:23:23,520 complicated molecular dynamics so really 718 00:23:22,640 --> 00:23:26,000 um 719 00:23:23,520 --> 00:23:27,200 it's hard to decide which which library 720 00:23:26,000 --> 00:23:28,960 you might want to use on top the only 721 00:23:27,200 --> 00:23:30,240 thing you can really do is to explore 722 00:23:28,960 --> 00:23:31,280 and experiment in a couple of different 723 00:23:30,240 --> 00:23:33,039 ways 724 00:23:31,280 --> 00:23:35,919 the the one that i have sort of settled 725 00:23:33,039 --> 00:23:37,840 on myself is the combination of haiku 726 00:23:35,919 --> 00:23:39,679 and optics and i just want to give you a 727 00:23:37,840 --> 00:23:41,200 quick flavor of those 728 00:23:39,679 --> 00:23:43,520 so haiku 729 00:23:41,200 --> 00:23:45,039 has like i mentioned with keras one of 730 00:23:43,520 --> 00:23:46,960 the things that keras gives you is a way 731 00:23:45,039 --> 00:23:48,559 of defining layers into models that's a 732 00:23:46,960 --> 00:23:50,240 really big thing and haiku gives us a 733 00:23:48,559 --> 00:23:51,919 bunch of that stuff so here's the 734 00:23:50,240 --> 00:23:54,000 simplest model i can basically make this 735 00:23:51,919 --> 00:23:56,640 is um y equals mx plus b basically it's 736 00:23:54,000 --> 00:23:59,440 a linear layer that haiku provides and 737 00:23:56,640 --> 00:24:01,520 there's some function which describes um 738 00:23:59,440 --> 00:24:03,279 that model but but this function i've 739 00:24:01,520 --> 00:24:05,360 got here could be all sorts of things 740 00:24:03,279 --> 00:24:07,200 you know multiple convolutions all sorts 741 00:24:05,360 --> 00:24:08,320 of weird self-attention stuff haiku 742 00:24:07,200 --> 00:24:09,919 provides a whole bunch of what we think 743 00:24:08,320 --> 00:24:12,320 about as those layers that i can 744 00:24:09,919 --> 00:24:14,480 assemble in different ways 745 00:24:12,320 --> 00:24:15,919 and then what we get from um and again 746 00:24:14,480 --> 00:24:18,240 just i'll just point out some 747 00:24:15,919 --> 00:24:20,480 interesting things here what what what 748 00:24:18,240 --> 00:24:22,159 haiku is doing is taking these functions 749 00:24:20,480 --> 00:24:23,760 and in the same way jax does 750 00:24:22,159 --> 00:24:26,159 transforming them into other things that 751 00:24:23,760 --> 00:24:28,240 we can use and in particular the main 752 00:24:26,159 --> 00:24:30,400 one around the model definition is that 753 00:24:28,240 --> 00:24:32,640 that haiku will change this function 754 00:24:30,400 --> 00:24:34,320 into a model object which has two 755 00:24:32,640 --> 00:24:37,600 important things for us 756 00:24:34,320 --> 00:24:39,120 one is an initialization uh sort of idea 757 00:24:37,600 --> 00:24:40,960 given this model and and you know 758 00:24:39,120 --> 00:24:42,720 through tracing what does this model do 759 00:24:40,960 --> 00:24:44,400 and what parameters are used by that 760 00:24:42,720 --> 00:24:46,159 model so the first thing we want to 761 00:24:44,400 --> 00:24:48,000 basically do is to initialize this model 762 00:24:46,159 --> 00:24:50,640 to get a sense of what are the actual 763 00:24:48,000 --> 00:24:52,480 parameters that are described inside and 764 00:24:50,640 --> 00:24:54,720 then once we have those parameters we 765 00:24:52,480 --> 00:24:57,200 can apply those parameters and some 766 00:24:54,720 --> 00:24:58,640 input to get a to get an output now you 767 00:24:57,200 --> 00:25:00,320 can certainly see this is very much the 768 00:24:58,640 --> 00:25:01,600 functional way of doing things isn't it 769 00:25:00,320 --> 00:25:03,520 usually when we think about like an 770 00:25:01,600 --> 00:25:05,919 object orientated approach we would sort 771 00:25:03,520 --> 00:25:07,679 of bake these parameters into the model 772 00:25:05,919 --> 00:25:09,679 but haiku continues to use that 773 00:25:07,679 --> 00:25:10,400 functional approach of saying well you 774 00:25:09,679 --> 00:25:11,760 know 775 00:25:10,400 --> 00:25:13,520 it's really about these everything's 776 00:25:11,760 --> 00:25:15,120 with respect to parameters every time we 777 00:25:13,520 --> 00:25:16,960 make these calls 778 00:25:15,120 --> 00:25:18,320 so um if you're more of a functional 779 00:25:16,960 --> 00:25:19,840 sort of person 780 00:25:18,320 --> 00:25:21,200 and you want to use the jack stuff more 781 00:25:19,840 --> 00:25:23,039 directly which is basically about 782 00:25:21,200 --> 00:25:24,559 functional stuff then haiku is maybe 783 00:25:23,039 --> 00:25:25,919 better whereas if you want to try to 784 00:25:24,559 --> 00:25:27,600 hide things more and more and you're not 785 00:25:25,919 --> 00:25:29,120 as interested then there might be 786 00:25:27,600 --> 00:25:30,080 another library that would suit you 787 00:25:29,120 --> 00:25:32,640 better 788 00:25:30,080 --> 00:25:34,640 the other thing that sort of 789 00:25:32,640 --> 00:25:36,480 sits next to haiku which is a different 790 00:25:34,640 --> 00:25:38,320 package is a tool called optics and 791 00:25:36,480 --> 00:25:40,960 optics is a is a 792 00:25:38,320 --> 00:25:42,480 set and a collection of optimizers so in 793 00:25:40,960 --> 00:25:43,919 the same way that 794 00:25:42,480 --> 00:25:45,600 say i wanted to use atom to do some 795 00:25:43,919 --> 00:25:46,640 optimization what i'll do is i'll make 796 00:25:45,600 --> 00:25:49,600 one of these 797 00:25:46,640 --> 00:25:51,520 atom objects and i would initialize it 798 00:25:49,600 --> 00:25:53,360 again with these parameters that have 799 00:25:51,520 --> 00:25:55,679 come from my model and that gives me an 800 00:25:53,360 --> 00:25:57,360 optimizer state because some optimizers 801 00:25:55,679 --> 00:25:59,520 are stateful adam for example is a 802 00:25:57,360 --> 00:26:01,279 stateful optimizer and then it's up to 803 00:25:59,520 --> 00:26:02,640 me again to look after this state and to 804 00:26:01,279 --> 00:26:04,880 do things with it 805 00:26:02,640 --> 00:26:07,120 and in particular here's the classic 806 00:26:04,880 --> 00:26:08,720 update rule we might do 807 00:26:07,120 --> 00:26:11,600 this is i'll go through this slide a bit 808 00:26:08,720 --> 00:26:13,679 more carefully so i've got some sort of 809 00:26:11,600 --> 00:26:15,520 loss function that describes 810 00:26:13,679 --> 00:26:18,080 given a set of x and y and a set of 811 00:26:15,520 --> 00:26:19,600 parameters you know how well am i doing 812 00:26:18,080 --> 00:26:22,880 and the way i would calculate that loss 813 00:26:19,600 --> 00:26:24,880 is to apply my model with the parameters 814 00:26:22,880 --> 00:26:27,279 in x get predicted values and then 815 00:26:24,880 --> 00:26:30,799 calculate some sort of loss 816 00:26:27,279 --> 00:26:32,159 then what i can do is uh use jacks to 817 00:26:30,799 --> 00:26:34,080 get gradients 818 00:26:32,159 --> 00:26:35,600 of that parameter of that function with 819 00:26:34,080 --> 00:26:37,760 respect to parameters which give me 820 00:26:35,600 --> 00:26:40,000 gradients and then what i can do is pass 821 00:26:37,760 --> 00:26:42,000 the gradients to the optimizer 822 00:26:40,000 --> 00:26:43,520 to say basically 823 00:26:42,000 --> 00:26:46,320 given these gradients and your current 824 00:26:43,520 --> 00:26:47,520 state what's your new state and are 825 00:26:46,320 --> 00:26:49,679 there any updates to make to the 826 00:26:47,520 --> 00:26:51,200 parameters and then again given the 827 00:26:49,679 --> 00:26:52,960 parameters and those updates what are 828 00:26:51,200 --> 00:26:54,320 the new set of parameters 829 00:26:52,960 --> 00:26:56,080 so this is a classic sort of functional 830 00:26:54,320 --> 00:26:58,400 approach again where we're passing in 831 00:26:56,080 --> 00:27:00,559 states and getting new states back so 832 00:26:58,400 --> 00:27:02,720 all these things are 100 immutable and 833 00:27:00,559 --> 00:27:04,640 we just sort of have to do this classic 834 00:27:02,720 --> 00:27:06,400 thing of here's a set of parameters and 835 00:27:04,640 --> 00:27:07,600 through this update i get a new set of 836 00:27:06,400 --> 00:27:09,760 parameters 837 00:27:07,600 --> 00:27:11,600 and um we're using all these jax 838 00:27:09,760 --> 00:27:13,679 fundamentals like grad or you know 839 00:27:11,600 --> 00:27:15,279 jitting this entire thing but all these 840 00:27:13,679 --> 00:27:16,799 things that i'm putting in here 841 00:27:15,279 --> 00:27:18,480 depending on how i want to do things can 842 00:27:16,799 --> 00:27:20,720 have a sprinkling of the v map and the p 843 00:27:18,480 --> 00:27:22,240 map depending on how i want to run 844 00:27:20,720 --> 00:27:23,360 across different machines so you 845 00:27:22,240 --> 00:27:24,799 remember in that slide a little while 846 00:27:23,360 --> 00:27:25,679 ago we were talking about this update 847 00:27:24,799 --> 00:27:27,679 being 848 00:27:25,679 --> 00:27:29,919 done with a pmap instead of just a 849 00:27:27,679 --> 00:27:31,200 directly with an update and i think this 850 00:27:29,919 --> 00:27:33,039 idea of being able to introduce these 851 00:27:31,200 --> 00:27:35,360 things as we go along is super powerful 852 00:27:33,039 --> 00:27:36,880 because what it means is that we can um 853 00:27:35,360 --> 00:27:39,360 sort of get things working in a simple 854 00:27:36,880 --> 00:27:42,000 way on a sort of you know single machine 855 00:27:39,360 --> 00:27:43,600 simple uh non-batched sort of operation 856 00:27:42,000 --> 00:27:45,440 and then eventually you know scale 857 00:27:43,600 --> 00:27:47,760 things out more and more and more and 858 00:27:45,440 --> 00:27:49,039 this idea of jitting an update rule and 859 00:27:47,760 --> 00:27:50,960 then basically going through and then 860 00:27:49,039 --> 00:27:52,960 just pumping that update basically 861 00:27:50,960 --> 00:27:54,720 across the data set where we pass in 862 00:27:52,960 --> 00:27:56,080 parameters we get next parameters we 863 00:27:54,720 --> 00:27:58,000 pass in the state we get the next state 864 00:27:56,080 --> 00:27:59,520 and just sort of run that 865 00:27:58,000 --> 00:28:01,440 is very very fast because we're 866 00:27:59,520 --> 00:28:04,559 basically just sort of if these are 867 00:28:01,440 --> 00:28:05,600 sharded across the um the devices then 868 00:28:04,559 --> 00:28:07,840 all the the only thing we're having to 869 00:28:05,600 --> 00:28:09,360 do is really move uh the data around so 870 00:28:07,840 --> 00:28:12,159 this can be very very high performance 871 00:28:09,360 --> 00:28:14,159 at a very high scale 872 00:28:12,159 --> 00:28:15,520 so yeah just a recap of the sort of 873 00:28:14,159 --> 00:28:17,200 things we talked about i sort of talked 874 00:28:15,520 --> 00:28:19,039 about the jax fundamentals that was the 875 00:28:17,200 --> 00:28:20,559 the main thing and again that takeaway 876 00:28:19,039 --> 00:28:22,240 just try jack's numpy if nothing else 877 00:28:20,559 --> 00:28:23,760 and see how you go it might just work 878 00:28:22,240 --> 00:28:25,600 and surprise you but you know 879 00:28:23,760 --> 00:28:28,000 fundamentally there's this there's these 880 00:28:25,600 --> 00:28:29,039 building blocks that jax provides around 881 00:28:28,000 --> 00:28:30,559 um 882 00:28:29,039 --> 00:28:32,399 the classic sort of composition so 883 00:28:30,559 --> 00:28:33,520 gradients i want jet i want to v-map 884 00:28:32,399 --> 00:28:35,360 things 885 00:28:33,520 --> 00:28:37,039 there are some multi-host specific 886 00:28:35,360 --> 00:28:39,600 things that come from the sort of 887 00:28:37,039 --> 00:28:41,760 heritage of why jax even came out around 888 00:28:39,600 --> 00:28:43,840 xla and so there's things like p-map and 889 00:28:41,760 --> 00:28:45,279 p-sump and and you know i've scratching 890 00:28:43,840 --> 00:28:46,960 the surface here in terms of that mesh 891 00:28:45,279 --> 00:28:48,880 computation there are many many 892 00:28:46,960 --> 00:28:51,279 complicated things that we can do here 893 00:28:48,880 --> 00:28:53,279 um i've just sort of dropped into for a 894 00:28:51,279 --> 00:28:56,399 quick talk this idea of parallelization 895 00:28:53,279 --> 00:28:58,159 of map or this this operator p sum 896 00:28:56,399 --> 00:28:59,440 and then you know jax maybe isn't 897 00:28:58,159 --> 00:29:01,679 exactly what you want you want a high 898 00:28:59,440 --> 00:29:03,440 level support and yeah all i can say is 899 00:29:01,679 --> 00:29:05,120 just query what the latest neural 900 00:29:03,440 --> 00:29:07,200 network library for jax is because it 901 00:29:05,120 --> 00:29:09,039 would have changed by the time this talk 902 00:29:07,200 --> 00:29:11,200 is finished it's a very vibrant field 903 00:29:09,039 --> 00:29:13,279 with a lot of things changing and you 904 00:29:11,200 --> 00:29:15,840 know i think it's it's okay in this sort 905 00:29:13,279 --> 00:29:16,880 of fields when things are fast to um try 906 00:29:15,840 --> 00:29:19,039 a couple of different ones and just see 907 00:29:16,880 --> 00:29:21,520 which sort of one uh represents the the 908 00:29:19,039 --> 00:29:22,960 way that you think about the problems um 909 00:29:21,520 --> 00:29:25,039 and you know you might get lucky on one 910 00:29:22,960 --> 00:29:26,799 or you might have to change around a bit 911 00:29:25,039 --> 00:29:28,960 okay so that was basically 912 00:29:26,799 --> 00:29:30,960 uh my introduction there 913 00:29:28,960 --> 00:29:32,880 there's like a two hour version of of 914 00:29:30,960 --> 00:29:34,640 this talk with a bit more tutorial stuff 915 00:29:32,880 --> 00:29:37,120 on my website there but otherwise here's 916 00:29:34,640 --> 00:29:38,240 a bunch of links which correspond to 917 00:29:37,120 --> 00:29:39,279 uh the sort of things i've talked about 918 00:29:38,240 --> 00:29:41,840 so far 919 00:29:39,279 --> 00:29:41,840 thank you 920 00:29:46,080 --> 00:29:49,840 thank you matt and 921 00:29:48,960 --> 00:29:51,919 um 922 00:29:49,840 --> 00:29:53,520 so we have time for one question and 923 00:29:51,919 --> 00:29:55,120 it's going to be 924 00:29:53,520 --> 00:29:57,760 a quick one maybe you have a quick 925 00:29:55,120 --> 00:30:01,120 answer so you've mentioned tpus but most 926 00:29:57,760 --> 00:30:03,520 people have gpus and no gpus what kind 927 00:30:01,120 --> 00:30:05,360 of problems are better trained on tpus 928 00:30:03,520 --> 00:30:08,080 like they would require you to go the 929 00:30:05,360 --> 00:30:09,279 tpu way and would your jacks go change 930 00:30:08,080 --> 00:30:11,039 for that 931 00:30:09,279 --> 00:30:13,760 yeah so the jax code would not change at 932 00:30:11,039 --> 00:30:15,039 all so that they've been very clear that 933 00:30:13,760 --> 00:30:16,480 xla 934 00:30:15,039 --> 00:30:18,480 supports 935 00:30:16,480 --> 00:30:20,799 gpus as much as tpus and so all the 936 00:30:18,480 --> 00:30:22,880 things i talked about pmap psum they 937 00:30:20,799 --> 00:30:25,600 work on my my desktop here with a couple 938 00:30:22,880 --> 00:30:27,200 of gpus without any change at all 939 00:30:25,600 --> 00:30:29,679 the tpus are 940 00:30:27,200 --> 00:30:31,600 yeah the tpus i think are more sorted uh 941 00:30:29,679 --> 00:30:33,440 in terms of the benchmarks i've seen the 942 00:30:31,600 --> 00:30:34,480 tpus are the best performance when 943 00:30:33,440 --> 00:30:37,440 you're really starting to get into the 944 00:30:34,480 --> 00:30:39,440 large data regime so um 945 00:30:37,440 --> 00:30:41,120 generally the way they work is that like 946 00:30:39,440 --> 00:30:42,799 like a lot of like high performance 947 00:30:41,120 --> 00:30:44,640 accelerators they work well once they've 948 00:30:42,799 --> 00:30:46,320 got a real chance to spin up so if 949 00:30:44,640 --> 00:30:48,320 you're working on a small bit of data 950 00:30:46,320 --> 00:30:50,480 you might find that 951 00:30:48,320 --> 00:30:52,720 jax on tpu's isn't actually as 952 00:30:50,480 --> 00:30:53,840 performant as some other things um and 953 00:30:52,720 --> 00:30:55,440 it's not until you get to these bigger 954 00:30:53,840 --> 00:30:56,960 data sets that um 955 00:30:55,440 --> 00:30:59,919 uh it really makes a difference 956 00:30:56,960 --> 00:31:01,039 particularly batch size is a big one 957 00:30:59,919 --> 00:31:01,840 thank you 958 00:31:01,039 --> 00:31:04,480 so 959 00:31:01,840 --> 00:31:06,240 uh thank you so much matt for being part 960 00:31:04,480 --> 00:31:08,559 of the data science and analytics track 961 00:31:06,240 --> 00:31:12,000 today yeah no worries thank you 962 00:31:08,559 --> 00:31:12,880 and coming up at 11 we'll have jordika 963 00:31:12,000 --> 00:31:14,640 singh 964 00:31:12,880 --> 00:31:16,720 who will be talking to us about 965 00:31:14,640 --> 00:31:18,399 classifying audio into types using 966 00:31:16,720 --> 00:31:21,399 python see you then 967 00:31:18,399 --> 00:31:21,399 bye