/* 
   This file is part of libodbc++.
   
   Copyright (C) 1999 Manush Dodunekov <manush@litecom.net>
   
   This library is free software; you can redistribute it and/or
   modify it under the terms of the GNU Library General Public
   License as published by the Free Software Foundation; either
   version 2 of the License, or (at your option) any later version.
   
   This library is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
   Library General Public License for more details.
   
   You should have received a copy of the GNU Library General Public License
   along with this library; see the file COPYING.  If not, write to
   the Free Software Foundation, Inc., 59 Temple Place - Suite 330,
   Boston, MA 02111-1307, USA.
*/

/*
  We have three cases to hadle here:
  1. The driver supports SQLNumParams and SQLDescribeParam and gives correct
  types for all parameters.
  2. The driver supports SQLNumParams and pretends to support SQLDescribeParam
  but actually says all parameters are VARCHAR(255)
  3. The driver doesn't support SQLNumParams and SQLDescribeParam. In this case
  we have to be able to dynamically add parameters.
*/

#include <odbc++/preparedstatement.h>
#include "datahandler.h"
#include "driverinfo.h"
#include "dtconv.h"

using namespace odbc;
using namespace std;

PreparedStatement::PreparedStatement(Connection* con,
				     SQLHSTMT hstmt,
				     const string& sql,
				     int resultSetType,
				     int resultSetConcurrency,
				     int defaultDirection)
  :Statement(con,hstmt,resultSetType,resultSetConcurrency),
   sql_(sql),
   rowset_(new Rowset(1,ODBC3_DC(true,false))), //always one row for now
   numParams_(0),
   defaultDirection_(defaultDirection),
   paramsBound_(false)
{
  this->_prepare();
  this->_setupParams();
}

PreparedStatement::~PreparedStatement()
{
  if(paramsBound_) {
    this->_unbindParams();
  }

  delete rowset_;
}

void PreparedStatement::_prepare()
{
  SQLRETURN r=SQLPrepare(hstmt_,
			 (SQLCHAR*)sql_.data(),
			 sql_.length());

  string msg="Error preparing "+sql_;
  this->_checkStmtError(hstmt_,r,msg.c_str());
}

void PreparedStatement::_checkParam(int idx, int sqlType,
				    int defPrec, int defScale)
{
    // we put a restriction when using drivers that don't support
    // SQLNumParams: All parameters have to be set in increasing order,
    // starting from 1

  if(idx<=0 || idx>numParams_+1) {
    throw SQLException
      ("[libodbc++]: PreparedStatement: parameter index "+
       intToString(idx)+" out of bounds");
  }

  if(numParams_<(unsigned int)idx) {

    if(paramsBound_) {
      this->_unbindParams();
    }

    // just add a column
    rowset_->addColumn(sqlType,defPrec,defScale);
    directions_.push_back(defaultDirection_);
    numParams_++;
  }

  assert(idx<=numParams_ && idx>0);

  if(rowset_->getCurrentRow()>0) {
    return;
  }

  // check if we have to replace the datahandler
  // this happens if the driver reports for example
  // VARCHAR(255) and application calls setDate
  DataHandler* dh=rowset_->getColumn(idx);
  int newType;
  int newPrec=0;
  int newScale=0;
  
  bool replace=false;

  switch(sqlType) {
  case Types::DATE:
    if(dh->getSQLType()!=Types::DATE) {
      newType=Types::DATE;
      replace=true;
    }
    break;

  case Types::TIME:
    if(dh->getSQLType()!=Types::TIME) {
      newType=Types::TIME;
      replace=true;
    }
    break;

  case Types::TIMESTAMP:
    if(dh->getSQLType()!=Types::TIMESTAMP) {
      newType=Types::TIMESTAMP;
      replace=true;
    }
    break;

  case Types::LONGVARCHAR:
    if(dh->getSQLType()!=Types::LONGVARCHAR) {
      newType=Types::LONGVARCHAR;
      replace=true;
    }
    break;

  case Types::LONGVARBINARY:
    if(dh->getSQLType()!=Types::LONGVARBINARY) {
      newType=Types::LONGVARBINARY;
      replace=true;
    }
    break;

  case Types::VARBINARY:
    {
      switch(dh->getSQLType()) {
      case Types::BINARY:
      case Types::VARBINARY:
      case Types::LONGVARBINARY:
	break;

      default:
	replace=true;
	newType=Types::VARBINARY;
	newPrec=255;
      }
    }
    break;
  }

  if(replace) {
    if(paramsBound_) {
      // we are changing a buffer address, unbind 
      // the parameters
      this->_unbindParams();
    }
    rowset_->replaceColumn(idx,newType,newPrec,newScale);
  }
}


void PreparedStatement::_setupParams()
{
  if(!this->_getDriverInfo()->supportsFunction(SQL_API_SQLNUMPARAMS)) {
    return;
  }

  SQLSMALLINT np;
  SQLRETURN r=SQLNumParams(hstmt_,&np);
  this->_checkStmtError(hstmt_,r,"Error fetching number of parameters");
  numParams_=np;
  
  SQLSMALLINT sqlType;
  SQLUINTEGER prec;
  SQLSMALLINT scale;
  SQLSMALLINT nullable;

  if(this->_getDriverInfo()->supportsFunction(SQL_API_SQLDESCRIBEPARAM)) {
   
    for(size_t i=0; i<numParams_; i++) {
      r=SQLDescribeParam(hstmt_,
			 i+1,
			 &sqlType,
			 &prec,
			 &scale,
			 &nullable);
      
      this->_checkStmtError(hstmt_,r,"Error obtaining parameter information");
      rowset_->addColumn(sqlType,prec,scale);
      directions_.push_back(defaultDirection_);
    }
  } else {
    // default all parameters to VARCHAR(255)
    for(size_t i=0; i<numParams_; i++) {
      rowset_->addColumn(Types::VARCHAR,255,0);
      directions_.push_back(defaultDirection_);
    }
  }
}


void PreparedStatement::_bindParams()
{
  SQLRETURN r;
  for(size_t i=1; i<=numParams_; i++) {
    DataHandler* dh=rowset_->getColumn(i);

    //simple bind
    if(!dh->isStreamed_) {
      r=SQLBindParameter(hstmt_,
			 (SQLUSMALLINT)i,
			 (SQLSMALLINT)directions_[i-1],
			 (SQLSMALLINT)dh->cType_,
			 (SQLSMALLINT)dh->sqlType_,
			 (SQLUINTEGER)dh->precision_,
			 (SQLSMALLINT)dh->scale_,
			 (SQLCHAR*)dh->data(),
			 dh->bufferSize_,
			 dh->dataStatus_);

    } else {
      //we send in dh->dataStatus_, as it contains
      //SQL_LEN_DATA_AT_EXEC(size) after setStream, or SQL_NULL_DATA
      r=SQLBindParameter(hstmt_,
			 (SQLUSMALLINT)i,
			 (SQLSMALLINT)directions_[i-1],
			 (SQLSMALLINT)dh->cType_,
			 (SQLSMALLINT)dh->sqlType_,
			 0, //doesn't apply to streamed types
			 0, //same here
			 (SQLPOINTER)i, //our column index
			 0, //doesn't apply
			 dh->dataStatus_);
    }
    this->_checkStmtError(hstmt_,r,"Error binding parameter");
  }

  paramsBound_=true;
}

void PreparedStatement::_unbindParams()
{
  SQLRETURN r=SQLFreeStmt(hstmt_,SQL_RESET_PARAMS);
  this->_checkStmtError(hstmt_,r,"Error unbinding parameters");
  
  //notify our parameters (should this go into execute()?)
  for(size_t i=1; i<=numParams_; i++) {
    rowset_->getColumn(i)->afterUpdate();
  }
  
  paramsBound_=false;
}



bool PreparedStatement::execute()
{
#if 0
  unsigned int nc=rowset_->getColumns();
  cout << "Entering PreparedStatement::execute()" << endl
       << "Parameters: " << numParams_ << ", rowset columns: " << nc << endl;

  for(int i=1; i<=nc; i++) {
    DataHandler* dh=rowset_->getColumn(i);
    cout << "Parameter " << i << ":" << endl;
    cout << "SQLType   : " << dh->getSQLType() << endl
	 << "Value     : " << (dh->isNull()?"<NULL>":dh->getString()) << endl
	 << "Direction : " << directions_[i-1] << endl
	 << endl;
  }
#endif

  this->_beforeExecute();
  
  if(!paramsBound_) {
    this->_bindParams();
  }
  
  SQLRETURN r=SQLExecute(hstmt_);

  string msg="Error executing \""+sql_+"\"";
  this->_checkStmtError(hstmt_,r,msg.c_str());

  //the following should maybe join with ResultSet::updateRow/insertRow in
  //some way

  if(r==SQL_NEED_DATA) {
    char buf[PUTDATA_CHUNK_SIZE];
    
    while(r==SQL_NEED_DATA) {
      SQLPOINTER currentCol;
      r=SQLParamData(hstmt_,&currentCol);
      this->_checkStmtError(hstmt_,r,"SQLParamData failure");
      if(r==SQL_NEED_DATA) {
	DataHandler* dh=rowset_->getColumn((int)currentCol);
	
	if(dh->isNull()) {
	  continue;
	}
	
	std::istream* s=dh->getStream();
	if(s==NULL || !(*s)) {
	  //hmm. if this occurs, I do something wrong
	  continue;
	}
	
	do {
	  s->read(buf,PUTDATA_CHUNK_SIZE);
	  if(s->gcount()>0) {
	    SQLPutData(hstmt_,(SQLCHAR*)buf,s->gcount());
	  }
	} while(*s);
      }
    }
  }

  this->_afterExecute();

  return this->_checkForResults();
}


ResultSet* PreparedStatement::executeQuery()
{
  this->execute();
  return this->getResultSet();
}


int PreparedStatement::executeUpdate()
{
  this->execute();
  return this->getUpdateCount();
}


void PreparedStatement::clearParameters()
{
  if(paramsBound_) {
    this->_unbindParams();
  }

  for(size_t i=1; i<=numParams_; i++) {
    rowset_->getColumn(i)->setNull();
  }
}

void PreparedStatement::setNull(int idx, int sqlType)
{
  int defPrec=DataHandler::defaultPrecisionFor(sqlType);
  this->_checkParam(idx,sqlType,defPrec,0);
  rowset_->getColumn(idx)->setNull();
}

#define IMPLEMENT_SET(TYPE,FUNCSUFFIX,SQLTYPE,DEFPREC)		\
void PreparedStatement::set##FUNCSUFFIX(int idx, TYPE val)	\
{								\
  this->_checkParam(idx,SQLTYPE,DEFPREC,0);			\
  rowset_->getColumn(idx)->set##FUNCSUFFIX(val);		\
}


IMPLEMENT_SET(double,Double,Types::DOUBLE,0);
IMPLEMENT_SET(bool,Boolean,Types::BIT,0);
IMPLEMENT_SET(signed char,Byte,Types::TINYINT,0);
IMPLEMENT_SET(float, Float,Types::REAL,0);
IMPLEMENT_SET(int,Int,Types::INTEGER,0);
IMPLEMENT_SET(Long,Long,Types::BIGINT,0);
IMPLEMENT_SET(short,Short,Types::SMALLINT,0);
IMPLEMENT_SET(const string&, String,Types::VARCHAR,255);
IMPLEMENT_SET(const Date&,Date,Types::DATE,0);
IMPLEMENT_SET(const Time&,Time,Types::TIME,0);
IMPLEMENT_SET(const Timestamp&, Timestamp,Types::TIMESTAMP,0);
IMPLEMENT_SET(const Bytes&, Bytes, Types::VARBINARY,0);

void PreparedStatement::setAsciiStream(int idx, std::istream* s, int len)
{
  this->_checkParam(idx,Types::LONGVARCHAR,0,0);
  rowset_->getColumn(idx)->setStream(s,len);
}

void PreparedStatement::setBinaryStream(int idx, std::istream* s, int len)
{
  this->_checkParam(idx,Types::LONGVARBINARY,0,0);
  rowset_->getColumn(idx)->setStream(s,len);
}
