Commit fd8a159e authored by Kaan Güney Keklikçi's avatar Kaan Güney Keklikçi

added flows for noisy moons sklearn data

parent 2418085b
This source diff could not be displayed because it is too large. You can view the blob instead.
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "5394ec2c",
"metadata": {},
"outputs": [],
"source": [
"import os \n",
"os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'\n",
"import numpy as np \n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline \n",
"from sklearn.datasets import make_moons\n",
"from sklearn.preprocessing import StandardScaler\n",
"import tensorflow as tf \n",
"import tensorflow_probability as tfp \n",
"from tensorflow.keras.layers import Input\n",
"from tensorflow.keras import Model\n",
"from tensorflow.keras.callbacks import LambdaCallback\n",
"\n",
"tfd = tfp.distributions\n",
"tfb = tfp.bijectors\n",
"\n",
"color_list = ['#bcad', '#dacb']"
]
},
{
"cell_type": "markdown",
"id": "021893ba",
"metadata": {},
"source": [
"### Define the dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e96e5477",
"metadata": {},
"outputs": [],
"source": [
"n_samples = 1000\n",
"niosy_moons = make_moons(n_samples=n_samples, noise=.05)\n",
"X, y = niosy_moons\n",
"X_data = StandardScaler().fit_transform(X)\n",
"xlim, ylim = [-2, 2], [-2, 2]\n",
"\n",
"y_label = y.astype(np.bool)\n",
"X_train, Y_train = X_data[:,0], X_data[:,1]\n",
"plt.scatter(X_train[y_label], Y_train[y_label], s=10, color=color_list[0])\n",
"plt.scatter(X_train[y_label == False], Y_train[y_label == False], s=10, color=color_list[1])\n",
"plt.legend(['label: 1', 'label: 0'])\n",
"plt.xlim(xlim)\n",
"plt.ylim(ylim)"
]
},
{
"cell_type": "markdown",
"id": "f7989c77",
"metadata": {},
"source": [
"### Define the base distribution"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7205510b",
"metadata": {},
"outputs": [],
"source": [
"mvn = tfd.MultivariateNormalDiag(loc=[0.,0.], scale_diag=[1.,1.])\n",
"mvn"
]
},
{
"cell_type": "markdown",
"id": "91c03525",
"metadata": {},
"source": [
"### MAF initialization"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f3f9ffa1",
"metadata": {},
"outputs": [],
"source": [
"def masked_autoregressive_flow(hidden_units=[16,16], event_shape=[2], activation='relu'):\n",
" network = tfb.AutoregressiveNetwork(params=2, \n",
" hidden_units=hidden_units,\n",
" event_shape=event_shape,\n",
" activation=activation\n",
" )\n",
" return tfb.MaskedAutoregressiveFlow(shift_and_log_scale_fn=network)"
]
},
{
"cell_type": "markdown",
"id": "d2bb7fcd",
"metadata": {},
"source": [
"### Transformed distribution"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f776f68d",
"metadata": {},
"outputs": [],
"source": [
"trainable_dist = tfd.TransformedDistribution(distribution=mvn,\n",
" bijector=masked_autoregressive_flow(\n",
" activation='sigmoid'))\n",
"trainable_dist"
]
},
{
"cell_type": "markdown",
"id": "923ac546",
"metadata": {},
"source": [
"### Initialize samples"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8a2aa1ea",
"metadata": {},
"outputs": [],
"source": [
"x = mvn.sample(sample_shape=(n_samples, 2))\n",
"names = [mvn.name, trainable_dist.bijector.name]\n",
"samples = [x, trainable_dist.bijector.forward(x)]"
]
},
{
"cell_type": "markdown",
"id": "84cfabcf",
"metadata": {},
"source": [
"### Plot routine"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6071af1c",
"metadata": {},
"outputs": [],
"source": [
"def _plot(results, rows=1, legend=False, plot_color: str=color_list[0]):\n",
" cols = int(len(results) / rows)\n",
" f, arr = plt.subplots(rows, cols, figsize=(4 * cols, 4 * rows))\n",
" i = 0\n",
" for r in range(rows):\n",
" for c in range(cols):\n",
" res = results[i]\n",
" X, Y = res[..., 0].numpy(), res[..., 1].numpy()\n",
" if rows == 1:\n",
" p = arr[c]\n",
" else:\n",
" p = arr[r, c]\n",
" p.scatter(X, Y, s=10, color=plot_color)\n",
" p.set_xlim([-5, 5])\n",
" p.set_ylim([-5, 5])\n",
" p.set_title(names[i])\n",
" \n",
" i += 1"
]
},
{
"cell_type": "markdown",
"id": "129b234d",
"metadata": {},
"source": [
"### Prior training"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "64f81179",
"metadata": {},
"outputs": [],
"source": [
"_plot(samples)"
]
},
{
"cell_type": "markdown",
"id": "35a4a728",
"metadata": {},
"source": [
"### Training routine"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "71bc15fe",
"metadata": {},
"outputs": [],
"source": [
"def train_dist_routine(trainable_distribution, n_epochs=200, batch_size=None, n_disp=100):\n",
" x_ = Input(shape=(2,), dtype=tf.float32)\n",
" log_prob_ = trainable_distribution.log_prob(x_)\n",
" model = Model(x_, log_prob_)\n",
"\n",
" model.compile(optimizer=tf.optimizers.Adam(),\n",
" loss=lambda _, log_prob: -log_prob)\n",
"\n",
" ns = X_data.shape[0]\n",
" if batch_size is None:\n",
" batch_size = ns\n",
"\n",
" # Display the loss every n_disp epoch\n",
" epoch_callback = LambdaCallback(\n",
" on_epoch_end=lambda epoch, logs: \n",
" print('\\n Epoch {}/{}'.format(epoch+1, n_epochs, logs),\n",
" '\\n\\t ' + (': {:.4f}, '.join(logs.keys()) + ': {:.4f}').format(*logs.values()))\n",
" if epoch % n_disp == 0 else False \n",
" )\n",
"\n",
"\n",
" history = model.fit(x=X_data,\n",
" y=np.zeros((ns, 0), dtype=np.float32),\n",
" batch_size=batch_size,\n",
" epochs=n_epochs,\n",
" validation_split=0.2,\n",
" shuffle=True,\n",
" verbose=0,\n",
" callbacks=[epoch_callback])\n",
" return history"
]
},
{
"cell_type": "markdown",
"id": "b6260147",
"metadata": {},
"source": [
"### Train the distribution"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7e32cf88",
"metadata": {},
"outputs": [],
"source": [
"history = train_dist_routine(trainable_distribution=trainable_dist,\n",
" n_epochs=600,\n",
" n_disp=200)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bf7eaf63",
"metadata": {},
"outputs": [],
"source": [
"train_losses = history.history['loss']\n",
"valid_losses = history.history['val_loss']\n",
"\n",
"plt.plot(train_losses, label='train', c=color_list[0])\n",
"plt.plot(valid_losses, label='valid', c=color_list[1])\n",
"plt.legend()\n",
"plt.xlabel(\"Epochs\")\n",
"plt.ylabel(\"Negative log likelihood\")\n",
"plt.title(\"Training and validation loss curves\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "7be786e2",
"metadata": {},
"source": [
"### After training -- w/o stacking & permutation & batch normalization"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fe88f36c",
"metadata": {},
"outputs": [],
"source": [
"x = mvn.sample(sample_shape=(n_samples,2))\n",
"names = [mvn.name, trainable_dist.bijector.name]\n",
"samples = [x, trainable_dist.bijector.forward(x)]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "936bde0c",
"metadata": {},
"outputs": [],
"source": [
"_plot(samples, rows=1)"
]
},
{
"cell_type": "markdown",
"id": "dec4652e",
"metadata": {},
"source": [
"### Complex model -- apply bijectors in reverse order"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "892d601a",
"metadata": {},
"outputs": [],
"source": [
"num_layers = 4\n",
"flow_bijector = []\n",
"\n",
"for i in range(num_layers):\n",
" flow_i = masked_autoregressive_flow(hidden_units=[256,256])\n",
" flow_bijector.append(flow_i) \n",
" flow_bijector.append(tfb.Permute([1,0]))\n",
"# discard the last permute layer \n",
"flow_bijector = tfb.Chain(list(reversed(flow_bijector[:-1])))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "174b9850",
"metadata": {},
"outputs": [],
"source": [
"trainable_dist = tfd.TransformedDistribution(distribution=mvn,\n",
" bijector=flow_bijector)"
]
},
{
"cell_type": "markdown",
"id": "015eb594",
"metadata": {},
"source": [
"### Generate samples"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e73726cb",
"metadata": {},
"outputs": [],
"source": [
"def make_samples():\n",
" x = mvn.sample((n_samples, 2))\n",
" samples = [x]\n",
" names = [mvn.name]\n",
" for bijector in reversed(trainable_dist.bijector.bijectors):\n",
" x = bijector.forward(x)\n",
" samples.append(x)\n",
" names.append(bijector.name)\n",
" return names, samples\n",
"\n",
"names, samples = make_samples()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6a507799",
"metadata": {},
"outputs": [],
"source": [
"_plot(samples, rows=2, plot_color=color_list[1])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "17d3488a",
"metadata": {},
"outputs": [],
"source": [
"def visualize_training_data(samples, plot_color: str=color_list[0]):\n",
" f, arr = plt.subplots(1, 2, figsize=(20,5))\n",
" names = ['Data', 'Trainable']\n",
" samples = [tf.constant(X_data), samples[-1]]\n",
"\n",
" for i in range(2):\n",
" res = samples[i]\n",
" X, Y = res[..., 0].numpy(), res[..., 1].numpy()\n",
" arr[i].scatter(X, Y, s=10, color=plot_color)\n",
" arr[i].set_xlim([-2, 2])\n",
" arr[i].set_ylim([-2, 2])\n",
" arr[i].set_title(names[i])"
]
},
{
"cell_type": "markdown",
"id": "11c7b853",
"metadata": {},
"source": [
"### Visualization -- before training"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4791a426",
"metadata": {},
"outputs": [],
"source": [
"visualize_training_data(samples)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c2adde55",
"metadata": {},
"outputs": [],
"source": [
"history = train_dist_routine(trainable_distribution=trainable_dist,\n",
" n_epochs=600,\n",
" n_disp=200)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f8e8ea4d",
"metadata": {},
"outputs": [],
"source": [
"train_losses = history.history['loss']\n",
"valid_losses = history.history['val_loss']\n",
"\n",
"plt.plot(train_losses, label='train')\n",
"plt.plot(valid_losses, label='valid')\n",
"plt.legend()\n",
"plt.xlabel(\"Epochs\")\n",
"plt.ylabel(\"Negative log likelihood\")\n",
"plt.title(\"Training and validation loss curves\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7bc43dae",
"metadata": {},
"outputs": [],
"source": [
"names, samples = make_samples()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "41c676d2",
"metadata": {},
"outputs": [],
"source": [
"_plot(samples, rows=2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ffe5a97d",
"metadata": {},
"outputs": [],
"source": [
"visualize_training_data(samples)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment