Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
B
beta-vae-normalizing-flows
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Kaan Güney Keklikçi
beta-vae-normalizing-flows
Commits
d308c6cd
Commit
d308c6cd
authored
Aug 22, 2021
by
Kaan Güney Keklikçi
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
adding invertable radial script
parent
c011962a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
6 deletions
+8
-6
radial_execute.py
scripts/flows/radial/radial_execute.py
+8
-6
No files found.
scripts/flows/radial/radial_execute.py
View file @
d308c6cd
...
...
@@ -41,15 +41,15 @@ def main():
""" define the base distributon as bivariate gaussian """
base_dist
=
tfd
.
Independent
(
tfd
.
Normal
(
loc
=
[
0.
,
0.
],
scale
=
[
1.
,
1.
]),
base_dist
=
tfd
.
Independent
(
tfd
.
Normal
(
loc
=
[
2.
,
-
0.5
],
scale
=
[
1.
,
1.
]),
reinterpreted_batch_ndims
=
1
)
""" instantiate the bijector (a,b,x0) """
n
=
1000
a
=
2
.
b
=
-
1
.99
x0
=
np
.
array
([
0.
,
1.
])
.
astype
(
np
.
float32
)
.
reshape
(
-
1
,
2
)
a
=
10
.
b
=
-
1
0.
x0
=
np
.
array
([
-
0.5
,
1.
])
.
astype
(
np
.
float32
)
.
reshape
(
-
1
,
2
)
bijector
=
RadialFlow
(
a
,
b
,
x0
)
print
(
f
'x0 shape: {x0.shape}'
)
...
...
@@ -72,7 +72,7 @@ def main():
""" create transformed distribution """
tfd_dist
=
tfd
.
TransformedDistribution
(
distribution
=
base_dist
,
bijector
=
bijector
bijector
=
bijector
)
# prior training
...
...
@@ -105,6 +105,7 @@ def main():
name
=
'beta'
),
tf
.
Variable
(
x0
,
name
=
'ref'
))
# instantiate trainable distribution
...
...
@@ -164,10 +165,11 @@ def main():
### DO NOT CHANGE validate_args=True
### DOES NOT USE DATASET YET
### FOR VISUALIZATION PURPOSES IN
3
D
### FOR VISUALIZATION PURPOSES IN
2
D
### WILL INTEGRATE DATASET AFTER LEARNING THE DISTRIBUTION
if
__name__
==
'__main__'
:
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment