-
Notifications
You must be signed in to change notification settings - Fork 3
/
math_.py
2159 lines (1742 loc) · 57.9 KB
/
math_.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
r"""
:math:`\kappa`-Stereographic math module.
The functions for the mathematics in gyrovector spaces are taken from the
following resources:
[1] Ganea, Octavian, Gary Bécigneul, and Thomas Hofmann. "Hyperbolic
neural networks." Advances in neural information processing systems.
2018.
[2] Bachmann, Gregor, Gary Bécigneul, and Octavian-Eugen Ganea. "Constant
Curvature Graph Convolutional Networks." arXiv preprint
arXiv:1911.05076 (2019).
[3] Skopek, Ondrej, Octavian-Eugen Ganea, and Gary Bécigneul.
"Mixed-curvature Variational Autoencoders." arXiv preprint
arXiv:1911.08411 (2019).
[4] Ungar, Abraham A. Analytic hyperbolic geometry: Mathematical
foundations and applications. World Scientific, 2005.
[5] Albert, Ungar Abraham. Barycentric calculus in Euclidean and
hyperbolic geometry: A comparative introduction. World Scientific,
2010.
"""
import functools
from typing import List, Optional
import torch.jit
from ...utils import clamp_abs, drop_dims, list_range, sabs, sign
r"""
:math:`\kappa`-Stereographic math module.
The functions for the mathematics in gyrovector spaces are taken from the
following resources:
[1] Ganea, Octavian, Gary Bécigneul, and Thomas Hofmann. "Hyperbolic
neural networks." Advances in neural information processing systems.
2018.
[2] Bachmann, Gregor, Gary Bécigneul, and Octavian-Eugen Ganea. "Constant
Curvature Graph Convolutional Networks." arXiv preprint
arXiv:1911.05076 (2019).
[3] Skopek, Ondrej, Octavian-Eugen Ganea, and Gary Bécigneul.
"Mixed-curvature Variational Autoencoders." arXiv preprint
arXiv:1911.08411 (2019).
[4] Ungar, Abraham A. Analytic hyperbolic geometry: Mathematical
foundations and applications. World Scientific, 2005.
[5] Albert, Ungar Abraham. Barycentric calculus in Euclidean and
hyperbolic geometry: A comparative introduction. World Scientific,
2010.
"""
@torch.jit.script
def tanh(x):
return x.clamp(-15, 15).tanh()
@torch.jit.script
def artanh(x: torch.Tensor):
x = x.clamp(-1 + 1e-7, 1 - 1e-7)
return (torch.log(1 + x).sub(torch.log(1 - x))).mul(0.5)
@torch.jit.script
def arsinh(x: torch.Tensor):
return (x + torch.sqrt(1 + x.pow(2))).clamp_min(1e-15).log().to(x.dtype)
@torch.jit.script
def tanh_half_arsinh(x: torch.Tensor):
return ((x.pow(2) + 1).sqrt() - 1) / x.clamp_min(1e-7)
@torch.jit.script
def abs_zero_grad(x):
# this op has derivative equal to 1 at zero
return x * sign(x)
@torch.jit.script
def tan_k_zero_taylor(x: torch.Tensor, k: torch.Tensor, order: int = -1):
if order == 0:
return x
k = abs_zero_grad(k)
if order == -1 or order == 5:
return (
x
+ 1 / 3 * k * x**3
+ 2 / 15 * k**2 * x**5
+ 17 / 315 * k**3 * x**7
+ 62 / 2835 * k**4 * x**9
+ 1382 / 155925 * k**5 * x**11
# + o(k**6)
)
elif order == 1:
return x + 1 / 3 * k * x**3
elif order == 2:
return x + 1 / 3 * k * x**3 + 2 / 15 * k**2 * x**5
elif order == 3:
return (
x
+ 1 / 3 * k * x**3
+ 2 / 15 * k**2 * x**5
+ 17 / 315 * k**3 * x**7
)
elif order == 4:
return (
x
+ 1 / 3 * k * x**3
+ 2 / 15 * k**2 * x**5
+ 17 / 315 * k**3 * x**7
+ 62 / 2835 * k**4 * x**9
)
else:
raise RuntimeError("order not in [-1, 5]")
@torch.jit.script
def artan_k_zero_taylor(x: torch.Tensor, k: torch.Tensor, order: int = -1):
if order == 0:
return x
k = abs_zero_grad(k)
if order == -1 or order == 5:
return (
x
- 1 / 3 * k * x**3
+ 1 / 5 * k**2 * x**5
- 1 / 7 * k**3 * x**7
+ 1 / 9 * k**4 * x**9
- 1 / 11 * k**5 * x**11
# + o(k**6)
)
elif order == 1:
return x - 1 / 3 * k * x**3
elif order == 2:
return x - 1 / 3 * k * x**3 + 1 / 5 * k**2 * x**5
elif order == 3:
return (
x - 1 / 3 * k * x**3 + 1 / 5 * k**2 * x**5 - 1 / 7 * k**3 * x**7
)
elif order == 4:
return (
x
- 1 / 3 * k * x**3
+ 1 / 5 * k**2 * x**5
- 1 / 7 * k**3 * x**7
+ 1 / 9 * k**4 * x**9
)
else:
raise RuntimeError("order not in [-1, 5]")
@torch.jit.script
def arsin_k_zero_taylor(x: torch.Tensor, k: torch.Tensor, order: int = -1):
if order == 0:
return x
k = abs_zero_grad(k)
if order == -1 or order == 5:
return (
x
+ k * x**3 / 6
+ 3 / 40 * k**2 * x**5
+ 5 / 112 * k**3 * x**7
+ 35 / 1152 * k**4 * x**9
+ 63 / 2816 * k**5 * x**11
# + o(k**6)
)
elif order == 1:
return x + k * x**3 / 6
elif order == 2:
return x + k * x**3 / 6 + 3 / 40 * k**2 * x**5
elif order == 3:
return x + k * x**3 / 6 + 3 / 40 * k**2 * x**5 + 5 / 112 * k**3 * x**7
elif order == 4:
return (
x
+ k * x**3 / 6
+ 3 / 40 * k**2 * x**5
+ 5 / 112 * k**3 * x**7
+ 35 / 1152 * k**4 * x**9
)
else:
raise RuntimeError("order not in [-1, 5]")
@torch.jit.script
def sin_k_zero_taylor(x: torch.Tensor, k: torch.Tensor, order: int = -1):
if order == 0:
return x
k = abs_zero_grad(k)
if order == -1 or order == 5:
return (
x
- k * x**3 / 6
+ k**2 * x**5 / 120
- k**3 * x**7 / 5040
+ k**4 * x**9 / 362880
- k**5 * x**11 / 39916800
# + o(k**6)
)
elif order == 1:
return x - k * x**3 / 6
elif order == 2:
return x - k * x**3 / 6 + k**2 * x**5 / 120
elif order == 3:
return x - k * x**3 / 6 + k**2 * x**5 / 120 - k**3 * x**7 / 5040
elif order == 4:
return (
x
- k * x**3 / 6
+ k**2 * x**5 / 120
- k**3 * x**7 / 5040
+ k**4 * x**9 / 362880
)
else:
raise RuntimeError("order not in [-1, 5]")
@torch.jit.script
def tan_k(x: torch.Tensor, k: torch.Tensor):
k_sign = k.sign()
zero = torch.zeros((), device=k.device, dtype=k.dtype)
k_zero = k.isclose(zero)
# shrink sign
k_sign = torch.masked_fill(k_sign, k_zero, zero.to(k_sign.dtype))
if torch.all(k_zero):
return tan_k_zero_taylor(x, k, order=1)
k_sqrt = sabs(k).sqrt()
scaled_x = x * k_sqrt
if torch.all(k_sign.lt(0)):
return k_sqrt.reciprocal() * tanh(scaled_x)
elif torch.all(k_sign.gt(0)):
return k_sqrt.reciprocal() * scaled_x.clamp_max(1e38).tan()
else:
tan_k_nonzero = (
torch.where(k_sign.gt(0), scaled_x.clamp_max(1e38).tan(), tanh(scaled_x))
* k_sqrt.reciprocal()
)
return torch.where(k_zero, tan_k_zero_taylor(x, k, order=1), tan_k_nonzero)
@torch.jit.script
def artan_k(x: torch.Tensor, k: torch.Tensor):
k_sign = k.sign()
zero = torch.zeros((), device=k.device, dtype=k.dtype)
k_zero = k.isclose(zero)
# shrink sign
k_sign = torch.masked_fill(k_sign, k_zero, zero.to(k_sign.dtype))
if torch.all(k_zero):
return artan_k_zero_taylor(x, k, order=1)
k_sqrt = sabs(k).sqrt()
scaled_x = x * k_sqrt
if torch.all(k_sign.lt(0)):
return k_sqrt.reciprocal() * artanh(scaled_x)
elif torch.all(k_sign.gt(0)):
return k_sqrt.reciprocal() * scaled_x.atan()
else:
artan_k_nonzero = (
torch.where(k_sign.gt(0), scaled_x.atan(), artanh(scaled_x))
* k_sqrt.reciprocal()
)
return torch.where(k_zero, artan_k_zero_taylor(x, k, order=1), artan_k_nonzero)
@torch.jit.script
def arsin_k(x: torch.Tensor, k: torch.Tensor):
k_sign = k.sign()
zero = torch.zeros((), device=k.device, dtype=k.dtype)
k_zero = k.isclose(zero)
# shrink sign
k_sign = torch.masked_fill(k_sign, k_zero, zero.to(k_sign.dtype))
if torch.all(k_zero):
return arsin_k_zero_taylor(x, k)
k_sqrt = sabs(k).sqrt()
scaled_x = x * k_sqrt
if torch.all(k_sign.lt(0)):
return k_sqrt.reciprocal() * arsinh(scaled_x)
elif torch.all(k_sign.gt(0)):
return k_sqrt.reciprocal() * scaled_x.asin()
else:
arsin_k_nonzero = (
torch.where(
k_sign.gt(0),
scaled_x.clamp(-1 + 1e-7, 1 - 1e-7).asin(),
arsinh(scaled_x),
)
* k_sqrt.reciprocal()
)
return torch.where(k_zero, arsin_k_zero_taylor(x, k, order=1), arsin_k_nonzero)
@torch.jit.script
def sin_k(x: torch.Tensor, k: torch.Tensor):
k_sign = k.sign()
zero = torch.zeros((), device=k.device, dtype=k.dtype)
k_zero = k.isclose(zero)
# shrink sign
k_sign = torch.masked_fill(k_sign, k_zero, zero.to(k_sign.dtype))
if torch.all(k_zero):
return sin_k_zero_taylor(x, k)
k_sqrt = sabs(k).sqrt()
scaled_x = x * k_sqrt
if torch.all(k_sign.lt(0)):
return k_sqrt.reciprocal() * torch.sinh(scaled_x)
elif torch.all(k_sign.gt(0)):
return k_sqrt.reciprocal() * scaled_x.sin()
else:
sin_k_nonzero = (
torch.where(k_sign.gt(0), scaled_x.sin(), torch.sinh(scaled_x))
* k_sqrt.reciprocal()
)
return torch.where(k_zero, sin_k_zero_taylor(x, k, order=1), sin_k_nonzero)
def project(x: torch.Tensor, *, k: torch.Tensor, dim=-1, eps=-1):
r"""
Safe projection on the manifold for numerical stability.
Parameters
----------
x : tensor
point on the Poincare ball
k : tensor
sectional curvature of manifold
dim : int
reduction dimension to compute norm
eps : float
stability parameter, uses default for dtype if not provided
Returns
-------
tensor
projected vector on the manifold
"""
return _project(x, k, dim, eps)
@torch.jit.script
def _project(x, k, dim: int = -1, eps: float = -1.0):
if eps < 0:
if x.dtype == torch.float32:
eps = 4e-3
else:
eps = 1e-5
maxnorm = (1 - eps) / (sabs(k) ** 0.5)
maxnorm = torch.where(k.lt(0), maxnorm, k.new_full((), 1e15))
norm = x.norm(dim=dim, keepdim=True, p=2).clamp_min(1e-15)
cond = norm > maxnorm
projected = x / norm * maxnorm
return torch.where(cond, projected, x)
def lambda_x(x: torch.Tensor, *, k: torch.Tensor, keepdim=False, dim=-1):
r"""
Compute the conformal factor :math:`\lambda^\kappa_x` for a point on the ball.
.. math::
\lambda^\kappa_x = \frac{1}{1 + \kappa \|x\|_2^2}
Parameters
----------
x : tensor
point on the Poincare ball
k : tensor
sectional curvature of manifold
keepdim : bool
retain the last dim? (default: false)
dim : int
reduction dimension
Returns
-------
tensor
conformal factor
"""
return _lambda_x(x, k, keepdim=keepdim, dim=dim)
@torch.jit.script
def _lambda_x(x: torch.Tensor, k: torch.Tensor, keepdim: bool = False, dim: int = -1):
return 2 / (1 + k * x.pow(2).sum(dim=dim, keepdim=keepdim)).clamp_min(1e-15)
def inner(
x: torch.Tensor, u: torch.Tensor, v: torch.Tensor, *, k, keepdim=False, dim=-1
):
r"""
Compute inner product for two vectors on the tangent space w.r.t Riemannian metric on the Poincare ball.
.. math::
\langle u, v\rangle_x = (\lambda^\kappa_x)^2 \langle u, v \rangle
Parameters
----------
x : tensor
point on the Poincare ball
u : tensor
tangent vector to :math:`x` on Poincare ball
v : tensor
tangent vector to :math:`x` on Poincare ball
k : tensor
sectional curvature of manifold
keepdim : bool
retain the last dim? (default: false)
dim : int
reduction dimension
Returns
-------
tensor
inner product
"""
return _inner(x, u, v, k, keepdim=keepdim, dim=dim)
@torch.jit.script
def _inner(
x: torch.Tensor,
u: torch.Tensor,
v: torch.Tensor,
k: torch.Tensor,
keepdim: bool = False,
dim: int = -1,
):
return _lambda_x(x, k, keepdim=True, dim=dim) ** 2 * (u * v).sum(
dim=dim, keepdim=keepdim
)
def norm(x: torch.Tensor, u: torch.Tensor, *, k: torch.Tensor, keepdim=False, dim=-1):
r"""
Compute vector norm on the tangent space w.r.t Riemannian metric on the Poincare ball.
.. math::
\|u\|_x = \lambda^\kappa_x \|u\|_2
Parameters
----------
x : tensor
point on the Poincare ball
u : tensor
tangent vector to :math:`x` on Poincare ball
k : tensor
sectional curvature of manifold
keepdim : bool
retain the last dim? (default: false)
dim : int
reduction dimension
Returns
-------
tensor
norm of vector
"""
return _norm(x, u, k, keepdim=keepdim, dim=dim)
@torch.jit.script
def _norm(
x: torch.Tensor,
u: torch.Tensor,
k: torch.Tensor,
keepdim: bool = False,
dim: int = -1,
):
return _lambda_x(x, k, keepdim=keepdim, dim=dim) * u.norm(
dim=dim, keepdim=keepdim, p=2
)
def mobius_add(x: torch.Tensor, y: torch.Tensor, *, k: torch.Tensor, dim=-1):
r"""
Compute the Möbius gyrovector addition.
.. math::
x \oplus_\kappa y =
\frac{
(1 - 2 \kappa \langle x, y\rangle - \kappa \|y\|^2_2) x +
(1 + \kappa \|x\|_2^2) y
}{
1 - 2 \kappa \langle x, y\rangle + \kappa^2 \|x\|^2_2 \|y\|^2_2
}
.. plot:: plots/extended/stereographic/mobius_add.py
In general this operation is not commutative:
.. math::
x \oplus_\kappa y \ne y \oplus_\kappa x
But in some cases this property holds:
* zero vector case
.. math::
\mathbf{0} \oplus_\kappa x = x \oplus_\kappa \mathbf{0}
* zero curvature case that is same as Euclidean addition
.. math::
x \oplus_0 y = y \oplus_0 x
Another useful property is so called left-cancellation law:
.. math::
(-x) \oplus_\kappa (x \oplus_\kappa y) = y
Parameters
----------
x : tensor
point on the manifold
y : tensor
point on the manifold
k : tensor
sectional curvature of manifold
dim : int
reduction dimension for operations
Returns
-------
tensor
the result of the Möbius addition
"""
return _mobius_add(x, y, k, dim=dim)
@torch.jit.script
def _mobius_add(x: torch.Tensor, y: torch.Tensor, k: torch.Tensor, dim: int = -1):
x2 = x.pow(2).sum(dim=dim, keepdim=True)
y2 = y.pow(2).sum(dim=dim, keepdim=True)
xy = (x * y).sum(dim=dim, keepdim=True)
num = (1 - 2 * k * xy - k * y2) * x + (1 + k * x2) * y
denom = 1 - 2 * k * xy + k**2 * x2 * y2
# minimize denom (omit K to simplify th notation)
# 1)
# {d(denom)/d(x) = 2 y + 2x * <y, y> = 0
# {d(denom)/d(y) = 2 x + 2y * <x, x> = 0
# 2)
# {y + x * <y, y> = 0
# {x + y * <x, x> = 0
# 3)
# {- y/<y, y> = x
# {- x/<x, x> = y
# 4)
# minimum = 1 - 2 <y, y>/<y, y> + <y, y>/<y, y> = 0
return num / denom.clamp_min(1e-15)
def mobius_sub(x: torch.Tensor, y: torch.Tensor, *, k: torch.Tensor, dim=-1):
r"""
Compute the Möbius gyrovector subtraction.
The Möbius subtraction can be represented via the Möbius addition as
follows:
.. math::
x \ominus_\kappa y = x \oplus_\kappa (-y)
Parameters
----------
x : tensor
point on manifold
y : tensor
point on manifold
k : tensor
sectional curvature of manifold
dim : int
reduction dimension for operations
Returns
-------
tensor
the result of the Möbius subtraction
"""
return _mobius_sub(x, y, k, dim=dim)
def _mobius_sub(x: torch.Tensor, y: torch.Tensor, k: torch.Tensor, dim: int = -1):
return _mobius_add(x, -y, k, dim=dim)
def gyration(
a: torch.Tensor, b: torch.Tensor, u: torch.Tensor, *, k: torch.Tensor, dim=-1
):
r"""
Compute the gyration of :math:`u` by :math:`[a,b]`.
The gyration is a special operation of gyrovector spaces. The gyrovector
space addition operation :math:`\oplus_\kappa` is not associative (as
mentioned in :func:`mobius_add`), but it is gyroassociative, which means
.. math::
u \oplus_\kappa (v \oplus_\kappa w)
=
(u\oplus_\kappa v) \oplus_\kappa \operatorname{gyr}[u, v]w,
where
.. math::
\operatorname{gyr}[u, v]w
=
\ominus (u \oplus_\kappa v) \oplus (u \oplus_\kappa (v \oplus_\kappa w))
We can simplify this equation using the explicit formula for the Möbius
addition [1]. Recall,
.. math::
A = - \kappa^2 \langle u, w\rangle \langle v, v\rangle
- \kappa \langle v, w\rangle
+ 2 \kappa^2 \langle u, v\rangle \langle v, w\rangle\\
B = - \kappa^2 \langle v, w\rangle \langle u, u\rangle
+ \kappa \langle u, w\rangle\\
D = 1 - 2 \kappa \langle u, v\rangle
+ \kappa^2 \langle u, u\rangle \langle v, v\rangle\\
\operatorname{gyr}[u, v]w = w + 2 \frac{A u + B v}{D}.
Parameters
----------
a : tensor
first point on manifold
b : tensor
second point on manifold
u : tensor
vector field for operation
k : tensor
sectional curvature of manifold
dim : int
reduction dimension for operations
Returns
-------
tensor
the result of automorphism
References
----------
[1] A. A. Ungar (2009), A Gyrovector Space Approach to Hyperbolic Geometry
"""
return _gyration(a, b, u, k, dim=dim)
@torch.jit.script
def _gyration(
u: torch.Tensor, v: torch.Tensor, w: torch.Tensor, k: torch.Tensor, dim: int = -1
):
# non-simplified
# mupv = -_mobius_add(u, v, K)
# vpw = _mobius_add(u, w, K)
# upvpw = _mobius_add(u, vpw, K)
# return _mobius_add(mupv, upvpw, K)
# simplified
u2 = u.pow(2).sum(dim=dim, keepdim=True)
v2 = v.pow(2).sum(dim=dim, keepdim=True)
uv = (u * v).sum(dim=dim, keepdim=True)
uw = (u * w).sum(dim=dim, keepdim=True)
vw = (v * w).sum(dim=dim, keepdim=True)
K2 = k**2
a = -K2 * uw * v2 - k * vw + 2 * K2 * uv * vw
b = -K2 * vw * u2 + k * uw
d = 1 - 2 * k * uv + K2 * u2 * v2
return w + 2 * (a * u + b * v) / d.clamp_min(1e-15)
def mobius_coadd(x: torch.Tensor, y: torch.Tensor, *, k: torch.Tensor, dim=-1):
r"""
Compute the Möbius gyrovector coaddition.
The addition operation :math:`\oplus_\kappa` is neither associative, nor
commutative. In contrast, the coaddition :math:`\boxplus_\kappa` (or
cooperation) is an associative operation that is defined as follows.
.. math::
a \boxplus_\kappa b
=
b \boxplus_\kappa a
=
a\operatorname{gyr}[a, -b]b\\
= \frac{
(1 + \kappa \|y\|^2_2) x + (1 + \kappa \|x\|_2^2) y
}{
1 + \kappa^2 \|x\|^2_2 \|y\|^2_2
},
where :math:`\operatorname{gyr}[a, b]v = \ominus_\kappa (a \oplus_\kappa b)
\oplus_\kappa (a \oplus_\kappa (b \oplus_\kappa v))`
The following right cancellation property holds
.. math::
(a \boxplus_\kappa b) \ominus_\kappa b = a\\
(a \oplus_\kappa b) \boxminus_\kappa b = a
Parameters
----------
x : tensor
point on manifold
y : tensor
point on manifold
k : tensor
sectional curvature of manifold
dim : int
reduction dimension for operations
Returns
-------
tensor
the result of the Möbius coaddition
"""
return _mobius_coadd(x, y, k, dim=dim)
# TODO: check numerical stability with Gregor's paper!!!
@torch.jit.script
def _mobius_coadd(x: torch.Tensor, y: torch.Tensor, k: torch.Tensor, dim: int = -1):
# x2 = x.pow(2).sum(dim=dim, keepdim=True)
# y2 = y.pow(2).sum(dim=dim, keepdim=True)
# num = (1 + K * y2) * x + (1 + K * x2) * y
# denom = 1 - K ** 2 * x2 * y2
# avoid division by zero in this way
# return num / denom.clamp_min(1e-15)
#
# return _mobius_add(x, _gyration(x, -y, y, k=k, dim=dim), k, dim=dim)
x2 = x.pow(2).sum(dim=dim, keepdim=True)
y2 = y.pow(2).sum(dim=dim, keepdim=True)
num = (1 + k * y2) * x + (1 + k * x2) * y
denom = 1 - k**2 * x2 * y2
return num / denom.clamp_min(1e-15)
def mobius_cosub(x: torch.Tensor, y: torch.Tensor, *, k: torch.Tensor, dim=-1):
r"""
Compute the Möbius gyrovector cosubtraction.
The Möbius cosubtraction is defined as follows:
.. math::
a \boxminus_\kappa b = a \boxplus_\kappa -b
Parameters
----------
x : tensor
point on manifold
y : tensor
point on manifold
k : tensor
sectional curvature of manifold
dim : int
reduction dimension for operations
Returns
-------
tensor
the result of the Möbius cosubtraction
"""
return _mobius_cosub(x, y, k, dim=dim)
@torch.jit.script
def _mobius_cosub(x: torch.Tensor, y: torch.Tensor, k: torch.Tensor, dim: int = -1):
return _mobius_coadd(x, -y, k, dim=dim)
# TODO: can we make this operation somehow safer by breaking up the
# TODO: scalar multiplication for K>0 when the argument to the
# TODO: tan function gets close to pi/2+k*pi for k in Z?
# TODO: one could use the scalar associative law
# TODO: s_1 (X) s_2 (X) x = (s_1*s_2) (X) x
# TODO: to implement a more stable Möbius scalar mult
def mobius_scalar_mul(r: torch.Tensor, x: torch.Tensor, *, k: torch.Tensor, dim=-1):
r"""
Compute the Möbius scalar multiplication.
.. math::
r \otimes_\kappa x
=
\tan_\kappa(r\tan_\kappa^{-1}(\|x\|_2))\frac{x}{\|x\|_2}
This operation has properties similar to the Euclidean scalar multiplication
* `n-addition` property
.. math::
r \otimes_\kappa x = x \oplus_\kappa \dots \oplus_\kappa x
* Distributive property
.. math::
(r_1 + r_2) \otimes_\kappa x
=
r_1 \otimes_\kappa x \oplus r_2 \otimes_\kappa x
* Scalar associativity
.. math::
(r_1 r_2) \otimes_\kappa x = r_1 \otimes_\kappa (r_2 \otimes_\kappa x)
* Monodistributivity
.. math::
r \otimes_\kappa (r_1 \otimes x \oplus r_2 \otimes x) =
r \otimes_\kappa (r_1 \otimes x) \oplus r \otimes (r_2 \otimes x)
* Scaling property
.. math::
|r| \otimes_\kappa x / \|r \otimes_\kappa x\|_2 = x/\|x\|_2
Parameters
----------
r : tensor
scalar for multiplication
x : tensor
point on manifold
k : tensor
sectional curvature of manifold
dim : int
reduction dimension for operations
Returns
-------
tensor
the result of the Möbius scalar multiplication
"""
return _mobius_scalar_mul(r, x, k, dim=dim)
@torch.jit.script
def _mobius_scalar_mul(
r: torch.Tensor, x: torch.Tensor, k: torch.Tensor, dim: int = -1
):
x_norm = x.norm(dim=dim, keepdim=True, p=2).clamp_min(1e-15)
res_c = tan_k(r * artan_k(x_norm, k), k) * (x / x_norm)
return res_c
def dist(x: torch.Tensor, y: torch.Tensor, *, k: torch.Tensor, keepdim=False, dim=-1):
r"""
Compute the geodesic distance between :math:`x` and :math:`y` on the manifold.
.. math::
d_\kappa(x, y) = 2\tan_\kappa^{-1}(\|(-x)\oplus_\kappa y\|_2)
.. plot:: plots/extended/stereographic/distance.py
Parameters
----------
x : tensor
point on manifold
y : tensor
point on manifold
k : tensor
sectional curvature of manifold
keepdim : bool
retain the last dim? (default: false)
dim : int
reduction dimension
Returns
-------
tensor
geodesic distance between :math:`x` and :math:`y`
"""
return _dist(x, y, k, keepdim=keepdim, dim=dim)
@torch.jit.script
def _dist(
x: torch.Tensor,
y: torch.Tensor,
k: torch.Tensor,
keepdim: bool = False,
dim: int = -1,
):
return 2.0 * artan_k(
_mobius_add(-x, y, k, dim=dim).norm(dim=dim, p=2, keepdim=keepdim), k
)
def dist_matmul(x: torch.Tensor, y: torch.Tensor, k: torch.Tensor):
r"""
Compute the geodesic distance between :math:`x` and :math:`y` on the manifold.
.. math::
d_\kappa(x, y) = 2\tan_\kappa^{-1}(\|(-x)\oplus_\kappa y\|_2)
.. plot:: plots/extended/stereographic/distance.py
Parameters
----------
x : tensor : (*, n, d)
point on manifold
y : tensor : (*, d, m)
point on manifold
k : tensor
sectional curvature of manifold
keepdim : bool
retain the last dim? (default: false)
dim : int
reduction dimension
Returns
-------
tensor
geodesic distance between :math:`x` and :math:`y`
"""
return _dist_matmul(x, y, k)
@torch.jit.script
def _dist_matmul(
x: torch.Tensor,
y: torch.Tensor,
k: torch.Tensor,
):
x2 = x.pow(2).sum(dim=-1, keepdim=True)
y2 = y.pow(2).sum(dim=-2, keepdim=True)
xy = torch.matmul(x, y)
num = x2 - 2 * xy + y2
denom = (1 + 2 * k * xy + k.pow(2) * x2 * y2).clamp_min(1e-15)
return 2.0 * artan_k((num / denom).clamp_min(1e-15).sqrt(), k)
def dist0(x: torch.Tensor, *, k: torch.Tensor, keepdim=False, dim=-1):
r"""
Compute geodesic distance to the manifold's origin.
Parameters
----------
x : tensor
point on manifold
k : tensor
sectional curvature of manifold
keepdim : bool
retain the last dim? (default: false)
dim : int
reduction dimension for operations
Returns
-------
tensor
geodesic distance between :math:`x` and :math:`0`
"""
return _dist0(x, k, keepdim=keepdim, dim=dim)
@torch.jit.script
def _dist0(x: torch.Tensor, k: torch.Tensor, keepdim: bool = False, dim: int = -1):
return 2.0 * artan_k(x.norm(dim=dim, p=2, keepdim=keepdim), k)
def geodesic(
t: torch.Tensor, x: torch.Tensor, y: torch.Tensor, *, k: torch.Tensor, dim=-1
):
r"""
Compute the point on the path connecting :math:`x` and :math:`y` at time :math:`x`.
The path can also be treated as an extension of the line segment to an
unbounded geodesic that goes through :math:`x` and :math:`y`. The equation
of the geodesic is given as:
.. math::
\gamma_{x\to y}(t)
=
x \oplus_\kappa t \otimes_\kappa ((-x) \oplus_\kappa y)
The properties of the geodesic are the following:
.. math::
\gamma_{x\to y}(0) = x\\
\gamma_{x\to y}(1) = y\\
\dot\gamma_{x\to y}(t) = v