{
"cells": [
{
"cell_type": "markdown",
"id": "578f9d15-3df9-461f-befc-ab77c381a96d",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"(Intro_PyTensor)=\n",
"\n",
"# What is PyTensor?\n",
":::{post} August 16, 2025 \n",
":tags: introduction, worked examples, tutorial\n",
":category: beginner, explanation \n",
":author: Jesse Grabowski, Ricardo Vieira\n",
":::\n",
"\n",
"A library to define, manipulate, and compile computational graphs.\n",
"\n",
"\n",
"## Let's break it apart\n",
"A library to (1.) define, (2.) manipulate, and (3.) compile (0.) computational graphs."
]
},
{
"cell_type": "markdown",
"id": "0427f803",
"metadata": {},
"source": [
"## Prepare Notebook"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "701a0893",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"\n",
"plt.style.use(\"seaborn-v0_8\")\n",
"\n",
"%config InlineBackend.figure_format = \"retina\""
]
},
{
"cell_type": "markdown",
"id": "bdf5f119-c085-48a1-b6b0-733ed4921303",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## (0.) Computational graph\n",
"\n",
"Any program implies a computational graph. In PyTensor we're mostly focusing on static, array-based (i.e, numpy) programs with some branching and looping primitives. PyTensor is hackable, and can be easily extended to represent arbitrary types and operations. That said, its usefulness quickly vanishes as you venture out of its area of focus.\n",
"\n",
"Everyone here is likely familar with the idea of a computation graph, but let's look at a quick example anyway. Consider a program that computes $z = x + y$. Here we have:\n",
"\n",
"- Two symbolic inputs, $x$ and $y$\n",
"- An operator, $+$, that takes several inputs and maps them to a single output\n",
"- A symbolic output, $z$\n",
"\n",
"Visualized:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "f1890c77",
"metadata": {
"slideshow": {
"slide_type": "skip"
}
},
"outputs": [],
"source": [
"import graphviz as gr\n",
"\n",
"\n",
"def draw_graph(edge_list, node_props=None, edge_props=None, graph_direction=\"UD\"):\n",
" \"\"\"Utility to draw a causal (directed) graph\"\"\"\n",
" g = gr.Digraph(\n",
" graph_attr={\n",
" \"rankdir\": graph_direction,\n",
" \"ratio\": \"0.3\",\n",
" \"overlap\": \"vpsc\",\n",
" \"splines\": \"true\",\n",
" \"mode\": \"sgd\",\n",
" \"lheight\": \"4\",\n",
" },\n",
" engine=\"dot\",\n",
" )\n",
"\n",
" edge_props = {} if edge_props is None else edge_props\n",
" for e in edge_list:\n",
" props = edge_props[e] if e in edge_props else {}\n",
" g.edge(e[0], e[1], **props)\n",
"\n",
" if node_props is not None:\n",
" for name, props in node_props.items():\n",
" g.node(name=name, **props)\n",
" return g\n"
]
},
{
"cell_type": "markdown",
"id": "870d7457",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"While this DAG is boring and obvious, it highlights that there are two types of nodes:\n",
"\n",
"- Variables\n",
"- Operations\n",
"\n",
"We can also distinguish between \"root variables\", like $x$ and $y$, intermediate variables (we don't have any here :( ) and output variables ($z$). These differences will return in a few slides!"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "e0396698",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"draw_graph([(\"x\", \"+\"), (\"y\", \"+\"), (\"+\", \"z\")])"
]
},
{
"cell_type": "markdown",
"id": "97944cd3",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## (1.) Definition \n",
"In PyTensor, you define a computational graph explicitly, starting with _placeholder_ input variables. From these inputs you build more intermediate variables by applying operators (like $+$), which can then be treated either as outputs, or as intermediate variables for further computation.\n",
"\n",
"To reduce the learning barrier, Pytensor was designed to look like numpy code. But be aware that it's not! Let' look at some differences."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "ed6d4a86d7398eca",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"text/plain": [
"array([0. , 0.69314718, 1.31326169])"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numpy as np\n",
"\n",
"\n",
"# Numpy\n",
"x = np.array([0, 1, np.e]) # Actual numbers\n",
"y = np.log(1 + x) # Actual computation\n",
"y # Actual result"
]
},
{
"cell_type": "markdown",
"id": "559a56aa",
"metadata": {},
"source": [
"Now, let's see how we can use PyTensor to define a computational graph for the same operation."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "f10206d3",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(Log.0, pytensor.tensor.variable.TensorVariable)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pytensor\n",
"import pytensor.tensor as pt\n",
"\n",
"\n",
"# Pytensor\n",
"x = pt.tensor(shape=(3,), dtype=\"float64\") # Symbolic vector\n",
"y = pt.log(1 + x) # Symbolic computation\n",
"y, type(y)"
]
},
{
"cell_type": "markdown",
"id": "cd6ebfa4",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"The most important thing to grok here is that `y` is not a number! It's a symbol. Specifically, it's the output of a `Log` operation, so we see the name `Log.0`. The type is a `TensorVariable`, which is a basic unit of symbolic computation.\n",
"\n",
"Notice how Pytensor is straddling a line between a computer algebra system like sympy or maple, and a tensor library like numpy."
]
},
{
"cell_type": "markdown",
"id": "904410ca",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"A key tool when working with Pytensor is {func}`~pytensor.dprint`. It's short for \"debug print\", and it shows you a text representation of a graph. It's always a good idea to look at the graph pytensor is generating for your graph, because it shows you exactly what is going on. Admittedly, they take a bit of getting used to reading."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "f6f5731649ea29c7",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Log [id A]\n",
" └─ Add [id B]\n",
" ├─ ExpandDims{axis=0} [id C]\n",
" │ └─ 1 [id D]\n",
" └─ [id E]\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y.dprint()"
]
},
{
"cell_type": "markdown",
"id": "e913e973",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"PyTensor can also generate visual representations of your graph. You will be forgiven if you think this is a much better way to view a graph. It's great, but doesn't scale well at all. For large graphs, the output is completely unreadable. If anyone is interested in developing tensorboard-like tools that will allow for interactive investigation of a graph, PRs are accepted!\n",
"\n",
"That said, here's the visual representation of our graph. It's very similar to our $z = x + y$ graph, with a some extra computational accoutrement. "
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "25e22315",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"image/svg+xml": [
""
],
"text/plain": [
""
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from IPython.display import SVG\n",
"\n",
"from pytensor.printing import pydotprint\n",
"\n",
"\n",
"SVG(pydotprint(y, return_image=True, format=\"svg\"))"
]
},
{
"cell_type": "markdown",
"id": "8bf95ff2-1c44-4c98-9dbc-de76b94c57d2",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"For those curious: This kind of graph is a bi-partite, directed, acyclic graph composed of interconnected Variable -> Apply -> Variable nodes.\n",
"\n",
"Apply nodes connect input variables to output variables, via a specific operator. Variables have a type and can have an owner (the Apply node that creates them) or not (if they are root placeholder variables).\n",
"\n",
"Here is a schematic of the $z = x + y$ graph again, but this time with annotations that show how it is represented in Pytensor. I always felt it was upside-down; perhaps that feeling will help you to understand what you're looking at. Read it from top-to-bottom, but with the flow of computation running against the arrows:\n",
"\n",
""
]
},
{
"cell_type": "markdown",
"id": "6ec992ea",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"We can see these definitions in action by looking at properties of `y`. We've already seen that `type(y)` is a `TensorVariable`. With the exception of root variables, all `TensorVariables` have an `owner`, which is the `Apply` node that created it. \n",
"\n",
"In this case, the owner of `y` is `Log(Add.0)`. `Add.0` is the name of the output of an `Add` `Op`, so `y` is the result of a `Log` applied to an `Add`."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "2c0a5be5591f936",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"outputs": [
{
"data": {
"text/plain": [
"(Log(Add.0), pytensor.graph.basic.Apply)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y.owner, type(y.owner)"
]
},
{
"cell_type": "markdown",
"id": "3ef053ec",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Looking at the schematic again, note that there is a conceptual difference between `Apply` (which is a big box that handles a bunch of stuff) and an `Op`, which is the actual type of computation being done. \n",
"\n",
"In this case, the specific Op is an `Elemwise`, which is a meta-Op that broadcasts a scalar computation (in this case `scalar_op = log`) across a tensor input. It makes it an elementwise operation... hence \"Elemwise Op\"."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "bc75e500fb9ece2f",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"outputs": [
{
"data": {
"text/plain": [
"(Elemwise(scalar_op=log,inplace_pattern=),\n",
" pytensor.tensor.elemwise.Elemwise)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y.owner.op, type(y.owner.op)"
]
},
{
"cell_type": "markdown",
"id": "5e2e3fac",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"We can also check the outputs of the `Apply`. This is, of course, `y` itself!"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "f1bb3439391f2509",
"metadata": {
"scrolled": true,
"slideshow": {
"slide_type": "-"
}
},
"outputs": [
{
"data": {
"text/plain": [
"([Log.0], True)"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y.owner.outputs, y.owner.outputs == [y]"
]
},
{
"cell_type": "markdown",
"id": "75685087",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Finally, there's the inputs. This is `Add.0`, which represents `1 + x`"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "c6f97556431c818f",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"outputs": [
{
"data": {
"text/plain": [
"[Add.0]"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y.owner.inputs"
]
},
{
"cell_type": "markdown",
"id": "b57e98dc",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"From here, the story begin again, and we can keep climbing up the graph until the root variable, `x`. I'll spare you, though"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "316c181a13627039",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Add [id A]\n",
" ├─ ExpandDims{axis=0} [id B]\n",
" │ └─ 1 [id C]\n",
" └─ [id D]\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y.owner.inputs[0].dprint() # And the story begins again"
]
},
{
"cell_type": "markdown",
"id": "3a916500-4dfe-4ef0-90ac-84f5de9f7ce2",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## (2) Manipulation\n",
"\n",
"So far, nothing we've seen is special. Maybe it's a nice API for writing functions, but it's not at all clear why the program `y = pt.log(1 + x)` is better than the following Python program:\n",
"\n",
"```py\n",
"def log1p(x):\n",
" return np.log(1 + x)\n",
" \n",
"y = log1p(x)\n",
"```\n",
"\n",
"Actually, this is the exact API that JAX offers! Setting aside JAX, we might even take this function and decorate it with `@numba.njit`, potentially getting big speedups. What does pytensor offer us?\n",
"\n",
"The answer is that PyTensor is able to **manipulate** the computation graph. Furthermore, these manipulations are done within Python. JAX, for example, is able to trace the above function to construct a computational graph. But the graph representation is not front-and-center, and it's no obvious how to use it. In PyTensor, you are meant to be working direcly on graphs!\n",
"\n",
"There are three important graph operations to consider:\n",
"\n",
"1. Replacement\n",
"2. Rewriting\n",
"3. Transformation\n",
"\n",
"We will look at each of these in turn."
]
},
{
"cell_type": "markdown",
"id": "6ec7dfd0",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Graph Replacement\n",
"\n",
"The simplest graph operation is a **replace**. We simply take one operation and swap it out for another. Consider the following graph:"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "70351014",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Add [id A]\n",
" ├─ ExpandDims{axis=0} [id B]\n",
" │ └─ 1 [id C]\n",
" └─ Sin [id D]\n",
" └─ Mul [id E]\n",
" ├─ ExpandDims{axis=0} [id F]\n",
" │ └─ 6.283185307179586 [id G]\n",
" └─ b [id H]\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a = pt.tensor(\"b\", shape=(None,))\n",
"b = 2 * pt.pi * a\n",
"c = pt.sin(b)\n",
"d = 1 + c\n",
"d.dprint()"
]
},
{
"cell_type": "markdown",
"id": "87f8371c",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"If, for some reason, we wanted to do replace `pt.sin(b)` with `pt.cos(b)`, we could do so with the `graph_replace` function. The subgraph we want to target is `c`, and the top-level output is `d`, so we do it as follows:"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "a20df8d6",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Add [id A]\n",
" ├─ ExpandDims{axis=0} [id B]\n",
" │ └─ 1 [id C]\n",
" └─ Cos [id D]\n",
" └─ Mul [id E]\n",
" ├─ ExpandDims{axis=0} [id F]\n",
" │ └─ 6.283185307179586 [id G]\n",
" └─ b [id H]\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from pytensor.graph.replace import graph_replace\n",
"\n",
"\n",
"z2 = graph_replace(d, {c: pt.cos(b)})\n",
"z2.dprint()"
]
},
{
"cell_type": "markdown",
"id": "7ceffb26",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"This might seem like a silly example, but one-to-one transformations are quite common! For example, suppose we had a graph with a random variable, and we wanted to replace the random variable with a root variable, transforming it into a function. We can do that with graph replace too:"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "5856f780",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Exp [id A]\n",
" └─ normal_rv{\"(),()->()\"}.1 [id B]\n",
" ├─ RNG() [id C]\n",
" ├─ [10] [id D]\n",
" ├─ ExpandDims{axis=0} [id E]\n",
" │ └─ 0 [id F]\n",
" └─ ExpandDims{axis=0} [id G]\n",
" └─ 1 [id H]\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = pt.random.normal(loc=0, scale=1, size=(10,))\n",
"y = pt.exp(x)\n",
"y.dprint()"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "0271d4bd",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Exp [id A]\n",
" └─ x_input [id B]\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x_input = pt.tensor(\"x_input\", shape=(10,))\n",
"y2 = graph_replace(y, {x: x_input})\n",
"y2.dprint()"
]
},
{
"cell_type": "markdown",
"id": "08ccb932",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"#### Vectorization\n",
"\n",
"A less trivial type of graph replacement is *vectorization*. This is also a one-to-one replacement, but this time things are somewhat more complicated, because it changes the meaning of the graph. Pytensor will automatically reason about the broadcasting operations that need to happen to accomplish the replacement.\n",
"\n",
"This time, I make a graph where all the inputs are specifically declared to be scalars. Notice how this information flows down the graph: since `x` is scalar, and the `Op`s `log` and `add` are elemwise, the output `y` is also a scalar."
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "05efd44a",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Log [id A] \n",
" └─ Add [id B] \n",
" ├─ 1 [id C] \n",
" └─ x [id D] \n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = pt.dscalar(\"x\")\n",
"y = pt.log(1 + x)\n",
"y.dprint(print_type=True)"
]
},
{
"cell_type": "markdown",
"id": "bd514d05",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Suppose that instead of a scalar `x`, we want to use this function for a vector input. If we naively try to `graph_replace` here, we will get a shape error"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "2af1c20e",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Cannot convert Type Vector(float64, shape=(?,)) (of Variable x_vec) into Type Scalar(float64, shape=()). You can try to manually convert x_vec into a Scalar(float64, shape=()).\n"
]
}
],
"source": [
"x_vec = pt.tensor(\"x_vec\", shape=(None,))\n",
"try:\n",
" graph_replace(y, {x: x_vec})\n",
"except Exception as e:\n",
" print(e)"
]
},
{
"cell_type": "markdown",
"id": "1bb040e9",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Instead, use `vectorize_graph`!"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "429d224f",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Log [id A] \n",
" └─ Add [id B] \n",
" ├─ ExpandDims{axis=0} [id C] \n",
" │ └─ 1 [id D] \n",
" └─ x_vec [id E] \n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from pytensor.graph.replace import vectorize_graph\n",
"\n",
"\n",
"y_vec = vectorize_graph(y, {x: x_vec})\n",
"y_vec.dprint(print_type=True)"
]
},
{
"cell_type": "markdown",
"id": "539a6d85",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Rewrites\n",
"\n",
"Rewrites are at the heart of pytensor's usefulness. Pytensor maintains large databases of useful graph transformations. These can be applied to acheive a number of goals."
]
},
{
"cell_type": "markdown",
"id": "ff07ef2c",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"#### Scalarization\n",
"\n",
"We can even undo vectorization -- scalarization?\n",
"\n",
"Here's the original graph:"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "82e05376",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Subtensor{i} [id A]\n",
" ├─ Log [id B]\n",
" │ └─ Add [id C]\n",
" │ ├─ ExpandDims{axis=0} [id D]\n",
" │ │ └─ 1 [id E]\n",
" │ └─ x_vec [id F]\n",
" └─ 0 [id G]\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_vec[0].dprint()"
]
},
{
"cell_type": "markdown",
"id": "3a5bab67",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Notice how the indexing operation has been pushed down to the input!"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "e3c49b4c",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Log [id A]\n",
" └─ Add [id B]\n",
" ├─ 1.0 [id C]\n",
" └─ Subtensor{i} [id D]\n",
" ├─ x_vec [id E]\n",
" └─ 0 [id F]\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from pytensor.graph import rewrite_graph\n",
"\n",
"\n",
"rewrite_graph(y_vec[0]).dprint()"
]
},
{
"cell_type": "markdown",
"id": "6dce4025",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"#### Stabilization\n",
"\n",
"Calling `rewrite_graph` without any arguments will trigger many rewrites to be applied. If you have a more specific objective, you can ask for specific types of rewrites. For example, the `stabilize` tag includes rewrites that can transform your graph into a form that is more numerically stable"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "c1b3d7c7",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Log [id A] \n",
" └─ Add [id B] \n",
" ├─ ExpandDims{axis=0} [id C] \n",
" │ └─ 1 [id D] \n",
" └─ x [id E] \n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = pt.tensor(\"x\", shape=(None,))\n",
"y = pt.log(1 + x)\n",
"y.dprint(print_type=True)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "91174dfd",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Log1p [id A]\n",
" └─ x [id B]\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"stable_y = rewrite_graph(y, include=(\"stabilize\",))\n",
"stable_y.dprint()"
]
},
{
"cell_type": "markdown",
"id": "e65bb491",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"source": [
"There are many more types of rewrites, and we will see them shortly. But first we need to do a Quinten Tarintino and show things somewhat out of order. "
]
},
{
"cell_type": "markdown",
"id": "290f088c",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Graph-to-graph transformations\n",
"\n",
"The most powerful feature of pytensor is to take a graph and return another graph. \n",
"\n",
"We've already seen one example of this in `vectorization`. We included it above because it \"looks like\" `graph_replace`, in the sense that you target a single node for replacement. But as we saw, `vectorize_graph` returns an entirely new graph, with new shapes.\n",
"\n",
"The canonical example of a graph-to-graph transformation is automatic differentiation. If we know the derivative of every `Op` in our graph, we work backwards and follow the chain rule to construct a gradient graph from the graph of a scalar loss function.\n",
"\n",
"The cryptic Op called `Second` means: keep the second input after broadcasting the shape with the first.\n",
"It's the same as `np.broadcast_arrays(x, y)[1]`"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "c455a4c9",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True_div [id A]\n",
" ├─ Second [id B]\n",
" │ ├─ Log1p [id C]\n",
" │ │ └─ x [id D]\n",
" │ └─ ExpandDims{axis=0} [id E]\n",
" │ └─ Second [id F]\n",
" │ ├─ Sum{axes=None} [id G]\n",
" │ │ └─ Log1p [id C]\n",
" │ │ └─ ···\n",
" │ └─ 1.0 [id H]\n",
" └─ Add [id I]\n",
" ├─ ExpandDims{axis=0} [id J]\n",
" │ └─ 1 [id K]\n",
" └─ x [id D]\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from pytensor.gradient import grad\n",
"\n",
"\n",
"grad_y = grad(stable_y.sum(), wrt=x)\n",
"grad_y.dprint()"
]
},
{
"cell_type": "markdown",
"id": "5c0484ee",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"#### Quick digression: back to rewrites!\n",
"\n",
"Gradient graphs tend to be complex, so it's nice to simplify them. \n",
"\n",
"One type of simplification is *canonicalization*. It converts a graph into a \"standard\" form. Other rewrites can expect and reason from this form. I think you'll agree this form is much nicer."
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "1bef9b01",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True_div [id A]\n",
" ├─ [1.] [id B]\n",
" └─ Add [id C]\n",
" ├─ [1.] [id B]\n",
" └─ x [id D]\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rewrite_graph(grad_y, include=(\"canonicalize\",)).dprint()"
]
},
{
"cell_type": "markdown",
"id": "97a7c16b",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"We can also compose rewrites! In this next example, we both canonicalize, then apply specializations to get computational speedups.\n",
"\n",
"This also reveals the \"final form\" of the gradient: $\\log(1 + x) = \\frac{1}{1 + x}$"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "0262e4ba",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Reciprocal [id A]\n",
" └─ Add [id B]\n",
" ├─ [1.] [id C]\n",
" └─ x [id D]\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rewrite_graph(grad_y, include=(\"canonicalize\", \"specialize\")).dprint()"
]
},
{
"cell_type": "markdown",
"id": "d492543e",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"#### Back to transformations -- logp inference\n",
"\n",
"Another important graph-to-graph transformation is *automatic logp inference*. PyMC knows how to transform a generatve graph (forward draws) into a logp graph (backwards inference). Let's look at a simple example."
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "37ed37be",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"normal_rv{\"(),()->()\"}.1 [id A] 'z'\n",
" ├─ RNG() [id B]\n",
" ├─ NoneConst{None} [id C]\n",
" ├─ [0 0] [id D]\n",
" └─ [1 2] [id E]\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pymc as pm\n",
"\n",
"\n",
"with pm.Model() as model:\n",
" z = pm.Normal(name=\"z\", mu=np.array([0, 0]), sigma=np.array([1, 2]))\n",
"\n",
"pytensor.dprint(z)"
]
},
{
"cell_type": "markdown",
"id": "43b8a6da",
"metadata": {},
"source": [
"We can now use {func}`~pymc.logp` to compute the log-probability of a value."
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "8e3194d8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Check{sigma > 0} [id A] 'z_logprob'\n",
" ├─ Sub [id B]\n",
" │ ├─ Sub [id C]\n",
" │ │ ├─ Mul [id D]\n",
" │ │ │ ├─ ExpandDims{axis=0} [id E]\n",
" │ │ │ │ └─ -0.5 [id F]\n",
" │ │ │ └─ Pow [id G]\n",
" │ │ │ ├─ True_div [id H]\n",
" │ │ │ │ ├─ Sub [id I]\n",
" │ │ │ │ │ ├─ z_value [id J]\n",
" │ │ │ │ │ └─ [0 0] [id K]\n",
" │ │ │ │ └─ [1 2] [id L]\n",
" │ │ │ └─ ExpandDims{axis=0} [id M]\n",
" │ │ │ └─ 2 [id N]\n",
" │ │ └─ ExpandDims{axis=0} [id O]\n",
" │ │ └─ Log [id P]\n",
" │ │ └─ Sqrt [id Q]\n",
" │ │ └─ 6.283185307179586 [id R]\n",
" │ └─ Log [id S]\n",
" │ └─ [1 2] [id L]\n",
" └─ All{axes=None} [id T]\n",
" └─ MakeVector{dtype='bool'} [id U]\n",
" └─ All{axes=None} [id V]\n",
" └─ Gt [id W]\n",
" ├─ [1 2] [id L]\n",
" └─ ExpandDims{axis=0} [id X]\n",
" └─ 0 [id Y]\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"z_value = pt.vector(name=\"z_value\")\n",
"z_logp = pm.logp(rv=z, value=z_value)\n",
"\n",
"pytensor.dprint(z_logp)"
]
},
{
"cell_type": "markdown",
"id": "942aadb1",
"metadata": {},
"source": [
"Observe that we still get a graph, we do not do any computation yet. To do a evaluation on a given value, we can use the `eval` method:"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "1a0f1ff3",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Check{sigma > 0}(Sub.0, All{axes=None}.0)"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"z_logp.owner"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "a6aff350",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([-0.91893853, -1.61208571])"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"z_logp.eval({z_value: [0, 0]})"
]
},
{
"cell_type": "markdown",
"id": "43339908",
"metadata": {},
"source": [
"To \"verify\" the computation, we compute the same quantity using `scipy.stats.norm.logpdf`."
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "3453b57f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([-0.91893853, -1.61208571])"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import scipy\n",
"\n",
"\n",
"scipy.stats.norm.logpdf(\n",
" x=np.array([0, 0]), loc=np.array([0, 0]), scale=np.array([1, 2])\n",
")"
]
},
{
"cell_type": "markdown",
"id": "65c3537c",
"metadata": {},
"source": [
"The values match as expected."
]
},
{
"cell_type": "markdown",
"id": "1e888b20",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Composability\n",
"\n",
"It is important to note that all of these transformation operations are *composable*. We already saw an example of this when we applied graph rewrites to a gradient graph. We can also vectorize gradients, or take the gradient of a vectorize graph. Or we can replace subgraphs before or after applying rewrites. Because all of these operations take in a graph and return a graph, we're always able to chain together graph operations. \n",
"\n",
"Sometimes, you'll be surprised by what you end up with:"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "f44aed85",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Log [id A]\n",
" └─ Add [id B]\n",
" ├─ ExpandDims{axis=0} [id C]\n",
" │ └─ 1 [id D]\n",
" └─ x [id E]\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = pt.tensor(\"x\", shape=(3,))\n",
"y = pt.log(1 + x)\n",
"y.dprint()"
]
},
{
"cell_type": "markdown",
"id": "fa1eb391",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"The graph we've been working with is $y = \\log(1 + x)$. Let's transform it to be $y = \\log(1 + \\exp(x))$, then rewrite the expression for stability.\n",
"\n",
"First replace $x$ by $\\exp(x)$:"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "5dc415eb",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Log [id A]\n",
" └─ Add [id B]\n",
" ├─ ExpandDims{axis=0} [id C]\n",
" │ └─ 1 [id D]\n",
" └─ Exp [id E]\n",
" └─ x [id F]\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"new_y = graph_replace(y, replace={x: pt.exp(x)})\n",
"new_y.dprint()"
]
},
{
"cell_type": "markdown",
"id": "d418c581",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Then optimize the graph for stability:"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "94ca84f9",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Scalar_softplus [id A]\n",
" └─ x [id B]\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rewrite_graph(new_y, include=(\"stabilize\",)).dprint()"
]
},
{
"cell_type": "markdown",
"id": "97e2650b",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Did you guess the result would be just a single Op?"
]
},
{
"cell_type": "markdown",
"id": "93ed3570-45f3-4df4-bbfd-4ffa0497d30e",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## (3) Compilation\n",
"\n",
"All this is fun and dandy but only useful if we actually use it compute stuff! \n",
"\n",
"PyTensor provides a critical non-composable graph operation: `function`, which converts a pytensor graph into a callable python object that takes concrete inputs and returns concrete outputs. \n",
"\n",
"By default it runs an extensive database of rewrites to try and optimize the computational graph, and then compiles to C (technically a mix of C and Python if not all operations have a C implementation). See https://pytensor.readthedocs.io/en/latest/extending/pipeline.html for a bit more detail.\n",
"\n",
"As with anything remotely useful in Python, when it comes to work you want to [STAY OUT OF PYTHON](https://www.youtube.com/watch?v=vVUnCXKuNOg) as much as possible."
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "7d74a4d9-2079-4462-81a3-a4c17a271f29",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Cos [id A]\n",
" └─ Squeeze{axes=[0, 1]} [id B]\n",
" └─ Blockwise{dot, (m,k),(k,n)->(m,n)} [id C]\n",
" ├─ ExpandDims{axis=0} [id D]\n",
" │ └─ Exp [id E]\n",
" │ └─ Sin [id F]\n",
" │ └─ x [id G]\n",
" └─ ExpandDims{axis=1} [id H]\n",
" └─ Exp [id E]\n",
" └─ ···\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = pt.vector(\"x\", shape=(None,))\n",
"z = pt.exp(pt.sin(x))\n",
"out = pt.cos((z[None, :] @ z[:, None]).squeeze())\n",
"out.dprint()"
]
},
{
"cell_type": "markdown",
"id": "c9446b3e",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"As noted, use `pytensor.function` to compile a graph into an executable program. You need to pass a **list** of inputs (even if it has no inputs, you still need an empty list!), and outputs (this *can* be a list, but isn't required to be."
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "c98b61df45d6a2bb",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"outputs": [],
"source": [
"y_fn = pytensor.function([x], out)"
]
},
{
"cell_type": "markdown",
"id": "5e2727ea",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"What we get back is a `Function` object. This is a wrapper around a call out to a compiled `C` program that now lives in some cache folder somewhere on your computer."
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "5a8f3eca-51bf-4a3d-a656-85df87135c34",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"outputs": [
{
"data": {
"text/plain": [
"pytensor.compile.function.types.Function"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"type(y_fn)"
]
},
{
"cell_type": "markdown",
"id": "b8f7e69c",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"And now we can use it like any other python function"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "90f520a5-eec6-42ed-abad-26204cac5868",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"outputs": [
{
"data": {
"text/plain": [
"array(-0.72535991)"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_fn(np.random.randn(3))"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "3badd33e-ddb3-4aca-a27f-49275b1b994b",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"outputs": [
{
"data": {
"text/plain": [
"array(-0.98428505)"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_fn(np.random.randn(5))"
]
},
{
"cell_type": "markdown",
"id": "c7574ded",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"We can also inspect it to see the final graph that was compiled, after applying rewrites. Notice that `dot` became `CGemv`, so we're using the correct BLAS routine for the inputs provided. It also did some *loop fusion* by compiling a composite inner graph. That is, if we give an array input, rather than looping over it once to compute `exp` of each element, then looping over it again to compute `sin` of each element, we instead loop only once, and compute $\\exp \\circ \\sin$ of each element."
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "e300df411e0aff49",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Cos [id A] d={0: [0]} 5\n",
" └─ Squeeze{axis=0} [id B] 4\n",
" └─ CGemv{inplace} [id C] d={0: [0]} 3\n",
" ├─ AllocEmpty{dtype='float64'} [id D] 1\n",
" │ └─ 1 [id E]\n",
" ├─ 1.0 [id F]\n",
" ├─ ExpandDims{axis=0} [id G] 2\n",
" │ └─ Composite{exp(sin(i0))} [id H] 0\n",
" │ └─ x [id I]\n",
" ├─ Composite{exp(sin(i0))} [id H] 0\n",
" │ └─ ···\n",
" └─ 0.0 [id J]\n",
"\n",
"Inner graphs:\n",
"\n",
"Composite{exp(sin(i0))} [id H]\n",
" ← exp [id K] 'o0'\n",
" └─ sin [id L]\n",
" └─ i0 [id M]\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_fn.dprint(print_destroy_map=True) # Some memory aliasing optimizations"
]
},
{
"cell_type": "markdown",
"id": "0b4adffa",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"By passing `print_destroy_map=True` to `dprint`, we also get to see where memory buffers are being reused. This shows up in the graph with the line `d={x: [y]}`, where \"x\" is an output, and \"y\" is an input that will be re-used as a buffer. We can see this in two places:\n",
"\n",
"- In the `CGemv` line, we have `d={0: [0]}`. This means that the first output is being allocated to the memory used for the first input. In this case, we save allocation of a length 1 array, which isn't so impressive.\n",
"\n",
"- More interestingly, we see it at the top line, `Cos [id A] d = {0: [0]}`. This means that we are doing the `Cos` directly on the squeezed output of the CGemv operation! This is equivalent to `np.cos(x, out=x)`, which qucikly can become unreadble. "
]
},
{
"cell_type": "markdown",
"id": "5936a640-43ca-4a37-85e5-66feeda4558f",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"PyTensor can also delegate compilation to other libraries in town, namely Numba, JAX, and PyTorch (latter still under active development). \n",
"\n",
"Notice that this is another kind of graph-to-\"graph\" transformation, that we can only do because we have access to the whole static computation. Every individual `Op` knows what it means to become a Numba program, so we can just walk across the graph and generate the appropriate code.\n",
"\n",
"Sometimes, this means that certain rewrites aren't applied! In this case, Numba does it's own BLAS optimizations. So we don't rewrite `dot -> CGemv`, we just leave it as dot."
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "719fdfbe-790d-41c5-9d4a-398697f01926",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Cos [id A] d={0: [0]} 5\n",
" └─ Squeeze{axes=[0, 1]} [id B] 4\n",
" └─ dot [id C] 3\n",
" ├─ ExpandDims{axis=0} [id D] 2\n",
" │ └─ Composite{exp(sin(i0))} [id E] 0\n",
" │ └─ x [id F]\n",
" └─ ExpandDims{axis=1} [id G] 1\n",
" └─ Composite{exp(sin(i0))} [id E] 0\n",
" └─ ···\n",
"\n",
"Inner graphs:\n",
"\n",
"Composite{exp(sin(i0))} [id E]\n",
" ← exp [id H] 'o0'\n",
" └─ sin [id I]\n",
" └─ i0 [id J]\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_numba_fn = pytensor.function([x], out, mode=\"NUMBA\")\n",
"y_numba_fn.dprint(print_destroy_map=True)"
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "f9fcb852-49b3-4442-86e8-eccab65f3834",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 98 ms, sys: 8.58 ms, total: 107 ms\n",
"Wall time: 110 ms\n"
]
},
{
"data": {
"text/plain": [
"array(-0.97979355)"
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"\n",
"y_numba_fn(\n",
" np.random.randn(3)\n",
") # first time takes long, jit compilation actually happening"
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "b007790d-fa14-4ac4-823e-94cf2df15e35",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 136 μs, sys: 71 μs, total: 207 μs\n",
"Wall time: 243 μs\n"
]
},
{
"data": {
"text/plain": [
"array(0.99905205)"
]
},
"execution_count": 43,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"\n",
"y_numba_fn(np.random.randn(5))"
]
},