(* 
    XenBus for OCaml XenStore Daemon.
    Copyright (C) 2008 Patrick Colp University of British Columbia

    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program; if not, write to the Free Software
    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
*)

let ring_size = 1024

type xenbus_t
type ring_t
type ring_index_t

external init_req_cons : xenbus_t -> ring_index_t = "init_req_cons_c"
external init_req_prod : xenbus_t -> ring_index_t = "init_req_prod_c"
external init_req_ring : xenbus_t -> ring_t = "init_req_ring_c"
external init_rsp_cons : xenbus_t -> ring_index_t = "init_rsp_cons_c"
external init_rsp_prod : xenbus_t -> ring_index_t = "init_rsp_prod_c"
external init_rsp_ring : xenbus_t -> ring_t = "init_rsp_ring_c"
external read_ring : ring_t -> int -> string -> int -> int -> unit = "read_ring_c"
external write_ring : ring_t -> int -> string -> int -> int -> unit = "write_ring_c"
external get_index : ring_index_t -> int32 = "get_index_c"
external set_index : ring_index_t -> int32 -> unit = "set_index_c"
external mmap : int -> xenbus_t = "mmap_c"
external map_foreign : int -> int -> int -> xenbus_t = "xc_map_foreign_range_c"
external munmap : xenbus_t -> unit = "munmap_c"
external mb : unit -> unit = "mb_c"

(* Ring buffer *)
class ring_buffer ring consumer producer =
object (self)
  val m_consumer = consumer
  val m_producer = producer
  val m_ring = ring
  method private advance_consumer amount = set_index m_consumer (Int32.add self#consumer (Int32.of_int amount))
  method private advance_producer amount = set_index m_producer (Int32.add self#producer (Int32.of_int amount))
  method private check_indexes = self#diff <= ring_size
  method private consumer = get_index m_consumer
  method private diff = Int32.to_int (Int32.sub self#producer self#consumer)
  method private mask_index index = (Int32.to_int index) land (pred ring_size)
  method private producer = get_index m_producer
  method private ring = m_ring
  method private set_producer index = set_index m_producer index
  method can_read = self#diff <> 0
  method can_write = self#diff <> ring_size
  method read buffer offset length =
    let start = self#mask_index self#consumer
    and diff = self#diff in
    if not self#check_indexes then raise (Constants.Xs_error (Constants.EIO, "ring_buffer#read_ring", "could not check indexes"));
    mb ();
    let read_length = min (min diff length) (ring_size - start) in
    read_ring self#ring start buffer offset read_length;
    mb ();
    self#advance_consumer read_length;
    read_length
  method write buffer offset length =
    let start = self#mask_index self#producer
    and diff = self#diff in
    if not self#check_indexes then raise (Constants.Xs_error (Constants.EIO, "ring_buffer#write_ring", "could not check indexes"));
    mb ();
    let write_length = min (min (ring_size - diff) length) (ring_size - start) in
    write_ring self#ring start buffer offset write_length;
    mb ();
    self#advance_producer write_length;
    write_length
end

(* XenBus interface *)
class xenbus_interface port xenbus =
object (self)
  inherit Interface.interface as super
  val m_port = port
  val m_request_ring = new ring_buffer (init_req_ring xenbus) (init_req_cons xenbus) (init_req_prod xenbus)
  val m_response_ring = new ring_buffer (init_rsp_ring xenbus) (init_rsp_cons xenbus) (init_rsp_prod xenbus)
  val m_xenbus = xenbus
  method private port = m_port
  method private request_ring = m_request_ring
  method private response_ring = m_response_ring
  method can_read = self#request_ring#can_read
  method can_write = self#response_ring#can_write
  method destroy = if Eventchan.unbind self#port then munmap m_xenbus
  method read buffer offset length =
    let bytes_read = self#request_ring#read buffer offset (min length (String.length buffer)) in
    Eventchan.notify self#port;
    bytes_read
  method write buffer offset length =
    let bytes_written = self#response_ring#write buffer offset (min length (String.length buffer)) in
    Eventchan.notify self#port;
    bytes_written
end
