{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "e4cfac6e-a4f5-4505-afb2-d4e84b98dcf1", "metadata": {}, "outputs": [], "source": [ "# import\n", "from progeval import ProgEval" ] }, { "cell_type": "markdown", "id": "0bcff5fd-aa2e-47dc-912b-53b45390df21", "metadata": { "tags": [] }, "source": [ "# How-to guide\n", "\n", "Let's consider the following computation:\n", "\n", "![toy computation](toy-computation.svg)" ] }, { "cell_type": "markdown", "id": "a3226152-fbab-485a-9dcb-01273a5ae66f", "metadata": {}, "source": [ "## Dynamically construct graph\n", "The most straight-forward way to define the computational graph is by assigning to a `ProgEval()` object.\n", "If a callable object is assigned, it is automatically interpreted as specifying how to compute the given quantity.\n", "Dependent quantities are detected based on the argument names of the function." ] }, { "cell_type": "code", "execution_count": null, "id": "f6640e52-d0c1-4e51-a7e0-422724a0f416", "metadata": {}, "outputs": [], "source": [ "def compute_alpha(x):\n", " print(f'computing alpha = 2 * {x} + 1')\n", " return 2 * x + 1\n", "\n", "def compute_beta(y):\n", " print(f'computing beta = {y} * {y}')\n", " return y * y" ] }, { "cell_type": "code", "execution_count": null, "id": "86e573f6-a3a1-40c6-a99c-1d18601a6377", "metadata": {}, "outputs": [], "source": [ "graph = ProgEval()\n", "\n", "graph.alpha = compute_alpha\n", "graph.beta = compute_beta\n", "# any callable object works\n", "graph.gamma = lambda alpha, beta, y: alpha * beta - y" ] }, { "cell_type": "markdown", "id": "b36ae47d-99c7-4f8e-be52-848f91537d77", "metadata": {}, "source": [ "Having constructed the graph, we can set input values and compute the outputs" ] }, { "cell_type": "code", "execution_count": null, "id": "06c3e956-8c84-48ec-810a-17f0877c6999", "metadata": {}, "outputs": [], "source": [ "graph.x, graph.y = 3, 4\n", "graph.gamma" ] }, { "cell_type": "markdown", "id": "66b5b6f3-8a06-4f8a-ab46-9fab4c44aa8e", "metadata": {}, "source": [ "If we request intermediate values now, they are not computed again (note there is no printed message)!" ] }, { "cell_type": "code", "execution_count": null, "id": "abf29122-fa90-469e-9099-19130e259b80", "metadata": {}, "outputs": [], "source": [ "graph.beta" ] }, { "cell_type": "markdown", "id": "add7a55f-b55c-4e6a-b47b-4b018225554b", "metadata": {}, "source": [ "### Evaluate everything\n", "We can evluate and collect all quantities in the graph by invoking `compute_all_quantities`:" ] }, { "cell_type": "code", "execution_count": null, "id": "81c597d2-ae31-42fa-a000-48362cfedd89", "metadata": {}, "outputs": [], "source": [ "graph.compute_all_quantities()" ] }, { "cell_type": "markdown", "id": "56e0d259-24c6-4785-b914-e647d1fa6892", "metadata": {}, "source": [ "### Clear cache\n", "By removing all saved intermediate values, we can force the computational graph to be recomputed in full." ] }, { "cell_type": "code", "execution_count": null, "id": "63493c4d-9f4a-4894-a7d1-ef8bafde14c6", "metadata": {}, "outputs": [], "source": [ "graph.clear_cache()\n", "graph.compute_all_quantities()" ] }, { "cell_type": "markdown", "id": "a6a77385-21b1-43b6-9cd7-0cccfa169344", "metadata": {}, "source": [ "### Changing the computational graph\n", "When we override the input values, only those quantities that depend on the changes will be re-computed (no printed message for alpha):" ] }, { "cell_type": "code", "execution_count": null, "id": "13635357-ae95-4066-8a8b-a40a1424bf29", "metadata": {}, "outputs": [], "source": [ "graph.y = 8\n", "graph.gamma" ] }, { "cell_type": "markdown", "id": "b94d31b4-f382-48c8-8c71-923a1b646875", "metadata": {}, "source": [ "Besides re-assigning values to the inputs, we can also change the structure of the graph itself." ] }, { "cell_type": "code", "execution_count": null, "id": "8802f343-6288-4302-8d0e-d3eab776220e", "metadata": {}, "outputs": [], "source": [ "graph.gamma = lambda alpha, beta: alpha - beta\n", "graph.gamma" ] }, { "cell_type": "markdown", "id": "d6724ace-22ba-4c1b-939a-0fd0fcdbcca4", "metadata": {}, "source": [ "### Disabling recomputation\n", "If for any reason the re-computation of values is not desired, it can be disabled by specifying `track_dependence=False`.\n", "In that case, the graph no longer registers which quantities are requested in the different computations." ] }, { "cell_type": "code", "execution_count": null, "id": "c27aa50a-1549-42a1-88e3-5cbf0976cdb0", "metadata": {}, "outputs": [], "source": [ "graph = ProgEval(track_dependence=False)\n", "\n", "graph.alpha = compute_alpha\n", "graph.beta = compute_beta\n", "graph.gamma = lambda alpha, beta, y: alpha * beta - y\n", "\n", "graph.x, graph.y = 3, 4\n", "graph.gamma" ] }, { "cell_type": "code", "execution_count": null, "id": "292c5111-082c-4507-9f51-622c32108dc5", "metadata": {}, "outputs": [], "source": [ "graph.y = 8\n", "graph.gamma # now, no change" ] }, { "cell_type": "markdown", "id": "b1690cd1-5ab7-4c8f-8a5b-57b685d74131", "metadata": {}, "source": [ "### Specifying input arguments\n", "Above, the inputs to the node functions are derived from their call signature.\n", "Instead, it is also possible to explicitly pass their names." ] }, { "cell_type": "code", "execution_count": null, "id": "51a0f353-1c59-4dd6-85ac-58a18ca3860c", "metadata": {}, "outputs": [], "source": [ "def prod(a, b):\n", " return a * b\n", "\n", "graph = ProgEval()\n", "graph.register('beta', prod, ['y', 'y'])\n", "\n", "graph.y = 5\n", "graph.beta" ] }, { "cell_type": "markdown", "id": "eacb2b1c-38b2-49b3-8e2d-f7f1502c1e3c", "metadata": { "tags": [] }, "source": [ "## Define computations as classes" ] }, { "cell_type": "markdown", "id": "5f17c004-faa4-49fd-a379-1b2bbf9b2ca9", "metadata": {}, "source": [ "Instead of defining computational graphs by assinging nodes to a ProgEval object, we can also define a new class that represents the computation.\n", "This can be nice for two reasons:\n", "1. All functions/quantities are in one place and are registered automatically.\n", "2. We can easily specify all input values and efficiently creat the corresponding graph.\n", "\n", "The only thing we need to do is to sub-class `ProgEval`." ] }, { "cell_type": "code", "execution_count": null, "id": "53e11010-730c-4eb1-a5d0-30c5f0015414", "metadata": {}, "outputs": [], "source": [ "class MyComputation(ProgEval):\n", " \n", " # this says the function below does not have a `self` argument\n", " @staticmethod\n", " def alpha(x):\n", " print(f'computing alpha = 2 * {x} + 1')\n", " return 2 * x + 1\n", "\n", " @staticmethod\n", " def beta(y):\n", " print(f'computing beta = {y} * {y}')\n", " return y * y\n", " \n", " @staticmethod\n", " def gamma(y, alpha, beta):\n", " print(f'computing gamma = {alpha} * {beta} - {y}')\n", " return alpha * beta - y" ] }, { "cell_type": "code", "execution_count": null, "id": "31aee617-286c-40bf-ab4a-cae398a177e3", "metadata": {}, "outputs": [], "source": [ "comp = MyComputation()" ] }, { "cell_type": "markdown", "id": "b193ae64-9254-4dbd-862f-262a1a007a5f", "metadata": {}, "source": [ "To evaluate it, we must assign the input values `x` and `y`:" ] }, { "cell_type": "code", "execution_count": null, "id": "0e958d31-b7c7-4618-a694-ddfbb51e3929", "metadata": {}, "outputs": [], "source": [ "comp.x, comp.y = 5, 3\n", "comp.gamma" ] }, { "cell_type": "markdown", "id": "3ec218dd-c70f-42fd-a5c6-8edb3ae33cde", "metadata": {}, "source": [ "The strucuture can be made even cleaner by taking the inputs of the computations as inputs when creating the graph." ] }, { "cell_type": "code", "execution_count": null, "id": "cc2796cd-c984-4f78-8fdd-49a36dce5dec", "metadata": {}, "outputs": [], "source": [ "class MyComputation(ProgEval):\n", " \n", " def __init__(self, x, y):\n", " super().__init__(x=x, y=y)\n", " \n", " @staticmethod\n", " def alpha(x):\n", " print(f'computing alpha = 2 * {x} + 1')\n", " return 2 * x + 1\n", "\n", " @staticmethod\n", " def beta(y):\n", " print(f'computing beta = {y} * {y}')\n", " return y * y\n", " \n", " @staticmethod\n", " def gamma(y, alpha, beta):\n", " print(f'computing gamma = {alpha} * {beta} - {y}')\n", " return alpha * beta - y" ] }, { "cell_type": "code", "execution_count": null, "id": "c62c0783-64bd-4dee-b37e-a7d113785c09", "metadata": {}, "outputs": [], "source": [ "MyComputation(5, 3).gamma" ] }, { "cell_type": "code", "execution_count": null, "id": "8e41f84e-059c-4caa-a147-8eb60972716f", "metadata": {}, "outputs": [], "source": [ "MyComputation(4, 5).gamma" ] }, { "cell_type": "markdown", "id": "32bf60e3-e6ce-461d-86d1-3645164f9755", "metadata": {}, "source": [ "In this setting, recomputation may not be required (since one would just call with different inputs instead of replacing `x` and `y`).\n", "Dependency tracking can be turned off by setting `class MyComputation(ProgEval, track_dependence=False)` in the first line." ] }, { "cell_type": "markdown", "id": "14e09752-7312-4ba1-9197-299354b056fe", "metadata": {}, "source": [ "### Accessing quantities as attributes of ``self``" ] }, { "cell_type": "markdown", "id": "1711308d-7466-4823-8e53-696679f22fa6", "metadata": {}, "source": [ "In the above examples, the dependencies of quantities were made explicit by the arguments the functions take.\n", "It is also possible to have methods that are not a `staticmethod`, i.e. that access quantities as attributes of `self`.\n", "However:" ] }, { "cell_type": "markdown", "id": "f5bf840d-1742-4889-8890-5a01484d9922", "metadata": {}, "source": [ "```{eval-rst}\n", ".. warning::\n", " If a method accesses computational quantities as attributes of ``self`` (instead of explicit arguments), the dependencies in the computational graph can currently not be tracked.\n", " That means quantities are not properly recomputed when intermediate values are changed.\n", " This is only a problem if the computational graph is changed, i.e. if nodes are replaced or deleted, after it was created.\n", "```" ] }, { "cell_type": "code", "execution_count": null, "id": "fbfb2a34-7767-4aab-83eb-e3dcc2825d8a", "metadata": {}, "outputs": [], "source": [ "class MyComputation(ProgEval):\n", " \n", " def __init__(self, x, y):\n", " super().__init__(x=x, y=y)\n", "\n", " def alpha(self):\n", " print(f'computing alpha = 2 * {self.x} + 1')\n", " return 2 * self.x + 1\n", "\n", " def beta(self):\n", " print(f'computing beta = {self.y} * {self.y}')\n", " return self.y * self.y\n", " \n", " def gamma(self):\n", " print(f'computing gamma = {self.alpha} * {self.beta} - {self.y}')\n", " return self.alpha * self.beta - self.y" ] }, { "cell_type": "code", "execution_count": null, "id": "006da673-69be-49b4-a43b-017345376076", "metadata": {}, "outputs": [], "source": [ "MyComputation(5, 3).gamma" ] }, { "cell_type": "code", "execution_count": null, "id": "5f0c597e-4a4b-4315-a832-128af3cfb12c", "metadata": {}, "outputs": [], "source": [ "MyComputation(4, 5).gamma" ] }, { "cell_type": "markdown", "id": "01e7dfff-23d7-4742-83f4-379aaa7ec810", "metadata": {}, "source": [ "## Advanced: transforming functions\n", "It is possible to specify an optional `transformer` when constructing the computational graph, which can modify the node functions before they are added. \n", "It must take three arguments: `transformer(function, static, name)`.\n", "The first is the function which is used to compute the quantity with the given `name`.\n", "`static` is a boolean value which is false if the function takes `self` as the first argument.\n", "\n", "The output should be a function of the same signature. If the signature is changed, the output must be a tuple of the transformed function and the new signature as an instance of type `inspect.Signature`.\n", "\n", "Below are two examples of how this can be used.\n", "They require [JAX](https://github.com/google/jax/) and [Dask](https://docs.dask.org/) to be installed, respectively." ] }, { "cell_type": "markdown", "id": "db984420-a2b4-43f3-b7e0-af78c7d2070b", "metadata": {}, "source": [ "### Just in time compilation with JAX" ] }, { "cell_type": "code", "execution_count": null, "id": "6e676dc8-c0c5-4d12-9729-1aa613f6c614", "metadata": {}, "outputs": [], "source": [ "import jax" ] }, { "cell_type": "code", "execution_count": null, "id": "c40466f2-f008-4d36-be8a-705855dec3d7", "metadata": {}, "outputs": [], "source": [ "def jit_if_static(fun, static, name):\n", " # only jit compile if the function doesn't depend on self\n", " if static:\n", " return jax.jit(fun)\n", " return fun\n", "\n", "\n", "class Computation(ProgEval):\n", " \n", " def __init__(self, x, y):\n", " super().__init__(x=x, y=y)\n", " \n", " @staticmethod\n", " def alpha(x, y):\n", " return jax.numpy.trace(x @ x) * jax.numpy.trace(y)\n", " \n", " @staticmethod\n", " def beta(x, y, alpha):\n", " return jax.numpy.trace(x @ y) * alpha\n", " \n", " @staticmethod\n", " def total(alpha, beta):\n", " return (alpha + beta) / alpha.size" ] }, { "cell_type": "markdown", "id": "b9cd7597-b09e-4e91-b124-627a641ec412", "metadata": {}, "source": [ "The above construction only makes a noticeable difference if the individual functions are sufficiently costly.\n", "Another useful pattern with JAX is that we can define efficient function for parts of the computational tree we are interested in, without repeating code." ] }, { "cell_type": "code", "execution_count": null, "id": "dbf97003-5907-40e7-bd6d-1d6fba42f9d9", "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", "def compute_alpha(x, y):\n", " return Computation(x, y).alpha\n", "\n", "@jax.jit\n", "def compute_beta(x, y):\n", " return Computation(x, y).beta\n", "\n", "@jax.jit\n", "def compute_total(x, y):\n", " return Computation(x, y).total" ] }, { "cell_type": "code", "execution_count": null, "id": "41ce6007-a065-4941-abe8-385f8f32dcc6", "metadata": {}, "outputs": [], "source": [ "rng = jax.random.PRNGKey(0)\n", "x, y = jax.random.normal(rng, (2, 32, 32))\n", "\n", "%timeit compute_alpha(x, y).block_until_ready()" ] }, { "cell_type": "code", "execution_count": null, "id": "0d1e3a6a-578c-4fb8-af55-b388277dcfe9", "metadata": {}, "outputs": [], "source": [ "%timeit compute_beta(x, y).block_until_ready()" ] }, { "cell_type": "code", "execution_count": null, "id": "1fda84d1-75a9-4e66-a680-8c4df5018cf2", "metadata": {}, "outputs": [], "source": [ "%timeit compute_total(x, y).block_until_ready()" ] }, { "cell_type": "markdown", "id": "f8939c16-8832-4d33-aa41-57c664600fbb", "metadata": {}, "source": [ "Comparing this with a manual implementation, we see that the construction via the computational graph has virtually no cost after jit-compilation." ] }, { "cell_type": "code", "execution_count": null, "id": "765ec2f0-1be2-4d22-8ee7-86e7e618f5b5", "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", "def computation(x, y):\n", " alpha = jax.numpy.trace(x @ x) * jax.numpy.trace(y)\n", " beta = jax.numpy.trace(x @ y) * alpha\n", " total = (alpha + beta) / alpha.size\n", " return total" ] }, { "cell_type": "code", "execution_count": null, "id": "39bf402c-3f81-441a-8f39-e339ed316eff", "metadata": {}, "outputs": [], "source": [ "%timeit computation(x, y).block_until_ready()" ] }, { "cell_type": "markdown", "id": "05eeca91-f509-4932-a81c-adb6d211b557", "metadata": {}, "source": [ "### Generating Dask delayed objects" ] }, { "cell_type": "code", "execution_count": null, "id": "c792f1b9-86d4-42df-a144-02bec06d32d2", "metadata": {}, "outputs": [], "source": [ "import dask" ] }, { "cell_type": "code", "execution_count": null, "id": "c452fcaa-c352-4a20-a57a-53eda82233ad", "metadata": {}, "outputs": [], "source": [ "def delay(fun, _, name):\n", " return dask.delayed(fun, name)\n", " \n", "comp = ProgEval(transformer=delay)\n", "\n", "def inc(a):\n", " return a + 1\n", "\n", "def add(a, b):\n", " return a + b\n", "\n", "comp.x = 5\n", "comp.y = 3\n", "comp.register('a', inc, 'x') \n", "comp.register('b', inc, 'y') \n", "comp.register('total', add)\n", "\n", "# comp.total.visualize()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" } }, "nbformat": 4, "nbformat_minor": 5 }