(* 
    Watches for OCaml XenStore Daemon.
    Copyright (C) 2008 Patrick Colp University of British Columbia

    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program; if not, write to the Free Software
    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
*)

type t =
  {
    domain : Domain.domain;
    path : string;
    token : string;
    relative : bool
  }

let make domain path token relative =
  {
    domain = domain;
    path = path;
    token = token;
    relative = relative
  }

let equal watch1 watch2 =
  watch1.domain#id = watch2.domain#id && watch1.token = watch2.token && watch1.path = watch2.path

(* Fire a watch *)
let fire_watch path recurse watch =
  let relative_base_path = Store.domain_root ^ (string_of_int watch.domain#id) in
  let relative_base_length = succ (String.length relative_base_path) in
  if Store.is_child path watch.path
  then
    let watch_path =
      if watch.relative
      then String.sub path relative_base_length ((String.length path) - relative_base_length)
      else path in
    watch.domain#add_output_message (Message.event ((Utils.null_terminate watch_path) ^ (Utils.null_terminate watch.token)))
  else if recurse && Store.is_child watch.path path
  then
    let watch_path =
      if watch.relative
      then String.sub watch.path relative_base_length ((String.length watch.path) - relative_base_length)
      else watch.path in
    watch.domain#add_output_message (Message.event ((Utils.null_terminate watch_path) ^ (Utils.null_terminate watch.token)))

class watches =
object(self)
  val m_domain_watches = Hashtbl.create 16
  val m_watches = Hashtbl.create 32
  method private add_domain_watch watch =
    let watches = try Hashtbl.find self#domain_watches watch.domain#id with Not_found -> [] in
    Hashtbl.replace self#domain_watches watch.domain#id (watch :: watches)
  method private domain_watches = m_domain_watches
  method private remove_domain_watch watch =
    let watches = try Hashtbl.find self#domain_watches watch.domain#id with Not_found -> [] in
    Hashtbl.replace self#domain_watches watch.domain#id (List.filter (fun w -> not (equal watch w)) watches)
  method private watches = m_watches
  method add (watch : t) =
    if Hashtbl.mem self#watches watch.path
    then (
      let path_watches = Hashtbl.find self#watches watch.path in
      try ignore (List.find (equal watch) path_watches); false
      with Not_found -> (
            Hashtbl.replace self#watches watch.path (watch :: path_watches);
            self#add_domain_watch watch;
            true
          )
    )
    else (
      Hashtbl.add self#watches watch.path [ watch ];
      self#add_domain_watch watch;
      true
    )
  method fire_watches path in_transaction recursive =
    if not in_transaction then Hashtbl.iter (fun _ watches -> List.iter (fire_watch path recursive) watches) self#watches
  method num_watches_for_domain domain_id = try List.length (Hashtbl.find self#domain_watches domain_id) with Not_found -> 0
  method remove (watch : t) =
    if Hashtbl.mem self#watches watch.path
    then (
      let remaining_watches = List.filter (fun w -> not (equal watch w)) (Hashtbl.find self#watches watch.path) in
      if List.length remaining_watches > 0
      then Hashtbl.replace self#watches watch.path remaining_watches
      else Hashtbl.remove self#watches watch.path;
      self#remove_domain_watch watch;
      true
    )
    else false
  method remove_watches (domain : Domain.domain) =
    if Hashtbl.mem self#domain_watches domain#id
    then (
      List.iter (fun watch -> if self#remove watch then Trace.destroy watch.domain#id "watch") (Hashtbl.find self#domain_watches domain#id);
      Hashtbl.remove self#domain_watches domain#id;
    )
end
