Commit dd0d7bab authored by Kaan Güney Keklikçi's avatar Kaan Güney Keklikçi

corrected shift function

parent 323654ae
...@@ -68,7 +68,7 @@ class MAF(object): ...@@ -68,7 +68,7 @@ class MAF(object):
def make_maf(self, data): def make_maf(self, data):
distribution = self.base_dist distribution = self.base_dist
sample_shape = self.get_dims(data) sample_shape = self.get_dims(data)
shift_scale_function = self.get_shift_scale_func(data) shift_scale_function = self.get_shift_scale_func()
bijector = tfb.MaskedAutoregressiveFlow(shift_scale_function) bijector = tfb.MaskedAutoregressiveFlow(shift_scale_function)
maf = tfd.TransformedDistribution(tfd.Sample(distribution, sample_shape), bijector) maf = tfd.TransformedDistribution(tfd.Sample(distribution, sample_shape), bijector)
return maf return maf
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment