Commit 5d08ddd0 authored by Kaan Güney Keklikçi's avatar Kaan Güney Keklikçi

adding radial class, training scripts

parent 4b8f228e
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
import os
import matplotlib.pyplot as plt
plt.style.use('seaborn')
tfd = tfp.distributions
tfb = tfp.bijectors
# experimental -- inspired by
# Variational Inference with Normalizing Flows, Rezende & Shakir
# do not use for this setting, distorts base distribution a lot, no convergence in the end
def parametrize(b):
b = tf.cast(b, tf.float32)
return tf.math.log1p(tf.exp(b)).numpy()
def h(a, r):
return 1 / (a + r)
def h_prime(a, r):
return -1 / (a + r)**2
class RadialFlow(tfb.Bijector):
def __init__(self, a, b, x0, validate_args=True, name='radial-flow'):
self.a = tf.cast(a, tf.float32)
self.b = tf.cast(b, tf.float32)
self.x0 = tf.cast(x0, tf.float32)
super(RadialFlow, self).__init__(validate_args=validate_args,
forward_min_event_ndims=0,
name=name)
if validate_args:
assert tf.math.greater_equal(self.b, -self.a).numpy() == True
def _forward(self, x):
r = tf.abs(x - self.x0)
zhat = (x - self.x0) / r
y = self.x0 + r*zhat + r * zhat * self.b * h(self.a, r)
return y
def _inverse(self, y):
r = tf.abs(y - self.x0)
zhat = (y - self.x0) / r
return self.b * r * zhat * h(self.a, r)
def _forward_log_det_jacobian(self, y):
try:
n_dims = y.shape[1]
except IndexError as e:
raise RuntimeError('Input is one dimensional!')
r = tf.abs(y - self.x0)
dh = h_prime(self.a, r)
hh = h(self.a, r)
return (1 + self.b * hh)**2 * (1 + self.b * hh + self.b * dh * r)
def _inverse_log_det_jacobian(self, y):
return -self._forward_log_det_jacobian(self._inverse(y))
\ No newline at end of file
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import numpy as np
from sklearn.preprocessing import StandardScaler
import tensorflow as tf
import tensorflow_probability as tfp
import matplotlib.pyplot as plt
plt.style.use('seaborn-paper')
from data_loader import load_data
from data_preprocesser import preprocess_data
from radial import RadialFlow
def main():
tfd = tfp.distributions
tfb = tfp.bijectors
""" load data """
filename = 'prostate.xls'
directory = '/Users/kaanguney.keklikci/Data/'
loader = load_data(filename, directory)
loader.create_directory(directory)
data = loader.read_data(directory, filename)
print('Data successfully loaded...\n')
""" preprocess data """
fillna_vals = ['sz', 'sg', 'wt']
dropna_vals = ['ekg', 'age']
drop_vals = ['patno', 'sdate']
preprocesser = preprocess_data(StandardScaler(), fillna_vals, dropna_vals, drop_vals)
data = preprocesser.dropna_features(data)
data = preprocesser.impute(data)
data = preprocesser.drop_features(data)
data = preprocesser.encode_categorical(data)
data = preprocesser.scale(data)
print('Data successfully preprocessed...\n')
""" define the base distributon as bivariate gaussian """
base_dist = tfd.Independent(tfd.Normal(loc=[0., 0.], scale=[1.,1.]),
reinterpreted_batch_ndims=1)
""" instantiate the bijector (a,b,x0) """
n = 1000
a = 2.
b = -1.99
x0 = np.array([0.,1.]).astype(np.float32).reshape(-1, 2)
bijector = RadialFlow(a, b, x0)
print(f'x0 shape: {x0.shape}')
x = np.linspace(start=-4, stop=4, num=n).astype(np.float32).reshape(-1, 2)
plt.title('Forward transformation')
plt.plot(x, bijector.forward(x))
plt.show()
plt.plot(x, bijector.inverse(x))
plt.title('Inverse transformation')
plt.show()
plt.plot(x, bijector.forward_log_det_jacobian(x))
plt.title('Jacobian determinant')
plt.show()
plt.plot(x, bijector.inverse_log_det_jacobian(x))
plt.title('Inverse Jacobian determinant')
plt.show()
""" create transformed distribution """
tfd_dist = tfd.TransformedDistribution(distribution=base_dist,
bijector=bijector
)
# prior training
plt.figure(figsize=(12,5))
plt.plot(tfd_dist.prob(x), label='Trainable')
plt.plot(base_dist.prob(x), label='Base')
plt.title('Target and Trainable distribution')
plt.legend()
plt.show()
# sample, batch -- train, validation
x_train = base_dist.sample(10000)
x_train = tf.data.Dataset.from_tensor_slices(x_train)
x_train = x_train.batch(int(n/2))
x_valid = base_dist.sample(1000)
x_valid = tf.data.Dataset.from_tensor_slices(x_valid)
x_valid = x_valid.batch(int(n/2))
print(x_train.element_spec)
print(x_valid.element_spec)
print()
# instantiate trainable bijector
trainable_bijector = RadialFlow(tf.Variable(a,
name='alpha'),
tf.Variable(b,
name='beta'),
tf.Variable(x0,
name='ref'))
# instantiate trainable distribution
trainable_dist = tfd.TransformedDistribution(tfd_dist,
trainable_bijector
)
# Train the bijector
num_epochs = 10
opt = tf.keras.optimizers.Adam()
train_losses = []
valid_losses = []
norm = 1e3
for epoch in range(num_epochs):
print("Epoch {}...".format(epoch))
train_loss = tf.keras.metrics.Mean()
val_loss = tf.keras.metrics.Mean()
for train_batch in x_train:
with tf.GradientTape() as tape:
tape.watch(trainable_bijector.trainable_variables)
loss = -trainable_dist.log_prob(train_batch)
train_loss(loss)
grads = tape.gradient(loss, trainable_bijector.trainable_variables)
grads, _ = tf.clip_by_global_norm(grads, norm)
# grads = tf.reshape(tf.nn.softmax(grads[-1], axis=1), [2])
# note that both alternatives work almost identically
opt.apply_gradients(zip(grads, trainable_bijector.trainable_variables))
train_losses.append(train_loss.result().numpy())
# validation
for valid_batch in x_valid:
loss = -trainable_dist.log_prob(valid_batch)
val_loss(loss)
valid_losses.append(val_loss.result().numpy())
# Plot the learning curves
plt.plot(train_losses, label='train')
plt.plot(valid_losses, label='valid')
plt.legend()
plt.xlabel("Epochs")
plt.ylabel("Negative log likelihood")
plt.title("Training and validation loss curves")
plt.show()
# Plot the data and learned distributions
plt.figure(figsize=(12,5))
plt.plot(trainable_dist.prob(x), label='Learned')
plt.plot(base_dist.prob(x), label='Data')
plt.legend()
plt.show()
### best result obtained with tuning displayed as above
### radial flows converge fast, more epochs overfit
### DO NOT CHANGE validate_args=True
### DOES NOT USE DATASET YET
### FOR VISUALIZATION PURPOSES IN 3D
### WILL INTEGRATE DATASET AFTER LEARNING THE DISTRIBUTION
if __name__ == '__main__':
main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment