diff --git a/jupyter/gojupyterscaffold/gojupyterscaffold.go b/jupyter/gojupyterscaffold/gojupyterscaffold.go index a42bbac..ad1aadd 100644 --- a/jupyter/gojupyterscaffold/gojupyterscaffold.go +++ b/jupyter/gojupyterscaffold/gojupyterscaffold.go @@ -15,12 +15,18 @@ import ( "io/ioutil" "os" "os/signal" + "os/user" + "path" + "path/filepath" + "sort" "syscall" "github.com/golang/glog" zmq "github.com/pebbe/zmq4" ) +const jupyterStartupDirectoryPath = ".jupyter/startup/" + // ConnectionInfo stores the contents of the kernel connection file created by Jupyter. type connectionInfo struct { StdinPort int `json:"stdin_port"` @@ -123,6 +129,9 @@ func NewServer(connectionFile string, handlers RequestHandlers) (server *Server, if err := hb.Bind(cinfo.getAddr(cinfo.HBPort)); err != nil { return nil, fmt.Errorf("Failed to bind heartbeat socket: %v", err) } + + loadStartupScripts(execQueue, shell) + return &Server{ handlers: handlers, ctx: serverCtx, @@ -137,6 +146,46 @@ func NewServer(connectionFile string, handlers RequestHandlers) (server *Server, }, nil } +func loadStartupScripts(execQueue *executeQueue, shell *shellSocket) { + u, err := user.Current() + if err != nil { + glog.Errorf("Error while fetching current user directory: %v\n", err) + return + } + + profileDirPath := path.Join(u.HomeDir, jupyterStartupDirectoryPath) + fileNames, err := ioutil.ReadDir(profileDirPath) + if err != nil { + glog.Errorf("Error while reading startup files from the profile directory: %v\n", err) + return + } + + var files []string + for _, file := range fileNames { + if !file.IsDir() && filepath.Ext(file.Name()) == ".go" { + files = append(files, file.Name()) + } + } + + if files == nil || len(files) == 0 { + glog.Info("No startup files found, returning") + return + } + + sort.Strings(files) + for _, file := range files { + code, err := ioutil.ReadFile(path.Join(profileDirPath, file)) + if err != nil { + glog.Errorf("Error while loading startup file %s: %+v", file, err) + return + } + + req := &message{Content: &ExecuteRequest{Code: string(code)}} + execQueue.queue <- &executeQueueItem{req, shell} + glog.Infof("Loaded %s\n", file) + } +} + // Context returns the context of the server func (s *Server) Context() context.Context { return s.ctx