Skip to content

Commit

Permalink
switch to backend-based
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Oct 28, 2024
1 parent 997f222 commit a8382e3
Show file tree
Hide file tree
Showing 8 changed files with 148 additions and 56 deletions.
1 change: 1 addition & 0 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ jobs:
- name: Install dependencies
run: |
julia --project=docs/ -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()'
julia -e 'using Pkg; Pkg.add("PythonPlot")'
- name: Setup README
run: cp README.md docs/src/README.md
Expand Down
6 changes: 0 additions & 6 deletions CondaPkg.toml

This file was deleted.

18 changes: 12 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,20 @@ version = "1.0.0"

[deps]
ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4"
CondaPkg = "992eb4ea-22a4-4c89-a5bb-47a3300528ab"
PythonPlot = "274fc56d-3b97-40fa-a1cd-1b4a50311bf9"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ColorSchemes = "3"
CondaPkg = "0.2"
PythonPlot = "1.0"
Reexport = "1"
Requires = "1.3.0"
PythonPlot = "1"
PyPlot = "2"
julia = "1"

[extensions]
PyPlotSlimExt = "PyPlot"
PythonPlotSlimExt = "PythonPlot"

[weakdeps]
PyPlot = "d330b81b-6aea-500a-939a-2ce795aea3ee"
PythonPlot = "274fc56d-3b97-40fa-a1cd-1b4a50311bf9"
14 changes: 0 additions & 14 deletions deps/build.jl

This file was deleted.

2 changes: 1 addition & 1 deletion examples/plot_example.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#' This example is converted to a markdown file for the documentation.

#' # Import SlimPlotting, SegyIO to read seismic data, JLD2 for hdf5-like files
using SlimPlotting, SegyIO, JLD2
using PythonPlot, SlimPlotting, SegyIO, JLD2

#' # Initialize all needed data

Expand Down
40 changes: 40 additions & 0 deletions ext/PyPlotSlimExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
module PyPlotSlimExt

import SlimPlotting: seiscm, getcmap

isdefined(Base, :get_extension) ? (using PyPlot) : (using ..PyPlot)
isdefined(Base, :get_extension) ? (using PyPlot.PyCall) : (using ..PyPlot.PyCall)

function tryimport(pkg::String)
pyi = try
PyPlot.pyimport(pkg)
catch e
if PyPlot.PyCall.conda
PyPlot.PyCall.Conda.pip_interop(true)
PyPlot.PyCall.Conda.pip("install", pkg)
else
run(PyPlot.PyCall.python_cmd(`-m pip install --user $(pkg)`))
end
PyPlot.pyimport(pkg)
end
return pyi
end

const cc = PyNULL()
const scm = PyNULL()

pypltref = PyPlot

function __init__()
@info "Initializing PyPlotSlimExt"
# Import colorcet
copy!(cc, tryimport("colorcet"))
# Import SeisCM
copy!(scm, tryimport("seiscm"))

end

seiscm(s::String) = py"scm.$s"()
getcmap(s) = ColorMap(s)

end
38 changes: 38 additions & 0 deletions ext/PythonPlotSlimExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
module PythonPlotSlimExt

import SlimPlotting: seiscm, getcmap
isdefined(Base, :get_extension) ? (using PythonPlot) : (using ..PythonPlot)

function tryimport(pkg::String)
pyi = try
PythonPlot.pyimport(pkg)
catch e
if get(ENV, "JULIA_CONDAPKG_BACKEND", "conda") == "Null"
pyexe = PythonPlot.PythonCall.python_executable_path()
run(Cmd(`$(pyexe) -m pip install --user $(pkg)`))
else
PythonPlot.CondaPkg.add_pip(pkg)
end
PythonPlot.pyimport(pkg)
end
return pyi
end

const cc = PythonPlot.PythonCall.pynew()
const scm = PythonPlot.PythonCall.pynew()

pypltref = PythonPlot

function __init__()
@info "Initializing PythonPlotSlimExt"
# Import colorcet
PythonPlot.PythonCall.pycopy!(cc, tryimport("colorcet"))
# Import SeisCM
PythonPlot.PythonCall.pycopy!(scm, tryimport("seiscm"))
end

seiscm(s::String) = PythonPlot.pygetattr(scm, s)()
getcmap(s) = ColorMap(s)
getcmap(c::PythonPlot.PythonCall.Py) = c

end
85 changes: 56 additions & 29 deletions src/SlimPlotting.jl
Original file line number Diff line number Diff line change
@@ -1,36 +1,56 @@
__precompile__()
module SlimPlotting

using Statistics, ColorSchemes, Reexport
@reexport using PythonPlot
using Statistics, ColorSchemes

const cc = PythonPlot.PythonCall.pynew()
const scm = PythonPlot.PythonCall.pynew()
# Only needed if extension not available (julia < 1.9)
if !isdefined(Base, :get_extension)
using Requires
end

function __init__()
ccall(:jl_generating_output, Cint, ()) == 1 && return nothing
# Import colorcet
PythonPlot.PythonCall.pycopy!(cc, PythonPlot.pyimport("colorcet"))
# Import SeisCM
PythonPlot.PythonCall.pycopy!(scm, PythonPlot.pyimport("seiscm"))
# Optional dependencies
@static if !isdefined(Base, :get_extension)
@require PyPlot="d330b81b-6aea-500a-939a-2ce795aea3ee" begin
@info "PyPlot compat enabled"
include("../ext/PyPlotSlimExt.jl")
end

@require PythonPlot="274fc56d-3b97-40fa-a1cd-1b4a50311bf9" begin
@info "PythonPlot compat enabled"
include("../ext/PythonPlotSlimExt.jl")
end
end
end

export plot_fslice, plot_velocity, plot_simage, plot_sdata, wiggle_plot, compare_shots
export colorschemes, seiscm

function getcmap end

function get_extension()
# get backend
ext = Base.get_extension(@__MODULE__, :PythonPlotSlimExt)
if isnothing(ext)
ext = Base.get_extension(@__MODULE__, :PyPlotSlimExt)
end
isnothing(ext) && throw(MissingException("No plotting backend found, either PyPlot or PythonPlot need to be loaded in the script"))

return ext.pypltref
end

# String conversion in case of legacy :symbol input for PyPlot
to_string(x::Symbol) = string(x)
to_string(x) = x
to_symbol(x::String) = Symbol(x)
to_symbol(x) = x


"""
seiscm(name)
Return the colormap `name` for seiscm. These colormap are preimported as a dictionnary
"""
seiscm(s::Symbol) = seiscm(to_string(s))
seiscm(s::String) = PythonPlot.pygetattr(scm, s)()

"""
_plot_with_units(image, spacing; perc=95, cmap=:cet_CET_L1,
Expand Down Expand Up @@ -90,14 +110,18 @@ function _plot_with_units(

# color map
cmap = try
ColorMap(to_string(cmap))
getcmap(to_string(cmap))
catch
ColorMap(colorschemes[to_symbol(cmap)].colors)
getcmap(colorschemes[to_symbol(cmap)].colors)
end
new_fig && figure()

backend = get_extension()

# Create new figure
new_fig && backend.figure()
# Plot
if !isnothing(alpha)
imshow(
backend.imshow(
scaled,
vmin = ma,
vmax = a,
Expand All @@ -108,7 +132,7 @@ function _plot_with_units(
alpha = alpha,
)
else
imshow(
backend.imshow(
scaled,
vmin = ma,
vmax = a,
Expand All @@ -118,14 +142,14 @@ function _plot_with_units(
extent = extent,
)
end
xlabel("$(labels[1]) [$(units[1])]")
ylabel("$(labels[2]) [$(units[2])]")
title("$name")
cbar && colorbar(fraction = 0.046, pad = 0.04)
backend.xlabel("$(labels[1]) [$(units[1])]")
backend.ylabel("$(labels[2]) [$(units[2])]")
backend.title("$name")
cbar && backend.colorbar(fraction = 0.046, pad = 0.04)

if ~isnothing(save)
save == true ? filename = name : filename = save
savefig(filename, bbox_inches = "tight", dpi = 150)
backend.savefig(filename, bbox_inches = "tight", dpi = 150)
end
end

Expand Down Expand Up @@ -332,7 +356,7 @@ function compare_shots(image, image2, spacing; chunksize = 20, kw...)
else
cmap2 = pop!(kwd, :cmap, :cet_CET_D1A)
if cmap2 == cmap1
cmap2 = ColorMap(cmap1).reversed()
cmap2 = getcmap(cmap1).reversed()
end
end
# Zero out to alternate
Expand Down Expand Up @@ -449,19 +473,22 @@ function wiggle_plot(
error("time_axis must be the same length as the number of rows in data")
# Time gain
tg = time_axis .^ t_scale
new_fig && figure()

ylim(maximum(time_axis), minimum(time_axis))
xlim(minimum(xrec), maximum(xrec))
backend = get_extension()

new_fig && backend.figure()

backend.ylim(maximum(time_axis), minimum(time_axis))
backend.xlim(minimum(xrec), maximum(xrec))
for (i, xr) enumerate(xrec)
x = tg .* data[:, i]
x = dx[i] * x ./ maximum(x) .+ xr
# rescale to avoid large spikes
plot(x, time_axis, "k-")
fill_betweenx(time_axis, xr, x, where = (x .> xr), color = "k")
backend.plot(x, time_axis, "k-")
backend.fill_betweenx(time_axis, xr, x, where = (x .> xr), color = "k")
end
xlabel("X")
ylabel("Time")
backend.xlabel("X")
backend.ylabel("Time")
end

end # module

0 comments on commit a8382e3

Please sign in to comment.