Skip to content

Commit

Permalink
Update and apply formatter
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Sep 29, 2023
1 parent 6e0bb25 commit e6328f5
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 22 deletions.
2 changes: 1 addition & 1 deletion .dev/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"

[compat]
JuliaFormatter = "0.3"
JuliaFormatter = "1"
8 changes: 4 additions & 4 deletions src/transform_fourier.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ function truncate_modes(tr::FourierTransform{N}, coeff) where {N}

# indices for the spectral coefficients that we need to retain
inds = [
vcat(collect(1:(m)), collect((s - m + 1):s))
for (s, m) in zip(size(coeff)[2:(end - 1)], tr.modes)
vcat(collect(1:(m)), collect((s - m + 1):s)) for
(s, m) in zip(size(coeff)[2:(end - 1)], tr.modes)
]

# we need to handle the first dimension of the real Fourier transform
Expand All @@ -63,8 +63,8 @@ function pad_modes(
size_padded = (size(coeff)[1], size_pad..., size(coeff)[end])
coeff_padded = zeros(eltype(coeff), size_padded)
inds = [
vcat(collect(1:(div(m, 2) + 1)), collect((s - div(m, 2) + 2):s))
for (s, m) in zip(size_pad, size(coeff)[2:(end - 1)])
vcat(collect(1:(div(m, 2) + 1)), collect((s - div(m, 2) + 2):s)) for
(s, m) in zip(size_pad, size(coeff)[2:(end - 1)])
]

# we need to handle the first dimension of the real Fourier transform
Expand Down
21 changes: 6 additions & 15 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,8 @@ mview(c, inds, ::Val{2}) = view(c, :, inds[1], inds[2], :)
mview(c, inds, ::Val{3}) = view(c, :, inds[1], inds[2], inds[3], :)
mview(c, inds, ::Val{4}) = view(c, :, inds[1], inds[2], inds[3], inds[4], :)

tensor_contraction(
A,
B,
::Val{1},
) = @tullio C[o, a, b] := A[i, a, o] * B[i, a, b]
tensor_contraction(A, B, ::Val{1}) =
@tullio C[o, a, b] := A[i, a, o] * B[i, a, b]
tensor_contraction(A, B, ::Val{2}) =
@tullio C[o, a₁, a₂, b] := A[i, a₁, a₂, o] * B[i, a₁, a₂, b]
tensor_contraction(A, B, ::Val{3}) =
Expand All @@ -16,21 +13,15 @@ tensor_contraction(A, B, ::Val{4}) = @tullio C[o, a₁, a₂, a₃, a₄, b] :=
A[i, a₁, a₂, a₃, a₄, o] * B[i, a₁, a₂, a₃, a₄, b]

sparse_mean(w, c, ::Val{1}) = @tullio μ[o, a, b] := w[i, o] * c[i, a, b]
sparse_mean(
w,
c,
::Val{2},
) = @tullio μ[o, a₁, a₂, b] := w[i, o] * c[i, a₁, a₂, b]
sparse_mean(w, c, ::Val{2}) =
@tullio μ[o, a₁, a₂, b] := w[i, o] * c[i, a₁, a₂, b]
sparse_mean(w, c, ::Val{3}) =
@tullio μ[o, a₁, a₂, a₃, b] := w[i, o] * c[i, a₁, a₂, a₃, b]
sparse_mean(w, c, ::Val{4}) =
@tullio μ[o, a₁, a₂, a₃, a₄, b] := w[i, o] * c[i, a₁, a₂, a₃, a₄, b]

sparse_covariance(
w,
c,
::Val{1},
) = @tullio μ[o, a, r, b] := w[i, r, o] * c[i, a, b]
sparse_covariance(w, c, ::Val{1}) =
@tullio μ[o, a, r, b] := w[i, r, o] * c[i, a, b]

Check warning on line 24 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L23-L24

Added lines #L23 - L24 were not covered by tests
sparse_covariance(w, c, ::Val{2}) =
@tullio μ[o, a₁, a₂, r, b] := w[i, r, o] * c[i, a₁, a₂, b]
sparse_covariance(w, c, ::Val{3}) =
Expand Down
4 changes: 2 additions & 2 deletions test/transform_fourier.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ using LinearAlgebra
trafo = FourierTransform(modes = (12, 5))
c = rand(27, 16, 14, 1)
inds = [
[(1:(div(m, 2) + 1))..., ((s - div(m, 2) + 2):s)...]
for (s, m) in zip((M, 14), size(c)[2:(end - 1)])
[(1:(div(m, 2) + 1))..., ((s - div(m, 2) + 2):s)...] for
(s, m) in zip((M, 14), size(c)[2:(end - 1)])
]
inds[1] = collect(1:16)
@test all(OperatorFlux.pad_modes(trafo, c, (32, 14))[:, inds..., :] .== c)
Expand Down

0 comments on commit e6328f5

Please sign in to comment.