
{
"cell_type": "markdown",
"id": "eab1d41c",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## A Full Example: Logistic Regression with Gradient Descent\n",
"\n",
"To show everything in action together, let's look at how we would write a binary classification model, how we can train it using gradient descent, and how we can use the final program we get back.\n",
"\n",
"First, let's set up the symbolic inputs. These will be the input data $X$ and targets $y$, as well as the initial values for alpha and beta. Our model will be:\n",
"\n",
"$$ p = \\sigma^{-1}(\\alpha + X \\beta) $$\n",
"\n",
"Where $\\sigma$ is the logistic function.\n",
"\n",
"And we'll choose parameters $\\alpha$ and $\\beta$ to minimize the binary cross entropy between $p$ and the target labels, which will just be:\n",
"\n",
"$$\\mathcal{L} = -\\frac{1}{N} \\sum_{i=0}^N y_i \\log(p_i) + (1 - y_i) \\log(1 - p_i)$$\n",
"\n",
"If there are any Bayesians in the audience, you will also recognize this as the negative log-likelihood of a Bernoulli GLM"
]
},
{
"cell_type": "code",
"execution_count": 44,
"id": "e448cd31",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [],
"source": [
"X_pt = pt.tensor(\"X\", shape=(None, None))\n",
"y_pt = pt.tensor(\"y\", shape=(None,))\n",
"alpha_pt = pt.tensor(\"alpha\", shape=())\n",
"beta_pt = pt.tensor(\"beta\", shape=(None,))\n",
"\n",
"p = pt.sigmoid(alpha_pt + X_pt @ beta_pt)\n",
"p.name = \"p_class_0\"\n",
"\n",
"loss = -(y_pt * pt.log(p) + (1 - y_pt) * pt.log(1 - p)).mean()\n",
"loss.name = \"cross_entropy\""
]
},
{
"cell_type": "markdown",
"id": "eb622fa5",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"To convince you that the text version of dprint is nicer, here's the big graphic plot for our loss function:"
]
},
{
"cell_type": "code",
"execution_count": 45,
"id": "d52561b6",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"outputs": [
{
"data": {
"image/svg+xml": [
""
],
"text/plain": [
""
]
},
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"SVG(pydotprint(loss, return_image=True, format=\"svg\"))"
]
},
{
"cell_type": "markdown",
"id": "77b258df",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Comparsed to the text verison. It pays dividends to get used to reading these things!"
]
},
{
"cell_type": "code",
"execution_count": 46,
"id": "e793a2bc",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Neg [id A] 'cross_entropy'\n",
" └─ True_div [id B] 'mean'\n",
" ├─ Sum{axes=None} [id C]\n",
" │ └─ Add [id D]\n",
" │ ├─ Mul [id E]\n",
" │ │ ├─ y [id F]\n",
" │ │ └─ Log [id G]\n",
" │ │ └─ Sigmoid [id H] 'p_class_0'\n",
" │ │ └─ Add [id I]\n",
" │ │ ├─ ExpandDims{axis=0} [id J]\n",
" │ │ │ └─ alpha [id K]\n",
" │ │ └─ Squeeze{axis=1} [id L]\n",
" │ │ └─ Blockwise{dot, (m,k),(k,n)->(m,n)} [id M]\n",
" │ │ ├─ X [id N]\n",
" │ │ └─ ExpandDims{axis=1} [id O]\n",
" │ │ └─ beta [id P]\n",
" │ └─ Mul [id Q]\n",
" │ ├─ Sub [id R]\n",
" │ │ ├─ ExpandDims{axis=0} [id S]\n",
" │ │ │ └─ 1 [id T]\n",
" │ │ └─ y [id F]\n",
" │ └─ Log [id U]\n",
" │ └─ Sub [id V]\n",
" │ ├─ ExpandDims{axis=0} [id W]\n",
" │ │ └─ 1 [id X]\n",
" │ └─ Sigmoid [id H] 'p_class_0'\n",
" │ └─ ···\n",
" └─ Subtensor{i} [id Y]\n",
" ├─ Cast{float64} [id Z]\n",
" │ └─ Shape [id BA]\n",
" │ └─ Add [id D]\n",
" │ └─ ···\n",
" └─ 0 [id BB]\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"loss.dprint()"
]
},
{
"cell_type": "markdown",
"id": "33a6398f",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"To compute an update step, we will use a simple gradient descent algorithm, defined as:\n",
"\n",
"$$\\theta^\\prime = \\theta - \\eta \\nabla \\mathcal{L}(\\theta)$$"
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "d24ad49c",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"outputs": [],
"source": [
"# The learning rate is another root variable we need to provide\n",
"eta_pt = pt.tensor(\"learning_rate\", shape=())\n",
"\n",
"\n",
"# We can compute gradients for a list of varibales too!\n",
"# It's also good to apply some graph simplification before taking gradients. PyMC does this\n",
"# by default, for example.\n",
"d_alpha, d_beta = grad(\n",
" rewrite_graph(loss, include=(\"canonicalize\", \"stabilize\")), wrt=[alpha_pt, beta_pt]\n",
")\n",
"\n",
"# Apply gradient updates\n",
"alpha_prime = alpha_pt - eta_pt * d_alpha\n",
"beta_prime = beta_pt - eta_pt * d_beta"
]
},
{
"cell_type": "markdown",
"id": "7e963af4",
"metadata": {},
"source": [
"Compile our functions. We can make one for training, and one for prediction. Here we see another example of a PyTensor superpower: we don't have to know ahead of time how we want to use different outputs. We do all the steps of computation symbolically, then only at the end decide what will be used where. This type of thinking ahead is especially important when you're writing programs in Numba. You will need to think about what units of computation can be decomposed and jitted, for reuse in later jitted functions. We see that PyTensor does this reasoning for us.\n",
"\n",
"\n",
"Getting to this point, we realized that we can output class probabilities, but not class labels. It's easy to make one more root variable for the prediction function, representing a prediction threshold for membership in class 1. Then we can make a `y_hat` variable and return it."
]
},
{
"cell_type": "code",
"execution_count": 48,
"id": "10903335",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"outputs": [],
"source": [
"training_fn = pytensor.function(\n",
" [X_pt, y_pt, eta_pt, alpha_pt, beta_pt], [loss, alpha_prime, beta_prime]\n",
")\n",
"\n",
"threshold = pt.tensor(\"threshold\", shape=())\n",
"y_hat = (p > threshold).astype(int)\n",
"predict_fn = pytensor.function([X_pt, alpha_pt, beta_pt, threshold], [p, y_hat])"
]
},
{
"cell_type": "markdown",
"id": "1480f470",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"It's worth having a look at the dprint of `training_fn` to see a couple features of pytensor we haven't run into yet.\n",
"\n",
"1. Sub-computations are automatically re-used! The first output, `Composite{...}.1 [id A] 9` is the cross entropy. It gets computed once (it's the first return, after all), but then it is also used in the 2nd return, to compute `alpha_prime`.\n",
"2. `Composite{...}.0 [id I] 4` is the probability of class 0. That also gets computed once and re-used several times.\n",
"3. Unused computations get truncated! Notice that `p_class_0` never appears as such. Instead, we got a composite sub-graph that compute that quantity given inputs. The thing itself was never requested, and is thus never used."
]
},
{
"cell_type": "code",
"execution_count": 49,
"id": "5e899f55",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Composite{...}.1 [id A] 'cross_entropy' d={0: [1]} 9\n",
" ├─ Assert{msg='Could not broadcast dimensi...'} [id B] 8\n",
" │ ├─ Shape_i{0} [id C] 0\n",
" │ │ └─ X [id D] \n",
" │ └─ Eq [id E] 7\n",
" │ ├─ Shape_i{0} [id C] 0\n",
" │ │ └─ ···\n",
" │ └─ Shape_i{0} [id F] 6\n",
" │ └─ y [id G] \n",
" └─ Sum{axes=None} [id H] 5\n",
" └─ Composite{...}.0 [id I] d={0: [1]} 4\n",
" ├─ ExpandDims{axis=0} [id J] 3\n",
" │ └─ alpha [id K] \n",
" ├─ CGemv{inplace} [id L] d={0: [0]} 2\n",
" │ ├─ AllocEmpty{dtype='float64'} [id M] 1\n",
" │ │ └─ Shape_i{0} [id C] 0\n",
" │ │ └─ ···\n",
" │ ├─ 1.0 [id N] \n",
" │ ├─ X [id D] \n",
" │ ├─ beta [id O] \n",
" │ └─ 0.0 [id P] \n",
" └─ y [id G] \n",
"Composite{(i2 - (i0 * i1))} [id Q] d={0: [1]} 13\n",
" ├─ learning_rate [id R] \n",
" ├─ Sum{axes=None} [id S] 12\n",
" │ └─ Composite{((i2 / i1) + (i0 / i1))} [id T] d={0: [0]} 11\n",
" │ ├─ Composite{...}.1 [id I] d={0: [1]} 4\n",
" │ │ └─ ···\n",
" │ ├─ ExpandDims{axis=0} [id U] 10\n",
" │ │ └─ Composite{...}.0 [id A] d={0: [1]} 9\n",
" │ │ └─ ···\n",
" │ └─ Composite{...}.2 [id I] d={0: [1]} 4\n",
" │ └─ ···\n",
" └─ alpha [id K] \n",
"CGemv{no_inplace} [id V] 16\n",
" ├─ beta [id O] \n",
" ├─ Neg [id W] 15\n",
" │ └─ learning_rate [id R] \n",
" ├─ Transpose{axes=[1, 0]} [id X] 'X.T' 14\n",
" │ └─ X [id D] \n",
" ├─ Composite{((i2 / i1) + (i0 / i1))} [id T] d={0: [0]} 11\n",
" │ └─ ···\n",
" └─ 1.0 [id N] \n",
"\n",
"Inner graphs:\n",
"\n",
"Composite{...} [id A] d={0: [1]}\n",
" ← Cast{float64} [id Y] 'o0'\n",
" └─ i0 [id Z] \n",
" ← neg [id BA] 'o1'\n",
" └─ true_div [id BB] \n",
" ├─ i1 [id BC] \n",
" └─ Cast{float64} [id Y] 'o0'\n",
" └─ ···\n",
"\n",
"Composite{...} [id I] d={0: [1]}\n",
" ← add [id BD] 'o0'\n",
" ├─ mul [id BE] \n",
" │ ├─ -1.0 [id BF] \n",
" │ ├─ i2 [id BG] \n",
" │ └─ scalar_softplus [id BH] \n",
" │ └─ neg [id BI] 't11'\n",
" │ └─ add [id BJ] 't3'\n",
" │ ├─ i0 [id BK] \n",
" │ └─ i1 [id BL] \n",
" └─ mul [id BM] \n",
" ├─ -1.0 [id BF] \n",
" ├─ sub [id BN] 't14'\n",
" │ ├─ 1.0 [id BO] \n",
" │ └─ i2 [id BG] \n",
" └─ scalar_softplus [id BP] \n",
" └─ add [id BJ] 't3'\n",
" └─ ···\n",
" ← mul [id BQ] 'o1'\n",
" ├─ sub [id BN] 't14'\n",
" │ └─ ···\n",
" └─ sigmoid [id BR] \n",
" └─ add [id BJ] 't3'\n",
" └─ ···\n",
" ← mul [id BS] 'o2'\n",
" ├─ -1.0 [id BF] \n",
" ├─ i2 [id BG] \n",
" └─ sigmoid [id BT] \n",
" └─ neg [id BI] 't11'\n",
" └─ ···\n",
"\n",
"Composite{(i2 - (i0 * i1))} [id Q] d={0: [1]}\n",
" ← sub [id BU] 'o0'\n",
" ├─ i2 [id BV] \n",
" └─ mul [id BW] \n",
" ├─ i0 [id BX] \n",
" └─ i1 [id BY] \n",
"\n",
"Composite{((i2 / i1) + (i0 / i1))} [id T] d={0: [0]}\n",
" ← add [id BZ] 'o0'\n",
" ├─ true_div [id CA] \n",
" │ ├─ i2 [id CB] \n",
" │ └─ i1 [id CC] \n",
" └─ true_div [id CD] \n",
" ├─ i0 [id CE] \n",
" └─ i1 [id CC] \n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 49,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"training_fn.dprint(print_type=True, print_destroy_map=True)"
]
},
{
"cell_type": "markdown",
"id": "6c880ffa",
"metadata": {},
"source": [
"The last point about truncation is even more clear when we look at the graph for `predict_fn`. Now the loss is nowhere to be seen! We just compute exactly what was requested.\n",
"\n",
"Note that we also have a nice example of re-using computation. We compute `p_class_0` exactly once, then re-use it in the 2nd output."
]
},
{
"cell_type": "code",
"execution_count": 50,
"id": "1a50f419",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Composite{...}.0 [id A] 'p_class_0' 5\n",
" ├─ ExpandDims{axis=0} [id B] 4\n",
" │ └─ alpha [id C] \n",
" ├─ CGemv{inplace} [id D] 3\n",
" │ ├─ AllocEmpty{dtype='float64'} [id E] 2\n",
" │ │ └─ Shape_i{0} [id F] 1\n",
" │ │ └─ X [id G] \n",
" │ ├─ 1.0 [id H] \n",
" │ ├─ X [id G] \n",
" │ ├─ beta [id I] \n",
" │ └─ 0.0 [id J] \n",
" └─ ExpandDims{axis=0} [id K] 0\n",
" └─ threshold [id L] \n",
"Composite{...}.1 [id A] 5\n",
" └─ ···\n",
"\n",
"Inner graphs:\n",
"\n",
"Composite{...} [id A]\n",
" ← sigmoid [id M] 'o0'\n",
" └─ add [id N] \n",
" ├─ i0 [id O] \n",
" └─ i1 [id P] \n",
" ← Cast{int64} [id Q] 'o1'\n",
" └─ GT [id R] \n",
" ├─ sigmoid [id M] 'o0'\n",
" │ └─ ···\n",
" └─ i2 [id S] \n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 50,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predict_fn.dprint(print_type=True)"
]
},
{
"cell_type": "markdown",
"id": "615f13ef",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Training"
]
},
{
"cell_type": "code",
"execution_count": 51,
"id": "b57e8e42",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"outputs": [],
"source": [
"import numpy as np\n",
"from sklearn.datasets import make_classification\n",
"\n",
"\n",
"seed = sum(map(ord, \"I <3 Pytensor\"))\n",
"rng = np.random.default_rng(seed)\n",
"\n",
"X, y = make_classification(\n",
" n_samples=100,\n",
" n_features=5,\n",
" n_redundant=0,\n",
" n_informative=5,\n",
" n_classes=2,\n",
" random_state=seed,\n",
")\n",
"\n",
"# Initial values for parameteres\n",
"beta = rng.normal(0, 1, size=(5,))\n",
"alpha = rng.normal(0, 1)\n",
"\n",
"learning_rate = 1e-1"
]
},
{
"cell_type": "code",
"execution_count": 52,
"id": "f4400eb1",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [],
"source": [
"n_steps = 1_000\n",
"histories = [np.empty(n_steps), np.empty(n_steps), np.empty((n_steps, 5))]\n",
"\n",
"for t in range(n_steps):\n",
" loss_val, alpha, beta = training_fn(X, y, learning_rate, alpha, beta)\n",
" histories[0][t] = loss_val\n",
" histories[1][t] = alpha\n",
" histories[2][t] = beta"
]
},
{
"cell_type": "code",
"execution_count": 53,
"id": "58052478",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
""
]
},
"metadata": {
"image/png": {
"height": 1011,
"width": 811
}
},
"output_type": "display_data"
}
],
"source": [
"fig, ax = plt.subplots(\n",
" nrows=3,\n",
" ncols=1,\n",
" figsize=(8, 10),\n",
" sharex=True,\n",
" layout=\"constrained\",\n",
")\n",
"for axis, data, name in zip(\n",
" fig.axes,\n",
" histories,\n",
" [\"cross-entropy\", \"alpha\", \"beta\"],\n",
"):\n",
" axis.plot(data)\n",
" axis.set(title=name)\n",
"\n",
"axis.set(xlabel=\"steps\")\n",
"\n",
"fig.suptitle(\"Training\", fontsize=18, fontweight=\"bold\");"
]
},
{
"cell_type": "markdown",
"id": "af29f666",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"`make_classification` doesn't return true parameters to check against, so we'll just look at the confusion matrix. Seems like we learned something. It's a toy problem anyway!"
]
},
{
"cell_type": "code",
"execution_count": 54,
"id": "cacdee5b",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"outputs": [
{
"data": {
"text/plain": [
"array([[36, 14],\n",
" [11, 39]])"
]
},
"execution_count": 54,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.metrics import confusion_matrix\n",
"\n",
"\n",
"confusion_matrix(y, predict_fn(X, alpha, beta, 0.5)[1])"
]
},
{
"cell_type": "markdown",
"id": "7fbb667d",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## Taking a step back\n",
"\n"
]
},
{
"cell_type": "markdown",
"id": "f1338584",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### How does it compare with alternative frameworks\n",
"\n",
"* Graph is built explicitly with placeholder inputs (common source of confusion for users)\n",
"* It is focused on array (tensor) operations (dense and sparse). Tries to look almost like numpy / scipy, (until the abstraction breaks).\n",
" * There is narrow / hidden support for other types like scalars, lists, slices, random Generators, strings, None (although easy to extend)\n",
"* Functional design (there is no variable mutation when defining graphs)\n",
"* Strong focus on hackability / graph manipulation\n",
"* Evolved from:\n",
" 1. Theano which strongly inspired Tensorflow 1.x and JAX. Many concepts stood the test of time. Others have aged and provide some drag.\n",
" 2. Aesara, which cleaned up the codebase, added alternative backends (Numba and JAX) and proved there's some interest out there in a library like this."
]
},
{
"cell_type": "markdown",
"id": "12420067",
"metadata": {},
"source": [
"We hope you enjoyed this introduction to PyTensor. This is just the beginning of what you can do with it. Please explore the documentation and the gallery for more examples and applications."
]
},
{
"cell_type": "markdown",
"id": "556202cf",
"metadata": {},
"source": [
"## Authors\n",
"\n",
"- Jesse Grabowski and Ricardo Vieira in August 2025"
]
},
{
"cell_type": "markdown",
"id": "f06cf7a2",
"metadata": {},
"source": [
"## References\n",
"\n",
":::{bibliography} :filter: docname in docnames"
]
},
{
"cell_type": "markdown",
"id": "4b3d194f",
"metadata": {},
"source": [
"## Watermark "
]
},
{
"cell_type": "code",
"execution_count": 55,
"id": "8277fb12",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Last updated: Sun Aug 17 2025\n",
"\n",
"Python implementation: CPython\n",
"Python version : 3.12.11\n",
"IPython version : 9.4.0\n",
"\n",
"pytensor: 2.31.7\n",
"\n",
"pymc : 5.25.1\n",
"scipy : 1.16.1\n",
"pytensor : 2.31.7\n",
"IPython : 9.4.0\n",
"sklearn : 1.7.1\n",
"matplotlib: 3.10.5\n",
"graphviz : 0.21\n",
"numpy : 2.2.6\n",
"\n",
"Watermark: 2.5.0\n",
"\n"
]
}
],
"source": [
"%load_ext watermark\n",
"%watermark -n -u -v -iv -w -p pytensor"
]
}
],
"metadata": {
"celltoolbar": "Slideshow",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}