Skip to content

Commit

Permalink
works #13
Browse files Browse the repository at this point in the history
  • Loading branch information
fnrizzi committed May 27, 2021
1 parent 5197f5b commit caaa683
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 34 deletions.
58 changes: 45 additions & 13 deletions meshing_scripts/create_full_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from one_dim_mesh import OneDimMesh
from natural_order_mesh import NatOrdMeshRow
from mesh_utils import *
from remove_cells_step import *

# ------------------------------------------------
def checkDomainBoundsCoordinates(xValues, yValues):
Expand Down Expand Up @@ -44,6 +45,8 @@ def main(workDir, debug, Nx, Ny, \

# figure out if domain is a plain rectangle or has a step in it
plainDomain = True
print(len(xBd))
print(xBd)
if dim!=1 and len(xBd) > 4:
plainDomain = False

Expand All @@ -59,22 +62,39 @@ def main(workDir, debug, Nx, Ny, \
axFM = figFM.gca()

# ---------------------------------------------
x = None
y = None
G = None
gids = None
if dim == 1:
meshObj = OneDimMesh(Nx, dx, xBd[0], xBd[1], stencilSize, enablePeriodicBc)
# get mesh coordinates, gids and graph
[x, y] = meshObj.getCoordinates()
gids = meshObj.getGIDs()
G = meshObj.getGraph()

elif dim==2 and plainDomain:
meshObj = NatOrdMeshRow(Nx,Ny, dx,dy,\
xBd[0], xBd[1], yBd[0], yBd[1], \
stencilSize, enablePeriodicBc)

else:
print("invalid domain: with step not impl yet")
sys.exit(1)
# get mesh coordinates, gids and graph
[x, y] = meshObj.getCoordinates()
gids = meshObj.getGIDs()
G = meshObj.getGraph()

elif dim==2 and not plainDomain:
minX, maxX = min(xBd), max(xBd)
minY, maxY = min(yBd), max(yBd)
L = [maxX-minX, maxY-minY]
dx,dy = L[0]/Nx, L[1]/Ny

meshObj = NatOrdMeshRow(Nx, Ny, dx, dy,\
minX, maxX, minY, maxY, \
stencilSize, False)

x,y,G,gids = removeStepCells(meshObj, xBd, yBd)

# get mesh coordinates, gids and graph
[x, y] = meshObj.getCoordinates()
gids = meshObj.getGIDs()
G = meshObj.getGraph()
if debug:
print("full mesh connectivity")
printDicPretty(G)
Expand Down Expand Up @@ -136,12 +156,13 @@ def main(workDir, debug, Nx, Ny, \
f.write("stencilSize %2d\n" % stencilSize)
f.write("nx %8d\n" % Nx)
f.close()

else:
f.write("dim %1d\n" % 2)
f.write("xMin %.14f\n" % xBd[0])
f.write("xMax %.14f\n" % xBd[1])
f.write("yMin %.14f\n" % yBd[0])
f.write("yMax %.14f\n" % yBd[1])
f.write("xMin %.14f\n" % min(xBd))
f.write("xMax %.14f\n" % max(xBd))
f.write("yMin %.14f\n" % min(yBd))
f.write("yMax %.14f\n" % max(yBd))
f.write("dx %.14f\n" % dx)
f.write("dy %.14f\n" % dy)
f.write("sampleMeshSize %8d\n" % meshSize)
Expand Down Expand Up @@ -174,9 +195,9 @@ def main(workDir, debug, Nx, Ny, \
if plotting != "none":
plotLabels(x, y, dx, dy, gids, axFM, fontSz=plotFontSize)
axFM.set_aspect(1.0)
axFM.set_xlim(xBd[0], xBd[1])
axFM.set_xlim(min(xBd), max(xBd))
if dim!=1:
axFM.set_ylim(yBd[0], yBd[1])
axFM.set_ylim(min(yBd), max(yBd))

if plotting == "show":
plt.show()
Expand Down Expand Up @@ -273,6 +294,17 @@ def main(workDir, debug, Nx, Ny, \
yCoords = [args.bounds[2], args.bounds[3]]
dim = 2

elif nx!=1 and ny!=1 and len(args.bounds) == 12:
xCoords = []
yCoords = []
print(args.bounds)
for i,v in enumerate(args.bounds):
if i % 2 == 0:
xCoords.append(v)
else:
yCoords.append(v)
dim = 2

# other things
plotFontSize = 0
if len(args.plottingInfo) == 2:
Expand Down
44 changes: 23 additions & 21 deletions meshing_scripts/create_sample_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,9 @@ def readFullMeshConnec(fullMeshDir):
lineList = line.split()
numNeigh = len(lineList[1:])
connections = [np.int64(i) for i in lineList[1:]]
#np.zeros(numNeigh, dtype=np.int64)
#for i in range(numNeigh):
#connections[i] = np.int64(lineList[1+i])
pt = np.int64(lineList[0])
gids.append(pt)
G[pt] = np.copy(connections)

line = fp.readline()
cnt += 1

Expand All @@ -51,6 +47,7 @@ def readFullMeshConnec(fullMeshDir):

#====================================================================
def readFullMeshInfo(fullMeshDir):
dim=2
bounds=[None]*4
dx,dy = 0., 0.
sampleMeshSize = 0
Expand Down Expand Up @@ -104,6 +101,7 @@ def readFullMeshCoordinates(fullMeshDir):
cnt += 1
return np.array(x), np.array(y)

#====================================================================
#====================================================================
def main(workDir, debug, fullMeshDir, tilingDir, plotting, plotFontSize):
dx,dy,numCells,domainBounds,stencilSize = readFullMeshInfo(fullMeshDir)
Expand All @@ -119,8 +117,8 @@ def main(workDir, debug, fullMeshDir, tilingDir, plotting, plotFontSize):
if plotting != "none":
plotLabels(x, y, dx, dy, gids, axFM, fontSz=plotFontSize)
axFM.set_aspect(1.0)
axFM.set_xlim(np.min(x)-0.5,np.max(x)+0.5)
axFM.set_ylim(np.min(y)-0.5,np.max(y)+0.5)
axFM.set_xlim(np.min(x)-dx*0.5, np.max(x)+dx*0.5)
axFM.set_ylim(np.min(y)-dy*0.5, np.max(y)+dy*0.5)

if debug:
print("natural order full mesh connectivity")
Expand All @@ -141,7 +139,7 @@ def main(workDir, debug, fullMeshDir, tilingDir, plotting, plotFontSize):
# -----------------------------------------------------
smGraph0 = collections.OrderedDict()
for rPt in sampleMeshGIDs:
smGraph0[rPt] = G[rPt]
smGraph0[rPt] = G[rPt]
print("\n")
if debug:
print("sample mesh graph0 (IDs wrt full mesh)")
Expand All @@ -150,10 +148,11 @@ def main(workDir, debug, fullMeshDir, tilingDir, plotting, plotFontSize):
stencilMeshGIDs = []
# loop over target cells wherw we want residual
for k,v in smGraph0.items():
# append the GID of this cell
stencilMeshGIDs.append(k)
# append GID of stencils/neighborin cells
for j in v:
# append the GID of this cell
stencilMeshGIDs.append(k)
# append GID of stencils/neighborin cells
for j in v:
if j != -1:
stencilMeshGIDs.append(np.int64(j))

# remove duplicates and sort
Expand All @@ -180,6 +179,7 @@ def main(workDir, debug, fullMeshDir, tilingDir, plotting, plotFontSize):

if tilingDir!=None:
gidsfiles = glob.glob(tilingDir+"/cell_gids_p_*.txt")

# sort based on the ID, so need to extract ID which is last of dir name
def func(elem): return int(elem.split('_')[-1].split('.')[0])
gidsfiles = sorted(gidsfiles, key=func)
Expand Down Expand Up @@ -211,14 +211,14 @@ def func(elem): return int(elem.split('_')[-1].split('.')[0])
if debug:
for k,v in fm_to_sm_map.items(): print(k, v)


print("doing now the sm -> fm gids mapping ")
sm_to_fm_map = collections.OrderedDict()
for k,v in fm_to_sm_map.items():
sm_to_fm_map[v] = k
print("Done with sm_to_fm_map")
if debug:
for k,v in sm_to_fm_map.items(): print(k, v)
for k,v in sm_to_fm_map.items():
print(k, v)

# -----------------------------------------------------
# Here we have a list of unique GIDs for the sample mesh.
Expand All @@ -227,12 +227,13 @@ def func(elem): return int(elem.split('_')[-1].split('.')[0])
# -----------------------------------------------------
sampleMeshGraph = collections.OrderedDict()
for rGidFM, v in smGraph0.items():
smGID = fm_to_sm_map[rGidFM]
smStencilGIDs = v
for i in range(len(smStencilGIDs)):
thisGID = smStencilGIDs[i]
smStencilGIDs[i] = fm_to_sm_map[thisGID]
sampleMeshGraph[smGID] = smStencilGIDs
smGID = fm_to_sm_map[rGidFM]
smStencilGIDs = v
for i in range(len(smStencilGIDs)):
thisGID = smStencilGIDs[i]
if thisGID != -1:
smStencilGIDs[i] = fm_to_sm_map[thisGID]
sampleMeshGraph[smGID] = smStencilGIDs

print("\n")
print("Done with sampleMeshGraph")
Expand All @@ -245,8 +246,9 @@ def func(elem): return int(elem.split('_')[-1].split('.')[0])
x2, y2 = x[ list(fm_to_sm_map.keys()) ], y[ list(fm_to_sm_map.keys()) ]
plotLabels(x2, y2, dx, dy, gids_sm, axSM, 's', 'r', 0, fontSz=plotFontSize)
axSM.set_aspect(1.0)
axSM.set_xlim(np.min(x)-0.5,np.max(x)+0.5)
axSM.set_ylim(np.min(y)-0.5,np.max(y)+0.5)
axSM.set_xlim(np.min(x)-dx*0.5, np.max(x)+dx*0.5)
axSM.set_ylim(np.min(y)-dy*0.5, np.max(y)+dy*0.5)


# -----------------------------------------------------
sampleMeshSize = len(sampleMeshGraph)
Expand Down
3 changes: 3 additions & 0 deletions meshing_scripts/natural_order_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def __init__(self, Nx, Ny, dx, dy, \
self.buildGraph(enablePeriodicBc)
#self.spMat_ = convertGraphDicToSparseMatrix(self.G_)

def getTotCells(self):
return self.Nx_ * self.Ny_

def getCoordinates(self):
return [self.x_, self.y_]

Expand Down
60 changes: 60 additions & 0 deletions meshing_scripts/remove_cells_step.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@

import matplotlib.pyplot as plt
import sys, os, time
import numpy as np
from numpy import linspace, meshgrid
from matplotlib import cm
import collections
from argparse import ArgumentParser
import random
import scipy.sparse as sp
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import reverse_cuthill_mckee

from mesh_utils import printDicPretty

def removeStepCells(meshObj, xBd, yBd):

stepXb = [xBd[2], xBd[3]]
stepYb = [yBd[1], yBd[2]]
print("stepXb ", stepXb)
print("stepYb ", stepYb)

totNumCells = meshObj.getTotCells()
[x, y] = meshObj.getCoordinates()
gids = meshObj.getGIDs()
G = meshObj.getGraph()

a = np.where(x>stepXb[0])
b = np.where(x<stepXb[1])
c = np.where(y>stepYb[0])
d = np.where(y<stepYb[1])
r1 = np.intersect1d(a,b)
r2 = np.intersect1d(r1,c)
stepCellsGids = np.intersect1d(r2,d)
domainGids = list(set(np.arange(totNumCells)).difference(set(stepCellsGids)))
#print("stepCellsGids = ", stepCellsGids)
#print("domainGids = ", domainGids)

G2 = {}
for k,v in G.items():
if k not in stepCellsGids:
v2 = [i if i not in stepCellsGids else -1 for i in v]
G2[k] = v2
#printDicPretty(G2)

# reindex
newIds = {}
count = 0
for k in G2.keys():
newIds[k] = count
count+=1

# fix graph
G3 = {}
for k,v in G2.items():
v2 = [newIds[i] if i != -1 else -1 for i in v]
G3[newIds[k]] = v2

gids = np.array(list(G3.keys()))
return x[domainGids], y[domainGids], G3, gids

0 comments on commit caaa683

Please sign in to comment.