# Adopted from https://github.com/kLabUM/rrcf
from typing import Literal
import numpy as np
from dtaianomaly.anomaly_detection._BaseDetector import BaseDetector, Supervision
from dtaianomaly.type_validation import (
BoolAttribute,
FloatAttribute,
IntegerAttribute,
LiteralAttribute,
NoneAttribute,
WindowSizeAttribute,
)
from dtaianomaly.windowing import (
WINDOW_SIZE_TYPE,
compute_window_size,
reverse_sliding_window,
sliding_window,
)
__all__ = ["RobustRandomCutForestAnomalyDetector"]
[docs]
class RobustRandomCutForestAnomalyDetector(BaseDetector):
"""
Detect anomalies using robust random cut forest :cite:`guha2016robust`.
A random cut tree is a binary tree similar to an isolation free. The main differerence
is how the dimension to split on is selected, and how the anomaly scores are comupted.
For a random cut tree, the dimension is chosen based on the difference between the
minimum and maximum value along that dimension, to prioritise dimensions with wider
spread. The anomaly score is computed based on collusive displacement: a sample is
more anomalous if the height of other samples within the tree is substantially smaller
when the sample is removed from the tree. This is based on the assumption that anomalies
make it harder to explain the dataset as a whole. A robust random cut tree then consists
of multiple random cut trees. Because the samples can be dynamically removed from and
added to the tree, the model can also deal with streaming data.
Parameters
----------
window_size : int or str
The window size to use for extracting sliding windows from the time series. This
value will be passed to :py:meth:`~dtaianomaly.anomaly_detection.compute_window_size`.
stride : int, default=1
The stride, i.e., the step size for extracting sliding windows from the time series.
online_learning : bool, default=True
Whether to perform online learning, i.e., update the trees when detecting anomalies.
n_estimators : int, default=100
The number of base trees in the ensemble.
max_samples : int or float, default='auto'
The number of samples to draw for training each base estimator:
- if ``int``: Draw at most ``max_samples`` samples.
- if ``float``: Draw at most ``max_samples`` percentage of the samples.
- if ``'auto'``: Set ``max_samples=min(256, n_windows)``.
precision : int, default=9
Floating-point precision for distinguishing duplicate points.
random_state : int, default=None
The seed used to create a random number generator from numpy.
Attributes
----------
window_size_ : int
The effectively used window size for this anomaly detector
max_samples_ : int
The effectively used maximum number of samples.
forest_ : list of RCTree
The trees which is used to isolate the observations.
Examples
--------
>>> from dtaianomaly.anomaly_detection import RobustRandomCutForestAnomalyDetector
>>> from dtaianomaly.data import demonstration_time_series
>>> x, y = demonstration_time_series()
>>> rrcf = RobustRandomCutForestAnomalyDetector(10, random_state=0).fit(x)
>>> rrcf.decision_function(x) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
array([3.72400345, 3.92008391, 4.05253107, ..., 3.74515887, 3.64998068,
3.26297901]...)
"""
window_size: WINDOW_SIZE_TYPE
stride: int
online_learning: bool
n_estimators: int
max_samples: int | float | Literal["auto"]
precision: int
random_state: int | None
window_size_: int
max_samples_: int
forest_: list["RCTree"]
attribute_validation = {
"window_size": WindowSizeAttribute(),
"stride": IntegerAttribute(1),
"online_learning": BoolAttribute(),
"n_estimators": IntegerAttribute(1),
"max_samples": IntegerAttribute(minimum=1)
| FloatAttribute(0.0, 1.0, inclusive_minimum=False)
| LiteralAttribute("auto"),
"precision": IntegerAttribute(1),
"random_state": IntegerAttribute() | NoneAttribute(),
}
def __init__(
self,
window_size: WINDOW_SIZE_TYPE,
stride: int = 1,
online_learning: bool = False,
n_estimators: int = 100,
max_samples: int | float | Literal["auto"] = "auto",
precision: int = 9,
random_state: int = None,
):
super().__init__(Supervision.UNSUPERVISED)
self.window_size = window_size
self.stride = stride
self.online_learning = online_learning
self.n_estimators = n_estimators
self.max_samples = max_samples
self.precision = precision
self.random_state = random_state
def _fit(self, X: np.ndarray, y: np.ndarray = None, **kwargs) -> None:
self.window_size_ = compute_window_size(X, self.window_size, **kwargs)
if isinstance(self.max_samples, float):
self.max_samples_ = int(X.shape[0] * self.max_samples)
elif isinstance(self.max_samples, int):
self.max_samples_ = self.max_samples
else: # self.max_samples == "auto"
self.max_samples_ = min(256, X.shape[0])
windows = sliding_window(X, self.window_size_, self.stride)
rng = np.random.default_rng(self.random_state)
self.forest_ = []
for _ in range(self.n_estimators):
indexes = rng.choice(
windows.shape[0], size=self.max_samples_, replace=False
)
indexes.sort()
self.forest_.append(
RCTree(
X=windows[indexes],
index_labels=np.arange(self.max_samples_),
precision=self.precision,
random_state=self.random_state,
)
)
def _decision_function(self, X: np.ndarray) -> np.array:
decision_scores = np.zeros(shape=X.shape[0] - self.window_size_ + 1)
windows = sliding_window(X, self.window_size_, self.stride)
for i, window in enumerate(windows):
for tree in self.forest_:
# The next valid index for the window
index = max(tree.leaves) + 1
# Add the window and compute the anomaly score
tree.insert_point(window, index=index)
decision_scores[i] += tree.codisp(index)
if self.online_learning:
# Remove the oldest window from the tree
tree.forget_point(min(tree.leaves))
else:
# Remove the window from the tree
tree.forget_point(index)
decision_scores /= len(self.forest_)
return reverse_sliding_window(
per_window_anomaly_scores=decision_scores,
window_size=self.window_size_,
stride=1,
length_time_series=X.shape[0],
)
class RCTree:
"""
Robust random cut tree data structure as described in:
S. Guha, N. Mishra, G. Roy, & O. Schrijvers. Robust random cut forest based anomaly
detection on streams, in Proceedings of the 33rd International conference on machine
learning, New York, NY, 2016 (pp. 2712-2721).
Parameters:
-----------
X: np.ndarray (n x d) (optional)
Array containing n data points, each with dimension d.
If no data provided, an empty tree is created.
index_labels: sequence of length n (optional) (default=None)
Labels for data points provided in X.
Defaults to [0, 1, ... n-1].
precision: float (optional) (default=9)
Floating-point precision for distinguishing duplicate points.
random_state: int, RandomState instance or None (optional) (default=None)
If int, random_state is the seed used by the random number generator;
If RandomState instance, random_state is the random number generator;
If None, the random number generator is the RandomState instance used by np.random.
Attributes:
-----------
root: Branch or Leaf instance
Pointer to root of tree.
leaves: dict
Dict containing pointers to all leaves in tree.
ndim: int
dimension of points in the tree
Methods:
--------
insert_point: inserts a new point into the tree.
forget_point: removes a point from the tree.
disp: compute displacement associated with the removal of a leaf.
codisp: compute collusive displacement associated with the removal of a leaf
(anomaly score).
map_leaves: traverses all nodes in the tree and executes a user-specified
function on the leaves.
query: finds nearest point in tree.
get_bbox: find bounding box of points under a given node.
find_duplicate: finds duplicate points in the tree.
"""
def __init__(self, X=None, index_labels=None, precision=9, random_state=None):
# Random number generation with provided seed
if isinstance(random_state, int):
self.rng = np.random.RandomState(random_state)
elif isinstance(random_state, np.random.RandomState):
self.rng = random_state
else:
self.rng = np.random
# Initialize dict for leaves
self.leaves = {}
# Initialize tree root
self.root = None
self.ndim = None
if X is not None:
# Round data to avoid sorting errors
X = np.around(X, decimals=precision)
# Initialize index labels, if they exist
if index_labels is None:
index_labels = np.arange(X.shape[0], dtype=int)
self.index_labels = index_labels
# Check for duplicates
U, I, N = np.unique(X, return_inverse=True, return_counts=True, axis=0)
# If duplicates exist, take unique elements
if N.max() > 1:
n, d = U.shape
X = U
else:
n, d = X.shape
N = np.ones(n, dtype=int)
I = None
# Store dimension of dataset
self.ndim = d
# Set node above to None in case of bottom-up search
self.u = None
# Create RRC Tree
S = np.ones(n, dtype=bool)
self._mktree(X, S, N, I, parent=self)
# Remove parent of root
self.root.u = None
# Count all leaves under each branch
self._count_all_top_down(self.root)
# Set bboxes of all branches
self._get_bbox_top_down(self.root)
def __repr__(self):
depth = ""
treestr = ""
def print_push(char):
nonlocal depth
branch_str = " {} ".format(char)
depth += branch_str
def print_pop():
nonlocal depth
depth = depth[:-4]
def print_tree(node):
nonlocal depth
nonlocal treestr
if isinstance(node, Leaf):
treestr += "({})\n".format(node.i)
elif isinstance(node, Branch):
treestr += "{0}{1}\n".format(chr(9472), "+")
treestr += "{0} {1}{2}{2}".format(depth, chr(9500), chr(9472))
print_push(chr(9474))
print_tree(node.l)
print_pop()
treestr += "{0} {1}{2}{2}".format(depth, chr(9492), chr(9472))
print_push(" ")
print_tree(node.r)
print_pop()
print_tree(self.root)
return treestr
def _cut(self, X, S, parent=None, side="l"):
# Find max and min over all d dimensions
xmax = X[S].max(axis=0)
xmin = X[S].min(axis=0)
# Compute l
l = xmax - xmin
l /= l.sum()
# Determine dimension to cut
q = self.rng.choice(self.ndim, p=l)
# Determine value for split
p = self.rng.uniform(xmin[q], xmax[q])
# Determine subset of points to left
S1 = (X[:, q] <= p) & (S)
# Determine subset of points to right
S2 = (~S1) & (S)
# Create new child node
child = Branch(q=q, p=p, u=parent)
# Link child node to parent
if parent is not None:
setattr(parent, side, child)
return S1, S2, child
def _mktree(self, X, S, N, I, parent=None, side="root", depth=0):
# Increment depth as we traverse down
depth += 1
# Create a cut according to definition 1
S1, S2, branch = self._cut(X, S, parent=parent, side=side)
# If S1 does not contain an isolated point...
if S1.sum() > 1:
# Recursively construct tree on S1
self._mktree(X, S1, N, I, parent=branch, side="l", depth=depth)
# Otherwise...
else:
# Create a leaf node from isolated point
i = np.flatnonzero(S1).item()
leaf = Leaf(i=i, d=depth, u=branch, x=X[i, :], n=N[i])
# Link leaf node to parent
branch.l = leaf
# If duplicates exist...
if I is not None:
# Add a key in the leaves dict pointing to leaf for all duplicate indices
J = np.flatnonzero(I == i)
# Get index label
J = self.index_labels[J]
for j in J:
self.leaves[j] = leaf
else:
i = self.index_labels[i]
self.leaves[i] = leaf
# If S2 does not contain an isolated point...
if S2.sum() > 1:
# Recursively construct tree on S2
self._mktree(X, S2, N, I, parent=branch, side="r", depth=depth)
# Otherwise...
else:
# Create a leaf node from isolated point
i = np.flatnonzero(S2).item()
leaf = Leaf(i=i, d=depth, u=branch, x=X[i, :], n=N[i])
# Link leaf node to parent
branch.r = leaf
# If duplicates exist...
if I is not None:
# Add a key in the leaves dict pointing to leaf for all duplicate indices
J = np.flatnonzero(I == i)
# Get index label
J = self.index_labels[J]
for j in J:
self.leaves[j] = leaf
else:
i = self.index_labels[i]
self.leaves[i] = leaf
# Decrement depth as we traverse back up
depth -= 1
def map_leaves(self, node, op=(lambda x: None), *args, **kwargs):
"""
Traverse tree recursively, calling operation given by op on leaves
Parameters:
-----------
node: node in RCTree
op: function to call on each leaf
*args: positional arguments to op
**kwargs: keyword arguments to op
Returns:
--------
None
"""
if isinstance(node, Branch):
if node.l:
self.map_leaves(node.l, op=op, *args, **kwargs)
if node.r:
self.map_leaves(node.r, op=op, *args, **kwargs)
else:
op(node, *args, **kwargs)
def forget_point(self, index):
"""
Delete leaf from tree
Parameters:
-----------
index: (Hashable type)
Index of leaf in tree
Returns:
--------
leaf: Leaf instance
Deleted leaf
"""
try:
# Get leaf from leaves dict
leaf = self.leaves[index]
except KeyError:
raise KeyError("Leaf must be a key to self.leaves")
# If duplicate points exist...
if leaf.n > 1:
# Simply decrement the number of points in the leaf and for all branches above
self._update_leaf_count_upwards(leaf, inc=-1)
return self.leaves.pop(index)
# Weird cases here:
# If leaf is the root...
if leaf is self.root:
self.root = None
self.ndim = None
return self.leaves.pop(index)
# Find parent
parent = leaf.u
# Find sibling
if leaf is parent.l:
sibling = parent.r
else:
sibling = parent.l
# If parent is the root...
if parent is self.root:
# Delete parent
del parent
# Set sibling as new root
sibling.u = None
self.root = sibling
# Update depths
if isinstance(sibling, Leaf):
sibling.d = 0
else:
self.map_leaves(sibling, op=self._increment_depth, inc=-1)
return self.leaves.pop(index)
# Find grandparent
grandparent = parent.u
# Set parent of sibling to grandparent
sibling.u = grandparent
# Short-circuit grandparent to sibling
if parent is grandparent.l:
grandparent.l = sibling
else:
grandparent.r = sibling
# Update depths
parent = grandparent
self.map_leaves(sibling, op=self._increment_depth, inc=-1)
# Update leaf counts under each branch
self._update_leaf_count_upwards(parent, inc=-1)
# Update bounding boxes
point = leaf.x
self._relax_bbox_upwards(parent, point)
return self.leaves.pop(index)
def _update_leaf_count_upwards(self, node, inc=1):
"""
Called after inserting or removing leaves. Updates the stored count of leaves
beneath each branch (branch.n).
"""
while node:
node.n += inc
node = node.u
def insert_point(self, point, index, tolerance=None):
"""
Inserts a point into the tree, creating a new leaf
Parameters:
-----------
point: np.ndarray (1 x d)
index: (Hashable type)
Identifier for new leaf in tree
tolerance: float
Tolerance for determining duplicate points
Returns:
--------
leaf: Leaf
New leaf in tree
"""
if not isinstance(point, np.ndarray):
point = np.asarray(point)
point = point.ravel()
if self.root is None:
leaf = Leaf(x=point, i=index, d=0)
self.root = leaf
self.ndim = point.size
self.leaves[index] = leaf
return leaf
# If leaves already exist in tree, check dimensions of point
try:
assert point.size == self.ndim
except ValueError:
raise ValueError("Point must be same dimension as existing points in tree.")
# Check for existing index in leaves dict
try:
assert index not in self.leaves
except KeyError:
raise KeyError("Index already exists in leaves dict.")
# Check for duplicate points
duplicate = self.find_duplicate(point, tolerance=tolerance)
if duplicate:
self._update_leaf_count_upwards(duplicate, inc=1)
self.leaves[index] = duplicate
return duplicate
# If tree has points and point is not a duplicate, continue with main algorithm...
node = self.root
parent = node.u
maxdepth = max([leaf.d for leaf in self.leaves.values()])
depth = 0
branch = None
for _ in range(maxdepth + 1):
bbox = node.b
cut_dimension, cut = self._insert_point_cut(point, bbox)
if cut <= bbox[0, cut_dimension]:
leaf = Leaf(x=point, i=index, d=depth)
branch = Branch(
q=cut_dimension, p=cut, l=leaf, r=node, n=(leaf.n + node.n)
)
break
elif cut >= bbox[-1, cut_dimension]:
leaf = Leaf(x=point, i=index, d=depth)
branch = Branch(
q=cut_dimension, p=cut, l=node, r=leaf, n=(leaf.n + node.n)
)
break
else:
depth += 1
if point[node.q] <= node.p:
parent = node
node = node.l
side = "l"
else:
parent = node
node = node.r
side = "r"
try:
assert branch is not None
except:
raise AssertionError("Error with program logic: a cut was not found.")
# Set parent of new leaf and old branch
node.u = branch
leaf.u = branch
# Set parent of new branch
branch.u = parent
if parent is not None:
# Set child of parent to new branch
setattr(parent, side, branch)
else:
# If a new root was created, assign the attribute
self.root = branch
# Increment depths below branch
self.map_leaves(branch, op=self._increment_depth, inc=1)
# Increment leaf count above branch
self._update_leaf_count_upwards(parent, inc=1)
# Update bounding boxes
self._tighten_bbox_upwards(branch)
# Add leaf to leaves dict
self.leaves[index] = leaf
# Return inserted leaf for convenience
return leaf
def query(self, point, node=None):
"""
Search for leaf nearest to point
Parameters:
-----------
point: np.ndarray (1 x d)
Point to search for
node: Branch instance
Defaults to root node
Returns:
--------
nearest: Leaf
Leaf nearest to queried point in the tree
"""
if not isinstance(point, np.ndarray):
point = np.asarray(point)
point = point.ravel()
if node is None:
node = self.root
return self._query(point, node)
def disp(self, leaf):
"""
Compute displacement at leaf
Parameters:
-----------
leaf: index of leaf or Leaf instance
Returns:
--------
displacement: int
Displacement if leaf is removed
"""
if not isinstance(leaf, Leaf):
try:
leaf = self.leaves[leaf]
except KeyError:
raise KeyError("leaf must be a Leaf instance or key to self.leaves")
# Handle case where leaf is root
if leaf is self.root:
return 0
parent = leaf.u
# Find sibling
if leaf is parent.l:
sibling = parent.r
else:
sibling = parent.l
# Count number of nodes in sibling subtree
displacement = sibling.n
return displacement
def codisp(self, leaf):
"""
Compute collusive displacement at leaf
Parameters:
-----------
leaf: index of leaf or Leaf instance
Returns:
--------
codisplacement: float
Collusive displacement if leaf is removed.
"""
if not isinstance(leaf, Leaf):
try:
leaf = self.leaves[leaf]
except KeyError:
raise KeyError("leaf must be a Leaf instance or key to self.leaves")
# Handle case where leaf is root
if leaf is self.root:
return 0
node = leaf
results = []
for _ in range(node.d):
parent = node.u
if parent is None:
break
if node is parent.l:
sibling = parent.r
else:
sibling = parent.l
num_deleted = node.n
displacement = sibling.n
result = displacement / num_deleted
results.append(result)
node = parent
co_displacement = max(results)
return co_displacement
def codisp_with_cut_dimension(self, leaf):
"""
Compute collusive displacement at leaf and the dimension of the cut.
This method can be used to find the most importance fetures that determined the CoDisp.
Parameters:
-----------
leaf: index of leaf or Leaf instance
Returns:
--------
codisplacement: float
Collusive displacement if leaf is removed.
cut_dimension: int
Dimension of the cut
"""
if not isinstance(leaf, Leaf):
try:
leaf = self.leaves[leaf]
except KeyError:
raise KeyError("leaf must be a Leaf instance or key to self.leaves")
# Handle case where leaf is root
if leaf is self.root:
return 0
node = leaf
results = []
cut_dimensions = []
for _ in range(node.d):
parent = node.u
if parent is None:
break
if node is parent.l:
sibling = parent.r
else:
sibling = parent.l
num_deleted = node.n
displacement = sibling.n
result = displacement / num_deleted
results.append(result)
cut_dimensions.append(parent.q)
node = parent
argmax = np.argmax(results)
return results[argmax], cut_dimensions[argmax]
def get_bbox(self, branch=None):
"""
Compute bounding box of all points underneath a given branch.
Parameters:
-----------
branch: Branch instance
Starting branch. Defaults to root of tree.
Returns:
--------
bbox: np.ndarray (2 x d)
Bounding box of all points underneath branch
"""
if branch is None:
branch = self.root
mins = np.full(self.ndim, np.inf)
maxes = np.full(self.ndim, -np.inf)
self.map_leaves(branch, op=self._get_bbox, mins=mins, maxes=maxes)
bbox = np.vstack([mins, maxes])
return bbox
def find_duplicate(self, point, tolerance=None):
"""
If point is a duplicate of existing point in the tree, return the leaf
containing the point, else return None.
Parameters:
-----------
point: np.ndarray (1 x d)
Point to query in the tree.
tolerance: float
Tolerance for determining whether or not point is a duplicate.
Returns:
--------
duplicate: Leaf or None
If point is a duplicate, returns the leaf containing the point.
If point is not a duplicate, return None.
"""
nearest = self.query(point)
if tolerance is None:
if np.all(nearest.x == point):
return nearest
else:
if np.isclose(nearest.x, point, rtol=tolerance).all():
return nearest
return None
def _lr_branch_bbox(self, node):
"""
Compute bbox of node based on bboxes of node's children.
"""
bbox = np.vstack(
[
np.minimum(node.l.b[0, :], node.r.b[0, :]),
np.maximum(node.l.b[-1, :], node.r.b[-1, :]),
]
)
return bbox
def _get_bbox_top_down(self, node):
"""
Recursively compute bboxes of all branches from root to leaves.
"""
if isinstance(node, Branch):
if node.l:
self._get_bbox_top_down(node.l)
if node.r:
self._get_bbox_top_down(node.r)
bbox = self._lr_branch_bbox(node)
node.b = bbox
def _count_all_top_down(self, node):
"""
Recursively compute number of leaves below each branch from
root to leaves.
"""
if isinstance(node, Branch):
if node.l:
self._count_all_top_down(node.l)
if node.r:
self._count_all_top_down(node.r)
node.n = node.l.n + node.r.n
def _query(self, point, node):
"""
Recursively search for the nearest leaf to a given point.
"""
if isinstance(node, Leaf):
return node
else:
if point[node.q] <= node.p:
return self._query(point, node.l)
else:
return self._query(point, node.r)
def _increment_depth(self, x, inc=1):
"""
Primitive function for incrementing the depth attribute of a leaf.
"""
x.d += inc
def _get_bbox(self, x, mins, maxes):
"""
Primitive function for computing the bbox of a point.
"""
lt = x.x < mins
gt = x.x > maxes
mins[lt] = x.x[lt]
maxes[gt] = x.x[gt]
def _tighten_bbox_upwards(self, node):
"""
Called when new point is inserted. Expands bbox of all nodes above new point
if point is outside the existing bbox.
"""
bbox = self._lr_branch_bbox(node)
node.b = bbox
node = node.u
while node:
lt = bbox[0, :] < node.b[0, :]
gt = bbox[-1, :] > node.b[-1, :]
lt_any = lt.any()
gt_any = gt.any()
if lt_any or gt_any:
if lt_any:
node.b[0, :][lt] = bbox[0, :][lt]
if gt_any:
node.b[-1, :][gt] = bbox[-1, :][gt]
else:
break
node = node.u
def _relax_bbox_upwards(self, node, point):
"""
Called when point is deleted. Contracts bbox of all nodes above deleted point
if the deleted point defined the boundary of the bbox.
"""
while node:
bbox = self._lr_branch_bbox(node)
if not np.any((node.b[0, :] == point) | (node.b[-1, :] == point)):
break
node.b[0, :] = bbox[0, :]
node.b[-1, :] = bbox[-1, :]
node = node.u
def _insert_point_cut(self, point, bbox):
"""
Generates the cut dimension and cut value based on the InsertPoint algorithm.
Parameters:
-----------
point: np.ndarray (1 x d)
New point to be inserted.
bbox: np.ndarray(2 x d)
Bounding box of point set S.
Returns:
--------
cut_dimension: int
Dimension to cut over.
cut: float
Value of cut.
"""
# Generate the bounding box
bbox_hat = np.empty((2, bbox.shape[1]))
# Update the bounding box based on the internal point
bbox_hat[0, :] = np.minimum(bbox[0, :], point)
bbox_hat[-1, :] = np.maximum(bbox[-1, :], point)
b_span = bbox_hat[-1, :] - bbox_hat[0, :]
b_range = b_span.sum()
r = self.rng.uniform(0, b_range)
span_sum = np.cumsum(b_span)
cut_dimension = np.inf
for j in range(len(span_sum)):
if span_sum[j] >= r:
cut_dimension = j
break
if not np.isfinite(cut_dimension):
raise ValueError("Cut dimension is not finite.")
cut = bbox_hat[0, cut_dimension] + span_sum[cut_dimension] - r
return cut_dimension, cut
class Branch:
"""
Branch of RCTree containing two children and at most one parent.
Attributes:
-----------
q: Dimension of cut
p: Value of cut
l: Pointer to left child
r: Pointer to right child
u: Pointer to parent
n: Number of leaves under branch
b: Bounding box of points under branch (2 x d)
"""
__slots__ = ["q", "p", "l", "r", "u", "n", "b"]
def __init__(self, q, p, l=None, r=None, u=None, n=0, b=None):
self.l = l
self.r = r
self.u = u
self.q = q
self.p = p
self.n = n
self.b = b
def __repr__(self):
return "Branch(q={}, p={:.2f})".format(self.q, self.p)
class Leaf:
"""
Leaf of RCTree containing no children and at most one parent.
Attributes:
-----------
i: Index of leaf (user-specified)
d: Depth of leaf
u: Pointer to parent
x: Original point (1 x d)
n: Number of points in leaf (1 if no duplicates)
b: Bounding box of point (1 x d)
"""
__slots__ = ["i", "d", "u", "x", "n", "b"]
def __init__(self, i, d=None, u=None, x=None, n=1):
self.u = u
self.i = i
self.d = d
self.x = x
self.n = n
self.b = x.reshape(1, -1)
def __repr__(self):
return "Leaf({0})".format(self.i)