pyuni10.Contract() result is wrong
Consider the direct product of two Tensors.
import pyuni10 as uni10
A = uni10.arange(4)
A.reshape_(2,2)
Id = uni10.zeros([2, 2])
Id[0, 0] = 2
Id[1, 1] = 2
AId = uni10.linalg.Kron(A,Id)
print(AId)
The result is
Shape : (4,4)
[[0.00000e+00 0.00000e+00 2.00000e+00 0.00000e+00 ]
[0.00000e+00 0.00000e+00 0.00000e+00 2.00000e+00 ]
[4.00000e+00 0.00000e+00 6.00000e+00 0.00000e+00 ]
[0.00000e+00 4.00000e+00 0.00000e+00 6.00000e+00 ]]
Now we promote two Tensors to UniTensors and use uni10.Contract to contract them. We set up the labels so that all labels are distinct, the result should be the same as the direct product.
T_A = uni10.UniTensor(A, 1)
T_A.print_diagram()
T_Id = uni10.UniTensor(Id, 1)
T_Id.set_labels([2, 3])
T_Id.print_diagram()
T_A_Id = uni10.Contract(T_A, T_Id)
T_A_Id.print_diagram()
X = T_A_Id.get_block()
X.reshape_(4,4)
print(X)
T_A_Id.permute_([0,2,1,3],2)
T_A_Id.print_diagram()
X = T_A_Id.get_block()
X.reshape_(4,4)
print(X)
If I don't perform permute after the contract, the labels are wrong. But if I get_block the content is correct.
-----------------------
tensor Name :
tensor Rank : 4
block_form : false
is_diag : False
on device : cytnx device: CPU
-------------
/ \
0 ____| 2 2 |____ 2
| |
1 ____| 2 2 |____ 3
\ /
-------------
Total elem: 16
type : Double (Float64)
cytnx device: CPU
Shape : (4,4)
[[0.00000e+00 0.00000e+00 2.00000e+00 0.00000e+00 ]
[0.00000e+00 0.00000e+00 0.00000e+00 2.00000e+00 ]
[4.00000e+00 0.00000e+00 6.00000e+00 0.00000e+00 ]
[0.00000e+00 4.00000e+00 0.00000e+00 6.00000e+00 ]]
If I do permute the labels to the correct order, the context is wrong.
-----------------------
tensor Name :
tensor Rank : 4
block_form : false
is_diag : False
on device : cytnx device: CPU
-------------
/ \
0 ____| 2 2 |____ 1
| |
2 ____| 2 2 |____ 3
\ /
-------------
Total elem: 16
type : Double (Float64)
cytnx device: CPU
Shape : (4,4)
[[0.00000e+00 0.00000e+00 0.00000e+00 0.00000e+00 ]
[2.00000e+00 0.00000e+00 0.00000e+00 2.00000e+00 ]
[4.00000e+00 0.00000e+00 0.00000e+00 4.00000e+00 ]
[6.00000e+00 0.00000e+00 0.00000e+00 6.00000e+00 ]]