Source code for daisy.persistence.mongodb_graph_provider

from __future__ import absolute_import
from ..coordinate import Coordinate
from ..roi import Roi
from .shared_graph_provider import\
    SharedGraphProvider, SharedSubGraph
from ..graph import Graph, DiGraph
from pymongo import MongoClient, ASCENDING, ReplaceOne, UpdateOne
from pymongo.errors import BulkWriteError, WriteError
import logging
import numpy as np
import networkx as nx

logger = logging.getLogger(__name__)


[docs]class MongoDbGraphProvider(SharedGraphProvider): '''Provides shared graphs stored in a MongoDB. Nodes are assumed to have at least an attribute ``id``. If the have a position attribute (set via argument ``position_attribute``, defaults to ``position``), it will be used for geometric slicing (see ``__getitem__``). Edges are assumed to have at least attributes ``u``, ``v``. Arguments: db_name (``string``): The name of the MongoDB database. host (``string``, optional): The URL of the MongoDB host. mode (``string``, optional): One of ``r``, ``r+``, or ``w``. Defaults to ``r+``. ``w`` drops the node, edge, and meta collections. directed (``bool``): True if the graph is directed, false otherwise. If None, attempts to read value from existing database. If not found, defaults to false. nodes_collection (``string``): edges_collection (``string``): meta_collection (``string``): Names of the nodes, edges. and meta collections, should they differ from ``nodes``, ``edges``, and ``meta``. endpoint_names (``list`` or ``tuple`` with two elements): What keys to use for the start and end of an edge. Default is ['u', 'v'] position_attribute (``string`` or list of ``string``s, optional): The node attribute(s) that contain position information. This will be used for slicing subgraphs via ``__getitem__``. If a single string, the attribute is assumed to be an array. If a list, each entry denotes the position coordinates in order (e.g., `position_z`, `position_y`, `position_x`). ''' def __init__( self, db_name, host=None, mode='r+', directed=None, total_roi=None, nodes_collection='nodes', edges_collection='edges', endpoint_names=None, meta_collection='meta', position_attribute='position'): self.db_name = db_name self.host = host self.mode = mode self.directed = directed self.total_roi = total_roi self.nodes_collection_name = nodes_collection self.edges_collection_name = edges_collection self.endpoint_names = ['u', 'v'] if endpoint_names is None\ else endpoint_names self.meta_collection_name = meta_collection self.client = None self.database = None self.nodes = None self.edges = None self.meta = None self.position_attribute = position_attribute try: self.__connect() if mode != 'w': if self.db_name not in self.client.list_database_names(): logger.warn("Opened with read mode %s, but no db with name" "%s found in client at %s" % (mode, self.db_name, self.host)) self.__open_db() if mode == 'w': logger.info( "dropping collections %s, %s, and %s", self.nodes_collection_name, self.edges_collection_name, self.meta_collection_name) self.__open_collections() self.nodes.drop() self.edges.drop() self.meta.drop() collection_names = self.database.list_collection_names() if meta_collection in collection_names: metadata = self.__get_metadata() if metadata: self.__check_metadata(metadata) else: self.__set_metadata() else: self.__set_metadata() if nodes_collection not in collection_names: self.__create_node_collection() if edges_collection not in collection_names: self.__create_edge_collection() except Exception as e: self.__disconnect() raise e def __del__(self): self.__disconnect()
[docs] def read_nodes(self, roi, attr_filter=None, read_attrs=None): '''Return a list of nodes within roi. Arguments: roi (``daisy.Roi``): Get nodes that fall within this roi attr_filter (``dict``): Only return nodes that have attribute=value for each attribute value pair in attr_filter. read_attrs (``list`` of ``string``): Attributes to return. Others will be ignored ''' logger.debug("Querying nodes in %s", roi) if attr_filter is None: attr_filter = {} try: self.__connect() self.__open_db() self.__open_collections() pos_query = self.__pos_query(roi) query_list = [pos_query] for attr, value in attr_filter.items(): query_list.append({attr: value}) projection = {'_id': False} if read_attrs is not None: projection['id'] = True if type(self.position_attribute) == list: for a in self.position_attribute: projection[a] = True else: projection[self.position_attribute] = True for attr in read_attrs: projection[attr] = True nodes = self.nodes.find({'$and': query_list}, projection) nodes = list(nodes) except Exception as e: self.__disconnect() raise e for node in nodes: node['id'] = np.uint64(node['id']) return nodes
[docs] def num_nodes(self, roi): '''Return the number of nodes in the roi.''' try: self.__connect() self.__open_db() self.__open_collections() num = self.nodes.count(self.__pos_query(roi)) except Exception as e: self.__disconnect() raise e return num
[docs] def has_edges(self, roi): '''Returns true if there is at least one edge in the roi.''' try: self.__connect() self.__open_db() self.__open_collections() nodes = list(self.nodes.find(self.__pos_query(roi))) # no nodes -> no edges if len(nodes) == 0: return False node_ids = list([int(np.int64(n['id'])) for n in nodes]) # limit query to 1M node IDs (otherwise we might exceed the 16MB # BSON document size limit) length = len(node_ids) query_size = 1000000 num_chunks = (length - 1)//query_size + 1 for i in range(num_chunks): i_b = i*query_size i_e = min((i + 1)*query_size, len(node_ids)) assert i_b < len(node_ids) query = {self.endpoint_names[0]: {'$in': node_ids[i_b:i_e]}} if self.edges.find_one(query) is not None: return True if num_chunks > 0: assert i_e == len(node_ids) except Exception as e: self.__disconnect() raise e return False
[docs] def read_edges(self, roi, nodes=None, attr_filter=None, read_attrs=None): '''Returns a list of edges within roi. Arguments: roi (``daisy.Roi``): Get nodes that fall within this roi nodes (``dict``): Return edges with sources in this nodes list. If none, reads nodes in roi using read_nodes. Dictionary format is string attribute -> value, including 'id' as an attribute. attr_filter (``dict``): Only return nodes that have attribute=value for each attribute value pair in attr_filter. read_attrs (``list`` of ``string``): Attributes to return. Others will be ignored ''' if nodes is None: nodes = self.read_nodes(roi) node_ids = list([int(np.int64(n['id'])) for n in nodes]) logger.debug("found %d nodes", len(node_ids)) logger.debug("looking for edges with u in %s", node_ids[:100]) u, v = self.endpoint_names edges = [] if attr_filter is None: attr_filter = {} try: self.__connect() self.__open_db() self.__open_collections() # limit query to 1M node IDs (otherwise we might exceed the 16MB # BSON document size limit) length = len(node_ids) query_size = 1000000 num_chunks = (length - 1)//query_size + 1 filters = [] for attr, value in attr_filter.items(): filters.append({attr: value}) projection = {'_id': False} if read_attrs is not None: projection[u] = True projection[v] = True for attr in read_attrs: projection[attr] = True for i in range(num_chunks): i_b = i*query_size i_e = min((i + 1)*query_size, len(node_ids)) assert i_b < len(node_ids) endpoint_query = {self.endpoint_names[0]: {'$in': node_ids[i_b:i_e]}} if attr_filter: query = {'$and': filters + [endpoint_query]} else: query = endpoint_query edges += self.edges.find(query, projection) if num_chunks > 0: assert i_e == len(node_ids) logger.debug("found %d edges", len(edges)) logger.debug("first 100 edges read: %s", edges[:100]) except Exception as e: self.__disconnect() raise e for edge in edges: edge[u] = np.uint64(edge[u]) edge[v] = np.uint64(edge[v]) return edges
def __getitem__(self, roi): return self.get_graph(roi)
[docs] def get_graph( self, roi, nodes_filter=None, edges_filter=None, node_attrs=None, edge_attrs=None): ''' Return a graph within roi, optionally filtering by node and edge attributes. Arguments: roi (``daisy.Roi``): Get nodes and edges whose source is within this roi nodes_filter (``dict``): edges_filter (``dict``): Only return nodes/edges that have attribute=value for each attribute value pair in nodes/edges_filter. node_attrs (``list`` of ``string``): Only return these attributes for nodes. Other attributes will be ignored, but id and position attribute(s) will always be included. If None (default), return all attrs. edge_attrs (``list`` of ``string``): Only return these attributes for edges. Other attributes will be ignored, but source and target will always be included. If None (default), return all attrs. ''' nodes = self.read_nodes( roi, attr_filter=nodes_filter, read_attrs=node_attrs) edges = self.read_edges( roi, nodes=nodes, attr_filter=edges_filter, read_attrs=edge_attrs) u, v = self.endpoint_names node_list = [ (n['id'], self.__remove_keys(n, ['id'])) for n in nodes] edge_list = [ (e[u], e[v], self.__remove_keys(e, [u, v])) for e in edges] if self.directed: graph = MongoDbSubDiGraph( self, roi) else: # create the subgraph graph = MongoDbSubGraph( self, roi) graph.add_nodes_from(node_list) graph.add_edges_from(edge_list) return graph
def __remove_keys(self, dictionary, keys): '''Removes given keys from dictionary.''' for key in keys: del dictionary[key] return dictionary def __connect(self): '''Connects to Mongo client''' if not self.client: self.client = MongoClient(self.host) def __open_db(self): '''Opens Mongo database''' if not self.database: self.database = self.client[self.db_name] def __open_collections(self): '''Opens the node, edge, and meta collections''' if not self.nodes: self.nodes = self.database[self.nodes_collection_name] self.edges = self.database[self.edges_collection_name] self.meta = self.database[self.meta_collection_name] def __get_metadata(self): '''Gets metadata out of the meta collection and returns it as a dictionary.''' self.__open_collections() metadata = self.meta.find_one({}, {"_id": False}) return metadata def __disconnect(self): '''Closes the mongo client and removes references to all collections and databases''' self.nodes = None self.edges = None self.meta = None self.database = None if self.client: self.client.close() self.client = None def __create_node_collection(self): '''Creates the node collection, including indexes''' self.__open_db() self.__open_collections() if type(self.position_attribute) == list: self.nodes.create_index( [ (key, ASCENDING) for key in self.position_attribute ], name='position') else: self.nodes.create_index( [ ('position', ASCENDING) ], name='position') self.nodes.create_index( [ ('id', ASCENDING) ], name='id', unique=True) def __create_edge_collection(self): '''Creates the edge collection, including indexes''' self.__open_db() self.__open_collections() u, v = self.endpoint_names self.edges.create_index( [ (u, ASCENDING), (v, ASCENDING) ], name='incident', unique=True) def __check_metadata(self, metadata): '''Checks if the provided metadata matches the existing metadata in the meta collection''' if self.directed is None: assert metadata['directed'] is not None,\ "Meta collection exists but does not contain "\ "directed information" self.directed = metadata['directed'] elif metadata['directed'] != self.directed: raise ValueError(( "Input parameter directed={} does not match" "directed value {} already in stored metadata") .format(self.directed, metadata['directed'])) if self.total_roi is None: if 'total_roi_offset' in metadata\ and 'total_roi_shape' in metadata: offset = metadata['total_roi_offset'] shape = metadata['total_roi_shape'] self.total_roi = Roi(offset, shape) else: offset = self.total_roi.get_offset() if list(offset) != metadata['total_roi_offset']: raise ValueError(( "Input total_roi offset {} does not match" "total_roi offset {} already stored in metadata") .format( self.total_roi.get_offset(), metadata['total_roi_offset'])) if list(self.total_roi.get_shape()) != metadata['total_roi_shape']: raise ValueError(( "Input total_roi shape {} does not match" "total_roi shape {} already stored in metadata") .format( self.total_roi.get_shape(), metadata['total_roi_shape'])) def __set_metadata(self): '''Sets the metadata in the meta collection to the provided values''' if not self.directed: # default is false self.directed = False meta_data = {'directed': self.directed} # if total_roi not specified, don't write it if self.total_roi: meta_data['total_roi_offset'] = self.total_roi.get_offset() meta_data['total_roi_shape'] = self.total_roi.get_shape() self.__open_collections() # It's possible that another worker has already inserted the metadata - # upsert to keep only one document in the collection self.meta.replace_one(meta_data, meta_data, upsert=True) def __pos_query(self, roi): '''Generates a mongo query for position''' begin = roi.get_begin() end = roi.get_end() if type(self.position_attribute) == list: assert len(self.position_attribute) == roi.dims, ( 'Number of position attributes does not match number of ' 'dimensions') return { key: { k: v for k, v in zip( ["$gte", "$lt"], [ b if b is not None else float("-inf"), e if e is not None else float("inf"), ], ) } for key, b, e in zip(self.position_attribute, begin, end) } else: return { "position.%d" % d: { k: v for k, v in zip( ["$gte", "$lt"], [ b if b is not None else float("-inf"), e if e is not None else float("inf"), ], ) } for d, (b, e) in enumerate(zip(begin, end)) }
class MongoDbSharedSubGraph(SharedSubGraph): def __init__( self, graph_provider, roi): super().__init__() self.provider = graph_provider self.roi = roi self.client = MongoClient(self.provider.host) self.database = self.client[self.provider.db_name] self.nodes_collection = self.database[ self.provider.nodes_collection_name] self.edges_collection = self.database[ self.provider.edges_collection_name] def write_nodes( self, roi=None, attributes=None, fail_if_exists=False, fail_if_not_exists=False, delete=False): assert not delete, "Delete not implemented" assert not(fail_if_exists and fail_if_not_exists),\ "Cannot have fail_if_exists and fail_if_not_exists simultaneously" if self.provider.mode == 'r': raise RuntimeError("Trying to write to read-only DB") if roi is None: roi = self.roi logger.debug("Writing nodes") nodes = [] for node_id, data in self.nodes(data=True): if not self.__contains(roi, node_id): logger.debug( "Skipping node {} with data {} because not in roi {}" .format(node_id, data, roi)) continue node = { 'id': int(np.int64(node_id)) } if not attributes: node.update(data) else: for key in data: if key in attributes: node[key] = data[key] nodes.append(node) if len(nodes) == 0: return try: self.__write(self.nodes_collection, ['id'], nodes, fail_if_exists=fail_if_exists, fail_if_not_exists=fail_if_not_exists, delete=delete) except BulkWriteError as e: logger.error(e.details) raise def write_edges( self, roi=None, attributes=None, fail_if_exists=False, fail_if_not_exists=False, delete=False): assert not delete, "Delete not implemented" assert not(fail_if_exists and fail_if_not_exists),\ "Cannot have fail_if_exists and fail_if_not_exists simultaneously" if self.provider.mode == 'r': raise RuntimeError("Trying to write to read-only DB") if roi is None: roi = self.roi logger.debug("Writing edges in %s", roi) edges = [] u_name, v_name = self.provider.endpoint_names for u, v, data in self.edges(data=True): if not self.is_directed(): u, v = min(u, v), max(u, v) if not self.__contains(roi, u): logger.debug( ("Skipping edge with u {}, v {}," + "and data {} because u not in roi {}") .format(u, v, data, roi)) continue edge = { u_name: int(np.int64(u)), v_name: int(np.int64(v)), } if not attributes: edge.update(data) else: for key in data: if key in attributes: edge[key] = data[key] edges.append(edge) if len(edges) == 0: logger.debug("No edges to insert in %s", roi) return try: self.__write(self.edges_collection, [u_name, v_name], edges, fail_if_exists=fail_if_exists, fail_if_not_exists=fail_if_not_exists, delete=delete) except BulkWriteError as e: logger.error(e.details) raise def update_node_attrs( self, roi=None, attributes=None): if self.provider.mode == 'r': raise RuntimeError("Trying to write to read-only DB") if roi is None: roi = self.roi logger.debug("Updating node attributes") updates = [] for node_id, data in self.nodes(data=True): if not self.__contains(roi, node_id): logger.debug( "Skipping node {} with data {} because not in roi {}" .format(node_id, data, roi)) continue _filter = { 'id': int(np.int64(node_id)) } if not attributes: update = {'$set': data} else: update = {} for key in data: if key in attributes: update[key] = data[key] if not update: logger.info("Skipping node %s with data %s" " - no attributes to update" % (node_id, data)) continue update = {'$set': update} updates.append(UpdateOne(_filter, update)) if len(updates) == 0: return try: self.nodes_collection.bulk_write(updates, ordered=False) except BulkWriteError as e: logger.error(e.details) raise def update_edge_attrs( self, roi=None, attributes=None): if self.provider.mode == 'r': raise RuntimeError("Trying to write to read-only DB") if roi is None: roi = self.roi logger.debug("Updating edge attributes") updates = [] u_name, v_name = self.provider.endpoint_names for u, v, data in self.edges(data=True): if not self.is_directed(): u, v = min(u, v), max(u, v) if not self.__contains(roi, u): logger.debug( ("Skipping edge with u {}, v {}," + "and data {} because u not in roi {}") .format(u, v, data, roi)) continue _filter = { u_name: int(np.int64(u)), v_name: int(np.int64(v)), } if not attributes: update = {'$set': data} else: update = {} for key in data: if key in attributes: update[key] = data[key] if not update: logger.info("Skipping edge %s -> %s with data %s" "- no attributes to update" % (u, v, data)) continue update = {'$set': update} updates.append(UpdateOne(_filter, update)) if len(updates) == 0: logger.info("No updates in roi %s" % roi) return try: self.edges_collection.bulk_write(updates, ordered=False) except BulkWriteError as e: logger.error(e.details) raise def get_connected_components(self): '''Returns a list of connected components as networkx (di)graphs''' subgraphs = [] if self.is_directed(): node_set_generator = nx.weakly_connected_components(self) else: node_set_generator = nx.connected_components(self) for node_set in node_set_generator: edge_set = self.edges(node_set, data=True) if self.is_directed(): g = nx.DiGraph() else: g = nx.Graph() g.add_nodes_from([(node, self.nodes[node]) for node in node_set]) g.add_edges_from(edge_set) subgraphs.append(g) return subgraphs def __write(self, collection, match_fields, docs, fail_if_exists=False, fail_if_not_exists=False, delete=False): '''Writes documents to provided mongo collection, checking for restricitons. Args: collection (``pymongo.collection``): The collection to write the documents into. match_fields (``list`` of ``string``): The set of fields to match to be considered the same document. docs (``dict`` or ``bson``): The documents to insert into the collection fail_if_exists, fail_if_not_exists, delete (``bool``): see write_nodes or write_edges for explanations of these flags ''' assert not delete, "Delete not implemented" match_docs = [] for doc in docs: match_doc = {} for field in match_fields: match_doc[field] = doc[field] match_docs.append(match_doc) if fail_if_exists: self.__write_fail_if_exists(collection, match_docs, docs) elif fail_if_not_exists: self.__write_fail_if_not_exists(collection, match_docs, docs) else: self.__write_no_flags(collection, match_docs, docs) def __write_no_flags(self, collection, old_docs, new_docs): bulk_query = [ReplaceOne(old, new, upsert=True) for old, new in zip(old_docs, new_docs)] collection.bulk_write(bulk_query, ordered=False) def __write_fail_if_exists(self, collection, old_docs, new_docs): for old in old_docs: if collection.find(old): raise WriteError( "Found existing doc %s and fail_if_exists set to True." " Aborting write for all docs." % old) collection.insert_many(new_docs) def __write_fail_if_not_exists(self, collection, old_docs, new_docs): for old in old_docs: if not collection.find(old): raise WriteError( "Did not find existing doc %s and fail_if_not_exists " "set to True. Aborting write for all docs." % old) bulk_query = [ReplaceOne(old, new, upsert=False) for old, new in zip(old_docs, new_docs)] result = collection.bulk_write(bulk_query, ordered=False) assert len(new_docs) == result.matched_count,\ ("Supposed to replace %s docs, but only replaced %s" % (len(new_docs), result.matched_count)) def __contains(self, roi, node): '''Determines if the given node is inside the given roi''' node_data = self.nodes[node] # Some nodes are outside of the originally requested ROI (they have # been pulled in by edges leaving the ROI). These nodes have no # attributes, so we can't perform an inclusion test. However, we # know they are outside of the subgraph ROI, and therefore also # outside of 'roi', whatever it is. coordinate = [] if type(self.provider.position_attribute) == list: for pos_attr in self.provider.position_attribute: if pos_attr not in node_data: return False coordinate.append(node_data[pos_attr]) else: if self.provider.position_attribute not in node_data: return False coordinate = node_data[self.provider.position_attribute] logger.debug("Checking if coordinate {} is inside roi {}" .format(coordinate, roi)) return roi.contains(Coordinate(coordinate)) def is_directed(self): raise RuntimeError("not implemented in %s" % self.name()) class MongoDbSubGraph(MongoDbSharedSubGraph, Graph): def __init__( self, graph_provider, roi): # this calls the init function of the MongoDbSharedSubGraph, # because left parents come before right parents super().__init__( graph_provider, roi) def is_directed(self): return False class MongoDbSubDiGraph(MongoDbSharedSubGraph, DiGraph): def __init__( self, graph_provider, roi): # this calls the init function of the MongoDbSharedSubGraph, # because left parents come before right parents super().__init__( graph_provider, roi) def is_directed(self): return True