diff --git a/demos/brunel_simulation/brunel_example.py b/demos/brunel_simulation/brunel_example.py index 4c59b60dbaa9fd1c84c694131ef9825f736d9a3c..57ed407519b478c9c18047a15409e955db39e294 100644 --- a/demos/brunel_simulation/brunel_example.py +++ b/demos/brunel_simulation/brunel_example.py @@ -11,11 +11,10 @@ is not structured spatially. from copy import deepcopy from math import sqrt import numpy as np -from mpl_toolkits.mplot3d import Axes3D -import matplotlib.pyplot as plt +# from mpl_toolkits.mplot3d import Axes3D +# import matplotlib.pyplot as plt import nest -import nest.raster_plot import nest.topology as tp @@ -23,6 +22,7 @@ class Brunel3D: def __init__(self): self.layer_dict = {} + nest.SetKernelStatus({"local_num_threads": 1}) nest.Install("streamingmodule") # nest.SetKernelStatus({'print_time': True}) @@ -196,70 +196,70 @@ class Brunel3D: def simulate(self): nest.Simulate(1000) - def plot_positions(self): - ex_pos = self.layers[0][1]['positions'] - in_pos = self.layers[1][1]['positions'] - fig = plt.figure() - ax = Axes3D(fig) - for c, m, positions in [('b', 'o', ex_pos), ('r', '^', in_pos)]: - ax.scatter([x for x, y, z in positions], - [y for x, y, z in positions], - [z for x, y, z in positions], - c=c, marker=m) - - def get_results(self): - mm = (self.layer_dict['Multimeter'][0] + 1,) - sd = (self.layer_dict['SpikeDetector'][0] + 1,) - mm_status = nest.GetStatus(mm)[0] - sd_status = nest.GetStatus(sd)[0] - - nest.raster_plot.from_device(sd, hist=True) - - senders = mm_status['events']['senders'] - times = mm_status['events']['times'] - v_m = mm_status['events']['V_m'] - v_th = mm_status['events']['V_th'] - step = int(max(senders)/100 + 1) # Only plot results from some GIDs - - mm_events = [] - for i in range(1, max(senders) + 1, step): - if i in senders: - indices = np.argwhere(senders == i) - mm_events.append({'GID': i, - 'times': [times[n] for n in indices], - 'V_m': [v_m[n] for n in indices], - 'V_th': [v_th[n] for n in indices]}) - - return {'multimeter': mm_events, - 'spike_detector': nest.GetStatus(sd)[0]} - - -if __name__ == '__main__': - nest.ResetKernel() - - print('Making specifications') - brunel = Brunel3D() - brunel.make_layer_specs() - brunel.make_connection_specs() - - print('Making layers') - brunel.make_layers() - nest.topology.DumpLayerNodes([l[0] for l in brunel.layer_dict.values()][:2], - 'brunel_nodes.txt') - - print('Making connections') - brunel.make_connections() - - brunel.simulate() - - print('Getting results') - brunel.plot_positions() - results = brunel.get_results() - - for value in ['V_m', 'V_th']: - plt.figure() - for n in results['multimeter'][::20]: - plt.plot(n['times'], n[value], label='{}'.format(n['GID'])) - plt.legend() - plt.title(value) - plt.show() + # def plot_positions(self): + # ex_pos = self.layers[0][1]['positions'] + # in_pos = self.layers[1][1]['positions'] + # fig = plt.figure() + # ax = Axes3D(fig) + # for c, m, positions in [('b', 'o', ex_pos), ('r', '^', in_pos)]: + # ax.scatter([x for x, y, z in positions], + # [y for x, y, z in positions], + # [z for x, y, z in positions], + # c=c, marker=m) + + # def get_results(self): + # mm = (self.layer_dict['Multimeter'][0] + 1,) + # sd = (self.layer_dict['SpikeDetector'][0] + 1,) + # mm_status = nest.GetStatus(mm)[0] + # sd_status = nest.GetStatus(sd)[0] + + # nest.raster_plot.from_device(sd, hist=True) + + # senders = mm_status['events']['senders'] + # times = mm_status['events']['times'] + # v_m = mm_status['events']['V_m'] + # v_th = mm_status['events']['V_th'] + # step = int(max(senders)/100 + 1) # Only plot results from some GIDs + + # mm_events = [] + # for i in range(1, max(senders) + 1, step): + # if i in senders: + # indices = np.argwhere(senders == i) + # mm_events.append({'GID': i, + # 'times': [times[n] for n in indices], + # 'V_m': [v_m[n] for n in indices], + # 'V_th': [v_th[n] for n in indices]}) + + # return {'multimeter': mm_events, + # 'spike_detector': nest.GetStatus(sd)[0]} + + +# if __name__ == '__main__': +# nest.ResetKernel() + +# print('Making specifications') +# brunel = Brunel3D() +# brunel.make_layer_specs() +# brunel.make_connection_specs() + +# print('Making layers') +# brunel.make_layers() +# nest.topology.DumpLayerNodes([l[0] for l in brunel.layer_dict.values()][:2], +# 'brunel_nodes.txt') + +# print('Making connections') +# brunel.make_connections() + +# brunel.simulate() + +# print('Getting results') +# brunel.plot_positions() +# results = brunel.get_results() + +# for value in ['V_m', 'V_th']: +# plt.figure() +# for n in results['multimeter'][::20]: +# plt.plot(n['times'], n[value], label='{}'.format(n['GID'])) +# plt.legend() +# plt.title(value) +# plt.show() diff --git a/streaming_recording_backend.cpp b/streaming_recording_backend.cpp index f6cf26da6dcd3eb83c68d0473b96bc68a737c5b3..5893486545aebd8f87ae4352e83dea57236fba43 100644 --- a/streaming_recording_backend.cpp +++ b/streaming_recording_backend.cpp @@ -38,6 +38,19 @@ void StreamingRecordingBackend::initialize() { std::cout << "initialize()" << std::endl; } +void StreamingRecordingBackend::prepare() { + std::cout << "prepare()" << std::endl; + std::cout << "Get the number of nodes" << nest::kernel().node_manager.size() << std::endl; +} + +void StreamingRecordingBackend::cleanup() { + std::cout << "cleanup()" << std::endl; +} + +void StreamingRecordingBackend::post_run_cleanup() { + std::cout << "post_run_cleanup()" << std::endl; +} + void StreamingRecordingBackend::enroll( const nest::RecordingDevice &device, const std::vector<Name> &double_value_names, @@ -80,32 +93,16 @@ void StreamingRecordingBackend::write(const nest::RecordingDevice &device, const nest::Event &event, const std::vector<double> &double_values, const std::vector<long> &long_values) { - // Called per thread - // std::lock_guard<std::mutex> lock_guard(write_mutex_); - // std::cout << std::this_thread::get_id() << ' '; - // std::cout << device.get_name() << ' '; - // std::cout << event.get_sender_gid() << ' '; - // std::cout << event.get_stamp() << ' '; - // for (const auto value : double_values) { - // std::cout << value << ' '; - // } - // std::cout << ' '; - // for (const auto value : long_values) { - // std::cout << value << ' '; - // } - // std::cout << std::endl; - const auto thread_devices = devices_.find(std::this_thread::get_id()); if (thread_devices == devices_.end()) { - // std::cout << "Error: no devices assigned to this thread!" << std::endl; + std::cerr << "Error: no devices assigned to this thread!" << std::endl; return; } const auto thread_device = thread_devices->second.find(device.get_name()); if (thread_device == thread_devices->second.end()) { - // std::cout << "Error: device not found in this thread (device = " - // << device.get_name() << ")" << std::endl; - + std::cerr << "Error: device not found in this thread (device = " + << device.get_name() << ")" << std::endl; return; } @@ -121,7 +118,13 @@ void StreamingRecordingBackend::write(const nest::RecordingDevice &device, void StreamingRecordingBackend::synchronize() { // Called per thread - for (const auto &device : devices_.at(std::this_thread::get_id())) { + + const auto thread_devices = devices_.find(std::this_thread::get_id()); + if (thread_devices == devices_.end()) { + return; + } + + for (const auto &device : thread_devices->second) { { std::lock_guard<std::mutex> lock_guard(relay_mutex_); relay_.Send(device.second->node(), false); @@ -131,6 +134,10 @@ void StreamingRecordingBackend::synchronize() { } } +void StreamingRecordingBackend::clear(const nest::RecordingDevice& device) { + std::cout << "clear(" << device.get_name() << ")" << std::endl; +} + void StreamingRecordingBackend::finalize() { // Called once std::cout << "finalize()" << std::endl; diff --git a/streaming_recording_backend.h b/streaming_recording_backend.h index ede20e6b3761b28fb4be86f4d6a153c2fae85778..52643d7fa67e200790f95395de1f5cde5625784a 100644 --- a/streaming_recording_backend.h +++ b/streaming_recording_backend.h @@ -27,6 +27,7 @@ #include <memory> #include <mutex> #include <string> +#include <thread> #include <vector> #include "nest_types.h" @@ -49,9 +50,15 @@ class StreamingRecordingBackend : public nest::RecordingBackend { void initialize() override; + void prepare() override; + void cleanup() override; + void post_run_cleanup() override; + void finalize() override; void synchronize() override; + void clear(const nest::RecordingDevice&) override; + void write(const nest::RecordingDevice &, const nest::Event &, const std::vector<double> &, const std::vector<long> &) override;