Commit 1878b3ad authored by Dennis Gläser's avatar Dennis Gläser
Browse files

[python][samplers] use __call__ to be in agreement with c++ impl

Entities are sampled from sampler classes in the C++ code by calling the
operator(). In order to streamline the Python implementation, we used
the name __call__ in the python classes to get the same syntax.
parent 907cc366
...@@ -71,7 +71,7 @@ entities2 = [] ...@@ -71,7 +71,7 @@ entities2 = []
print("\n --- Start entity sampling ---\n") print("\n --- Start entity sampling ---\n")
while not status.finished(): while not status.finished():
# sample a quadrilateral, alternating between sampler 1 and sampler 2 # sample a quadrilateral, alternating between sampler 1 and sampler 2
quad = quadSampler1.sample() if sampleIntoSet1 else quadSampler2.sample() quad = quadSampler1() if sampleIntoSet1 else quadSampler2()
entitySet = entities1 if sampleIntoSet1 else entities2 entitySet = entities1 if sampleIntoSet1 else entities2
otherEntitySet = entities2 if sampleIntoSet1 else entities1 otherEntitySet = entities2 if sampleIntoSet1 else entities1
......
...@@ -78,7 +78,7 @@ entities2 = [] ...@@ -78,7 +78,7 @@ entities2 = []
print("\n --- Start entity sampling ---\n") print("\n --- Start entity sampling ---\n")
while not status.finished(): while not status.finished():
# sample a quadrilateral, alternating between sampler 1 and sampler 2 # sample a quadrilateral, alternating between sampler 1 and sampler 2
quad = quadSampler1.sample() if sampleIntoSet1 else quadSampler2.sample() quad = quadSampler1() if sampleIntoSet1 else quadSampler2()
entitySet = entities1 if sampleIntoSet1 else entities2 entitySet = entities1 if sampleIntoSet1 else entities2
otherEntitySet = entities2 if sampleIntoSet1 else entities1 otherEntitySet = entities2 if sampleIntoSet1 else entities1
......
...@@ -137,7 +137,7 @@ sampleIntoSet1 = True ...@@ -137,7 +137,7 @@ sampleIntoSet1 = True
containedNetworkArea = 0.0; containedNetworkArea = 0.0;
while not status.finished(): while not status.finished():
id = diskSetId if sampleIntoSet1 else quadSetId id = diskSetId if sampleIntoSet1 else quadSetId
geom = diskSampler.sample() if sampleIntoSet1 else quadSampler.sample() geom = diskSampler() if sampleIntoSet1 else quadSampler()
# If the set this geometry belongs to is finished, skip the rest # If the set this geometry belongs to is finished, skip the rest
if status.finished(id): if status.finished(id):
......
...@@ -54,7 +54,7 @@ void registerBoxPointSampler(py::module& module) ...@@ -54,7 +54,7 @@ void registerBoxPointSampler(py::module& module)
"distributionX"_a, "distributionY"_a, "distributionZ"_a); "distributionX"_a, "distributionY"_a, "distributionZ"_a);
// register point sample function // register point sample function
cls.def("sample", &UniformSampler::operator(), "returns a randomly sampled point"); cls.def("__call__", &UniformSampler::operator(), "returns a randomly sampled point");
} }
} // end namespace Frackit::Python } // end namespace Frackit::Python
......
...@@ -56,7 +56,7 @@ void registerCylinderPointSampler(py::module& module) ...@@ -56,7 +56,7 @@ void registerCylinderPointSampler(py::module& module)
"cylinder"_a, "distributionR"_a, "distributionPhi"_a, "distributionH"_a); "cylinder"_a, "distributionR"_a, "distributionPhi"_a, "distributionH"_a);
// register point sample function // register point sample function
cls.def("sample", &UniformSampler::operator(), "returns a randomly sampled point"); cls.def("__ca,,__", &UniformSampler::operator(), "returns a randomly sampled point");
} }
} // end namespace Frackit::Python } // end namespace Frackit::Python
......
...@@ -17,7 +17,7 @@ class BoxPointSampler: ...@@ -17,7 +17,7 @@ class BoxPointSampler:
self.samplerY = samplerY self.samplerY = samplerY
self.samplerZ = samplerZ self.samplerZ = samplerZ
def sample(self): def __call__(self):
x = self.samplerX() x = self.samplerX()
y = self.samplerY() y = self.samplerY()
z = self.samplerZ() z = self.samplerZ()
...@@ -48,7 +48,7 @@ class CylinderPointSampler: ...@@ -48,7 +48,7 @@ class CylinderPointSampler:
self.samplerPhi = samplerPhi self.samplerPhi = samplerPhi
self.samplerZ = samplerZ self.samplerZ = samplerZ
def sample(self): def __call__(self):
from frackit.geometry import Vector_3 from frackit.geometry import Vector_3
a = Vector_3(self.cylinder.bottomFace().majorAxis()); a = Vector_3(self.cylinder.bottomFace().majorAxis());
...@@ -161,7 +161,7 @@ class DiskSampler: ...@@ -161,7 +161,7 @@ class DiskSampler:
self.yAngleSampler = yAngleSampler self.yAngleSampler = yAngleSampler
self.zAngleSampler = zAngleSampler self.zAngleSampler = zAngleSampler
def sample(self): def __call__(self):
a = self.majAxisSampler() a = self.majAxisSampler()
while (a <= 0.0): a = self.majAxisSampler() while (a <= 0.0): a = self.majAxisSampler()
...@@ -193,7 +193,7 @@ class DiskSampler: ...@@ -193,7 +193,7 @@ class DiskSampler:
# sample center point and make disk # sample center point and make disk
from frackit.geometry import Ellipse_3 from frackit.geometry import Ellipse_3
ellipse = Ellipse_3(self.pointSampler.sample(), ellipse = Ellipse_3(self.pointSampler(),
Direction_3(axes[0]), Direction_3(axes[0]),
Direction_3(axes[1]), Direction_3(axes[1]),
a, b) a, b)
...@@ -221,7 +221,7 @@ class QuadrilateralSampler: ...@@ -221,7 +221,7 @@ class QuadrilateralSampler:
self.edgeLengthSampler = edgeLengthSampler self.edgeLengthSampler = edgeLengthSampler
self.minEdgeLength = minEdgeLength self.minEdgeLength = minEdgeLength
def sample(self): def __call__(self):
strike = self.strikeAngleSampler() strike = self.strikeAngleSampler()
dip = self.dipAngleSampler() dip = self.dipAngleSampler()
...@@ -255,7 +255,7 @@ class QuadrilateralSampler: ...@@ -255,7 +255,7 @@ class QuadrilateralSampler:
dyVec2 = Vector_3(deepcopy(axes[1].x()), deepcopy(axes[1].y()), deepcopy(axes[1].z())); dyVec2 *= dy2/2.0 dyVec2 = Vector_3(deepcopy(axes[1].x()), deepcopy(axes[1].y()), deepcopy(axes[1].z())); dyVec2 *= dy2/2.0
# compute corner points # compute corner points
c = self.pointSampler.sample() c = self.pointSampler()
from frackit.geometry import Point_3 from frackit.geometry import Point_3
c1 = Point_3(deepcopy(c.x()), deepcopy(c.y()), deepcopy(c.z())); c1 -= dxVec1; c1 -= dyVec1 c1 = Point_3(deepcopy(c.x()), deepcopy(c.y()), deepcopy(c.z())); c1 -= dxVec1; c1 -= dyVec1
c2 = Point_3(deepcopy(c.x()), deepcopy(c.y()), deepcopy(c.z())); c2 += dxVec1; c2 -= dyVec2 c2 = Point_3(deepcopy(c.x()), deepcopy(c.y()), deepcopy(c.z())); c2 += dxVec1; c2 -= dyVec2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment