Commit 547dae16 authored by Kaan Güney Keklikçi's avatar Kaan Güney Keklikçi

planar flow with 2 bijector layers

parent dd6b65e1
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
tf.compat.v1.disable_eager_execution()
import os
import matplotlib.pyplot as plt
plt.style.use('seaborn')
tfd = tfp.distributions
tfb = tfp.bijectors
class Planar(tfb.Bijector, tf.Module):
def __init__(self, input_dimensions, case='density_estimation', validate_args=False, name='planar_flow'):
""" usage of bijector inheritance """
super(Planar, self).__init__(
forward_min_event_ndims=1,
inverse_min_event_ndims=1,
validate_args=validate_args,
name=name)
self.event_ndims = 1
self.case = case
try:
assert self.case != 'density_estimation' or self.case != 'sampling'
except ValueError:
print('Case is not defined. Available options for case: density_estimation, sampling')
self.u = tf.Variable(np.random.uniform(-1., 1., size=(int(input_dimensions))), name='u', dtype=tf.float32, trainable=True)
self.w = tf.Variable(np.random.uniform(-1., 1., size=(int(input_dimensions))), name='w', dtype=tf.float32, trainable=True)
self.b = tf.Variable(np.random.uniform(-1., 1., size=(1)), name='b', dtype=tf.float32, trainable=True)
def h(self, y):
return tf.math.tanh(y)
def h_prime(self, y):
return 1.0 - tf.math.tanh(y) ** 2.0
def alpha(self):
wu = tf.tensordot(self.w, self.u, 1)
m = -1.0 + tf.nn.softplus(wu)
return m - wu
def _u(self):
if tf.tensordot(self.w, self.u, 1) <= -1:
alpha = self.alpha()
z_para = tf.transpose(alpha * self.w / tf.math.sqrt(tf.reduce_sum(self.w ** 2.0)))
self.u.assign_add(z_para) # self.u = self.u + z_para
def _forward_func(self, zk):
inter_1 = self.h(tf.tensordot(zk, self.w, 1) + self.b)
return tf.add(zk, tf.tensordot(inter_1, self.u, 0))
def _forward(self, zk):
if self.case == 'sampling':
return self._forward_func(zk)
else:
raise NotImplementedError('_forward is not implemented for density_estimation')
def _inverse(self, zk):
if self.case == 'density_estimation':
return self._forward_func(zk)
else:
raise NotImplementedError('_inverse is not implemented for sampling')
def _log_det_jacobian(self, zk):
psi = tf.tensordot(self.h_prime(tf.tensordot(zk, self.w, 1) + self.b), self.w, 0)
det = tf.math.abs(1.0 + tf.tensordot(psi, self.u, 1))
return tf.math.log(det)
def _forward_log_det_jacobian(self, zk):
if self.case == 'sampling':
return -self._log_det_jacobian(zk)
else:
raise NotImplementedError('_forward_log_det_jacobian is not implemented for density_estimation')
def _inverse_log_det_jacobian(self, zk):
return self._log_det_jacobian(zk)
\ No newline at end of file
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import numpy as np
from sklearn.preprocessing import StandardScaler
import tensorflow as tf
tf.compat.v1.disable_eager_execution()
import tensorflow_probability as tfp
import matplotlib.pyplot as plt
from data_loader import load_data
from data_preprocesser import preprocess_data
from planar import Planar
def train(session, loss, optimizer, steps=int(1e5)):
""" optimize for all dimensions """
recorded_steps = []
recorded_losses = []
for i in range(steps):
_, loss_per_iteration = session.run([optimizer, loss])
if i % 100 == 0:
recorded_steps.append(i)
recorded_losses.append(loss_per_iteration)
if i % int(1e4) == 0:
print('Iteration {iteration}: {loss}'.format(iteration=i,loss=loss_per_iteration))
return recorded_losses
def plot_results(recorded_losses):
""" plot loss """
print('Displaying results...')
fig = plt.figure(figsize=(10,5))
x = np.arange(len(recorded_losses))
y = recorded_losses
m, b = np.polyfit(x, y, 1)
plt.scatter(x, y, s=10, alpha=0.3)
plt.plot(x, m*x+b, c="r")
plt.title('Loss per 100 iteration')
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.tight_layout()
plt.show()
def create_tensor(data, batch_size):
dataset = tf.data.Dataset.from_tensor_slices(data.astype(np.float32))
dataset = dataset.repeat()
dataset = dataset.shuffle(buffer_size=data.shape[0])
dataset = dataset.prefetch(2*batch_size)
dataset = dataset.batch(batch_size)
data_iterator = tf.compat.v1.data.make_one_shot_iterator(dataset)
samples = data_iterator.get_next()
return samples
"""
if any error on tensorflow is displayed claiming tf.float32 is not displayed,
do the following (one of them is probably enough)
** downgrade keras to 2.3.1
** replace tf.float32 with np.float32
"""
def check_version():
print(f'Tensorflow version: {tf.__version__}')
print(f'Tensorflow-probability version: {tfp.__version__}')
print(f'Keras version: {tf.keras.__version__}\n')
def main():
""" load data """
filename = 'prostate.xls'
directory = '/Users/kaanguney.keklikci/Data/'
loader = load_data(filename, directory)
loader.create_directory(directory)
data = loader.read_data(directory, filename)
print('Data successfully loaded...\n')
""" preprocess data """
fillna_vals = ['sz', 'sg', 'wt']
dropna_vals = ['ekg', 'age']
drop_vals = ['patno', 'sdate']
preprocesser = preprocess_data(StandardScaler(), fillna_vals, dropna_vals, drop_vals)
data = preprocesser.dropna_features(data)
data = preprocesser.impute(data)
data = preprocesser.drop_features(data)
data = preprocesser.encode_categorical(data)
data = preprocesser.scale(data)
print('Data successfully preprocessed...\n')
""" set Planar parameters """
tfd = tfp.distributions
tfb = tfp.bijectors
batch_size = 32
dtype = np.float32
layers = 2
dims = data.shape[1]
# multivariate normal for base distribution
base_dist = tfd.MultivariateNormalDiag(loc=tf.zeros(shape=dims, dtype=dtype))
learning_rate = 1e-4
""" initialize samples """
samples = create_tensor(data, batch_size)
""" make Planar """
bijectors = []
for i in range(0, layers):
bijectors.append(PlanarFlow(input_dimensions=dims, case='density_estimation'))
bijector = tfb.Chain(bijectors=list(reversed(bijectors)), name='chain_of_planar')
planar_flow = tfd.TransformedDistribution(
distribution=base_dist,
bijector=bijector
)
loss = -tf.reduce_mean(planar_flow.log_prob(samples))
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate).minimize(loss)
session = tf.compat.v1.Session()
tf.compat.v1.set_random_seed(42)
session.run(tf.compat.v1.global_variables_initializer())
print('Optimizer and loss successfully defined...\n')
""" start training """
recorded_losses = train(session, loss, optimizer)
print('Training finished...\n')
""" display results """
plot_results(recorded_losses)
if __name__ == "__main__":
main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment