{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Fitting a Model to Data\n", "\n", "In class we talked about two key ingredients for writing a computer recipe for optimizing a model to match data. We need a goodness-of-fit metric, and we need a way to explore the range of parameters that are out there. \n", "\n", "## Linear Least Squares\n", "For some classes of model fitting problems (\"linear least squares\" problems), there are some tools that can identify the best-fit parameters almost instantaneously. For example, `np.polyfit` does some magical matrix math to instantly fit a polynomial to (x,y) data. \n", "\n", "## General Fitting Problems\n", "You won't always have a tidy \"linear\" problem, so it's useful to understand how fitting might work more abstractly. Here are a few examples that help visualize the exploration and optimization process. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np, matplotlib.pyplot as plt\n", "import henrietta as hsl" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, let's generate a simulated transit. It's useful to know the parameters we injected, so we can see how good our fits look in the end." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# create a simulated light curve\n", "actual = dict(period=1.58, radius=0.116, a=15.2, b=0.385)\n", "lc = hsl.simulate_transit_data(tmin=-0.1, duration=0.2, cadence=1/60/24., N=1e7, **actual)\n", "lc.scatter();" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We're going to use some of the [astropy.modeling](http://docs.astropy.org/en/stable/modeling/) tools. You don't need to know much about them now, but the basic step we need to take is to set up a \"model\" object that we can play with. We wrote a wrapper to help with this:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = hsl.setup_transit_model(period=1.58, \n", " radius=[0.0, 0.3], \n", " a=[3.0, 50.0], \n", " b=0.385, \n", " baseline=1.0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you give the `setup_transit_model` function one value for a parameter, it will treat that parameter as fixed. If you give it a range of parameters, it will treat that parameter as unknown but within that range. All of these are stored inside of our model object." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(model.param_names)\n", "print(model.parameters)\n", "print(model.fixed)\n", "print(model.bounds)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "An astropy `model` can store parameters inside them, but it can also be called as a function to generate a model curve:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = lc.time\n", "y = lc.flux\n", "plt.plot(x, y, '.k')\n", "plt.plot(x, model(x));" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Hmmmm...that wasn't a great match. To play with fitting our model to improve its parameters, we need to define a goodness of fit function. Let's try a common one:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def sumofsquares(residuals):\n", " '''\n", " This calculates a goodness-of-fit from an array of residuals.\n", " (a lower value implies a better fit)\n", "\n", " Parameters\n", " ----------\n", " residuals : `~numpy.ndarray`\n", " Array of values for (data-model)/sigma\n", "\n", " Returns\n", " -------\n", " gof : float\n", " A single goodness-of-fit metric\n", " (in this case, sum of squares)\n", " '''\n", " return np.sum(residuals**2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "With this, we can set up a \"fitter\", which will help us figure out good values for our parameters. By setting the `goodness` keyword, we can tell the fitter what function to use for calculating the goodness of fit from the residuals. The fitter will try to *minimize* the goodness values, so make sure you're aware of that in your definition of the `goodness` function. (You could make a reasonable argument we should be calling this \"badness\" instead. That's fair.)\n", "\n", "### Exploration #1: Guess and Check\n", "Let's start with an inefficient, but conceptually simple, algorithm. We'll just try random sets of parameters, and see which one looks best. First, we create a \"fitter\" object, that handles this exploration." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fitter = hsl.VisualizedGuessNCheckFitter(goodness=sumofsquares)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we can run this fitter to see what happens. It will return the best model, but it will also make an animation that shows us the process it took to get there." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bestmodel = fitter.visualize(model, x, y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This will save an animation out to the file `guessncheck-sumofsquares.mp4`. Open it up to take a look!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Exploration #2: Simplex\n", "\n", "Our guessing and checking method seems a little inefficient. Here's an example of the `simplex` method for minimizing functions (in this case our goodness of fit). What's going on with this one?" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "faster = hsl.VisualizedSimplexFitter(goodness=sumofsquares)\n", "bestmodel = faster.visualize(model, x, y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`simplex` will do a better job of fitting some parameters than others. Play around with which parameters you allow to vary in `setup_transit_model`, and by what ranges. Does `simplex` do a good job of fitting all of them? Based on your animations, do have any ideas why or what not?" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }