/*
 *              __  ____________        ____         __    
 *             / / / /_  __/ __/ ____  / __/______ _/ /__ _
 *            / /_/ / / / _\ \  /___/ _\ \/ __/ _ `/ / _ `/
 *            \____/ /_/ /___/       /___/\__/\_,_/_/\_,_/ 
 * 
 * This file is part of an implementation of the Universe Type System for
 * Scala.
 * 
 * Copyright (C) 2007-2008  Swiss Federal Institute of Technology, Zurich
 * 
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License along
 * with this program; if not, write to the Free Software Foundation, Inc.,
 * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 * 
 * 
 * $Id: DefaultTypeRules.scala 908 2008-02-08 19:20:21Z ms $
 */
package ch.ethz.inf.sct.uts.plugin.staticcheck.rules.default

import scala.tools.nsc._

import ch.ethz.inf.sct.uts.annotation._
import ch.ethz.inf.sct.uts.plugin.common._

/**
 * Encapsulate helper functions and Universe types.
 * @param global Reference to the <code>Global</code> instance the plugin got instantiated with.
 * @param logger Reference to a <code>UTSLogger</code> object.
 * 
 * @author  Manfred Stock
 * @version $Revision: 908 $
 */
class DefaultTypeRules[G <: Global](override val global: G, override val logger: UTSLogger) 
  extends TypeRules[G](global,logger) with DefaultTypeAbstraction[G] {
  
  import global._
  import UTSDefaults._
  import extendedType._
  import Utils._
  
  /**
   * Check owner-as-modifier restrictions as well?
   */
  private var checkoam = false
  
  /**
   * Process options which were not accepted by the plugin itself and 
   * may therefore be meant for the used typerules.
   * @param o The option.
   * @return if the option was accepted.
   */
  def processOption(o: String) : Boolean = {
    o match {
      case "oam" => checkoam = true; true
      case _     => false
    }
  }
  
  /**
   * Get help about the options implemented by the given typerules.
   * @param name Name of the plugin which got the option.
   * @return the string with help on the options.
   */
  def getOptionsHelp(name: String) : String = {
    "-P:"+name+":oam                Check owner-as-modifier constraints.\n"
  }
    
  /**
   * Get a list with information about the selected options.
   * @return a list with descriptions about the selected options.
   */
  def getActiveOptions : List[String] = {
    var result : List[String] = Nil
    if (checkoam) {
      result = "Check owner as modifier property." :: result
    }
    result
  }
  
  /**
   * Set if owner-as-modifier restrictions should be checked as well.
   * @param b If the owner-as-modifier restrictions should be applied, too.
   */
  def setOAMCheck(b: Boolean) {
    checkoam = b
  }  

  /**
   * Field lookup to get the type of a field of a class.
   * @param c Class which is supposed to contain the field <code>f</code>.
   * @param f Field one is looking for.
   * @return the type of the field if found.
   */
  def fType(c: Type, f: Symbol) : UType = {
    if (c.typeSymbol.isClass) {
      if (f.isVariable || (f.isValue && !f.isMethod) || f.isModule || f.isType) {
        UType(f.tpe)
      }
      else {
        InvalidType("Given symbol "+f+" not a field.")
      }
    }
    else {
      InvalidType("Given type "+c+" is not a class.") 
    }
  }
  
  /**
   * Field lookup to get the type of a field of a nonvariable type, viewpoint adapted to it.
   * @param c Nonvariable type which is supposed to contain the field <code>f</code>.
   * @param f Field one is looking for.
   * @return the type of the field adapted to <code>nt</code> if found.
   */
  def fType(nt: NType, f: Symbol) : UType = {
    fType(nt.tpe.getClassType,f) match {
      case et: ErroneousType => et
      case t: UType          => nt |> t
    }
  }
  
  /**
   * All field identifiers defined in or inherited by c.
   * @param c Type whose fields are requested.
   * @return a list of the field names.
   */
  def fields(c: Type) : List[String] = {
   	c.members.filter(_.isVariable).map(_.fullNameString)
  }
  
  /**
   * Get the type of a method.
   * @param c The type which should contain the method.
   * @param m The method symbol as found in the tree.
   * @return the type of the method. 
   */
  def mType(c: Type, m: Symbol) : ROption[MethodSignature] = {
    method(m)
  }
  
  /**
   * Get the type of a method, adapt it to a nonvariable type.
   * @param n The type which should contain the method, and the signature should be adapted to.
   * @param m The method symbol as found in the tree.
   * @return the type of the method. 
   */
  def mType(n: NType,m: Symbol) : ROption[MethodSignature] = {
    mType(n.tpe.getClassType, m) match {
      case RSome(msig) => {
        // Method not owned by n, so try to resolve 'higher' type variables
        var mtype = msig
        if (n.tpe.typeSymbol != m.owner) {
          mtype = mtype.resolveTypeVarsFrom(n)
        }
        RSome(methodSignature(
            // Adapt type parameters
            mtype.typeParams.map(n |> _),
            // Copy purity
            mtype.isPure,
            // Adapt return type
            n |> mtype.returnType,
            // Adapt parameter types
            mtype.paramTypes.map(n |> _),
            // Copy method symbol
            mtype.sym
        ))
      }
      case nonsig@RNone(reason) => nonsig
    }
  }
  
  /**
   * Check if a given method of a class overrides correctly - implements WFM-2.
   * @param m Method to check.
   * @return if the given method successfully overrides the one in the parent class.
   */
  def ovrride(m: Symbol) : ROption[Boolean] = {
    method(m) match {
      case RSome(ms) if m.owner.isClass => {
        // Search base classes for symbols overridden by symbol m
        val overriddenSymbol = {
          val symbols = m.owner.tpe.baseClasses map m.overriddenSymbol filter {NoSymbol !=}
          val overriddenSymbols = symbols filter {_.owner != m.owner}
          if (overriddenSymbols.length > 0) {
            overriddenSymbols.head
          }
          else {
            NoSymbol
          }
        }
        // Check if overriding is done correctly
        overriddenSymbol match {
          case NoSymbol => RSome(true)
          case sym      => {
            method(sym) match {
              case RSome(superms) =>
                // Overriding method may be pure even if the overridden method was not
                if (ms overrides superms) {
                  RSome(true)
                } 
                else {
                  RNone(mismatchedMethods(
                      ms,
                      superms,
                      "Signature of "+m+" found in "+m.owner+
                        " does not match the expected signature of "+sym+" it overrides in "+sym.owner+".")
                  )
                } 
              case RNone(reason)  => RNone(reason)
            }
          }
        }
      }
      // There is no superSymbol since owner is not a class
      case RSome(ms)                    => RSome(true)
      // Failed to initialize method information
      case RNone(reason)                => RNone(reason)
    }
  }

  /**
   * Check if an assignment is ok, possibly respecting the type 
   * of the owner of the reference some value is assigned to.
   * @param vtype           The <code>UType</code> of the field which some value gets assigned to.
   * @param lhsRefOwnerType The <code>UType</code> of the object which contains the reference to the field.
   * @param atype           The <code>UType</code> of the value which is assigned.
   * @return the type of field which got assigned a value.
   */
  def checkAssignment(vtype: UType, lhsRefOwnerType: UType, atype: UType) : UType = {
    def ca = checkAssignment(vtype,atype)
    if (checkoam) {
      lhsRefOwnerType.tpe match {
        case NoType => 
          ca
        case _ => 
          if (! lhsRefOwnerType.modificationOK) {
            InvalidType("OAM: Cannot modify objects in arbitrary contexts: "+lhsRefOwnerType)
          }
          else {
            ca
          }
      }
    }
    else {
      ca
    }
  }
    
  /**
   * Check if an assignment is ok concerning the types.
   * @param lt The <code>UType</code> which some value would be assigned to.
   * @param rt The <code>UType</code> which would be assigned.
   * @return the type of lt.
   */
  def checkAssignment(lt: UType, rt: UType) : UType = {
    if (lt.contains(lost())) {
      InvalidType("Cannot assign a value to a variable of type "+lt+" since it contains a @lost ownership modifier.")
    }
    else {
      if (rt.tpe != NoType && !(rt <:< lt)) {
        InvalidType(typeError(rt,lt,"Actual type which was found is not a subtype of the formal type in the assignment."))
      }
      else {
        lt
      }
    }
  }
  
  /**
   * Check the invocation of a method (GT-Invk).
   * @param receiver Type of the receiver of the method call.
   * @param method   Symbol of the called method.
   * @param args     Arguments to the method call.
   * @param targs    Type arguments to the call which should be checked.
   * @return the return type of the method call, if any.
   */
  def checkInvocation(receiver: UType, method: Symbol, args: List[UType], targs: List[UType]) : UType = {
    val N0 = receiver match {
      case tv: TVarID => tv.ubgamma
      case ot         => ot
    }
    
    if (method == null) {
      // Happens eg. in call to constructor of case-class Foo in case Foo() => ...
      N0
    }
    else {
      N0 match {
        case nt: NType  => mType(nt, method) match {
          case RSome(msig: MethodSignature) => {
            var ms = msig.substTypeArgs(targs,args)
            if (ms.paramTypes.contain(lost())) {
              if (method.isSetter) {                
                InvalidType("Viewpoint adapted type of updated field contains @lost: "+ms.paramTypes.head)
              }
              else {
                InvalidType("Formal method parameters must not contain @lost in viewpoint adapted call to "+ms+" on "+nt+".")
              }
            }
            else {
              if (ms.typeParams.contain(lost())) {
                InvalidType("Formal method type parameters must not contain @lost: "+ms.typeParams+".")
              }
              else {
                var errors = InvalidTypes(Nil)
                if (List.forall2(targs, ms.typeParams)((a,b) => { // Check of type parameters
                  (a <:< b) || { errors = errors + InvalidType(typeError(a,b,"in type parameters of the call to "+ms+".")); false }
                })
                && 
                  List.forall2(args, ms.paramTypes)((a,b) => { // Check of parameter types
                    (a <:< b) || { errors = errors + InvalidType(typeError(a,b,"in parameter types of the call to "+ms+".")); false }
                  })) {
                  if (checkoam && !nt.modificationOK && !ms.isPure) {
                    if (method.isSetter) {
                      InvalidType("Cannot assign a value to a field of a non-modifiable reference (ie. "+nt+") when enforcing owner-as-modifier.")
                    }
                    else {
                      InvalidType("Cannot call non-pure methods on non-modifiable references (ie. "+nt+") when enforcing owner-as-modifier.")
                    }
                  }
                  else {
                    ms.returnType
                  }
                }
                else {
                  errors
                }
              }
            }
          }
          case RNone(reason) => InvalidType(reason)
        }
        case tv: TVarID => InvalidType("Invocation on type variable: "+tv)
        case et         => et
      }
    }
  }
  
  /**
   * Check if a given Type can be instantiated.
   * @param t Type to be checked.
   */
  def checkNew(t: UType) : UType = {
    t match {
      case nt: NType  => 
        nt.om match { 
          case any() => InvalidType("Cannot create a new instance of an any reference.")
          case _ => if (nt.contains(some()) || nt.contains(lost())) {
                      InvalidType("Nonvariable type must not contain @some or @lost.")
                    }
                    else {
                      nt
                    }
        }
      case tv: TVarID => InvalidType("Cannot create new instance of a type variable.")
      case et         => et
    }
  }

  /**
   * Get the type of a field read.
   * @param nt Class which contains the field.
   * @param f  Symbol of the field.
   * @return the type of the field. 
   */
  def fieldRead(nt: NType, f: Symbol) : UType = {
    fType(nt, f)
  }
  
  /**
   * Check a class definition.
   * @param cls    Symbol of the class.
   * @param params Type parameters.
   * @return the list of type parameters, possibly containing <code>InvalidType</code>
   *         instances if a parameter was invalid.
   */
  def checkClassDefinition(cls: Symbol, params: List[(util.Position,UType)]) : List[UType] = {
    // Check WFC
    params map {
      // Syntactically allowed by Scala: class Foo[Int] {} is valid, but does not make a lot of sense, so warn about it here. 
      case (t,nt: NType)               => InvalidType("Nonvariable type parameter detected.")
      case (t,tv: TVarID)              =>
        tv.ubgamma.check match {
          case nt: NType         => 
            if (nt.contains(rep())) {
              InvalidType("Type parameter contains forbidden @rep annotation in upper bound.").setPosition(t)  
            }
            else {
              nt
            }
          // Upper-bound of type variable should never be a type variable itself.
          case tvb: TVarID       => InvalidType("Found type variable "+tvb+" as upper bound for "+tv+", which should not happen.").setPosition(t)
          case it: InvalidType   => it.setPosition(t)
          case et: ErroneousType => et
        }
      case (t,it: InvalidType)   => it.setPosition(t)
      case (t,et: ErroneousType) => et
    }
  }
  
  /**
   * Check the definition of a value, ie. field or local variable.
   * @param sym    Symbol which is defined.
   * @param rhstpe Type of the initializer.
   */
  def checkValueDefinition(sym: Symbol, rhstpe: UType) : UType = {
    checkAssignment(UType(sym.tpe),rhstpe)
  }

  /**
   * Check a method definition.
   * @param method   The symbol of the defined method.
   * @param name     Name of the method.
   * @param tparams  Type parameters of the method.
   * @param vparamss Value parameter types.
   * @param frtype   Formal return type.
   * @param artype   Actual return type.
   */
  def checkMethodDefinition(method: Symbol, name: String, tparams: List[UType], vparamss: List[List[UType]], frtype: UType, artype: UType) : UType = {
    // Errors found during the check
    var errors = InvalidTypes(Nil)
    // Check WFM-1:
    // Check type parameters (ie. their upper bounds)
    tparams map {
      case tv: TVarID => tv.ubgamma
      case t          => t
    } foreach {
      case et: ErroneousType => errors = errors + et
      case _                 => ()
    }
    
    // Check value parameters
    vparamss foreach { t => t foreach { 
      case et: ErroneousType => errors = errors + et
      case _                 => ()
    }}
    
    //  Check if actual return type is subtype of formal type. Special
    // handling for methods "returning" Unit and constructor methods 
    // which usually don't return anything.
    if (! ((artype <:< frtype) || frtype.isUnit || method.isConstructor)) {
      errors = errors + InvalidType(typeError(artype,frtype,"Found return type of method "+name
          +" does not match/subtype expected formal return type." ))
    }
    
    //  Check if overriding is ok. WFM-2.
    ovrride(method) match {
      case RNone(reason) => errors = errors + InvalidType(reason)
      case _ => ()
    }
  
    if (!isMethodPureAnnotated(method) && methodAssumableAsPure(method)) {
      logger.warn("Method "+name+" is being assumed to be pure, but is not annotated as really being pure.")
    }
    
    errors orElse frtype
  }
  
  /**
   * Lift type <code>a</code> up to the type of <code>b</code>, possibly by going to upper 
   * bounds of type variables.
   * @param a Source type.
   * @param b Target type where the source type should be lifted to.
   * @return the lifted type.
   */
  private def lift(a: UType, b: UType) : UType = {
    a match {
      case t0: NType => b match {
        case t1: NType  => t0.liftTo(t1)
        case t1: TVarID => lift(t0,t1.ubgamma)
        case _         => a
      }
      case t: TVarID => lift(t.ubgamma,b)
      case _         => a
    }
  }
  
   /**
    * Adapt ownership modifiers of <code>it</code> with the goal that if finally is a 
    * common supertype of <code>a</code> and <code>b</code>.
    * @param it Type which should get adapted ownership modifiers
    * @param a  Future subtype 0 of <code>it</code>.
    * @param b  Future subtype 1 of <code>it</code>.
    * @return <code>it</code> with adapted ownership modifiers.
    */
   private def liftOm(it: UType, a: UType, b: UType) : UType = {
     (it.ubgamma,a.ubgamma,b.ubgamma) match {
       case (nit: NType, na: NType, nb: NType) =>
         val mainom = if (na.om != nb.om) any() else na.om
         nType(
             mainom,
             nit.tpe,
             List.map3(nit.typeParams,na.typeParams,nb.typeParams)(
                 (it,tt,et) => liftOm(it,tt,et)
             )
         )
       case (eit,ea,eb)  => 
         var errors = new InvalidTypes(Nil) 
         List(eit,ea,eb) foreach {
           case et: ErroneousType => errors = errors + et
           case _ => ()
         } 
         errors
     }              
   }
   
  /**
   * Check if-then-else statement.
   * @param iftpe     Type of the if expression.
   * @param cond      Type of the condition.
   * @param then      Type of the <code>then</code> branch.
   * @param thenpos   Position of then block.
   * @param otherwise Type of the <code>else</code> branch.
   * @param elsepos   Position of else block.
   * @return the type of the if-then-else statement.
   */
  def checkIfThenElse(iftpe: UType, cond: UType, then: UType, thenpos: util.Position, otherwise: UType, elsepos: util.Position) : UType = {
		// Simple case: If-expression's type already contained ownership modifiers
    if (iftpe.isUnderlyingOwnershipAnnotated) {
      if (then <:< iftpe) {
        if (otherwise.tpe != NoType && otherwise <:< iftpe) {
          iftpe
        }
        else {
          new InvalidType(typeError(otherwise,iftpe,"Type found for else block not a subtype of the if-statement's expected type."),elsepos)
        }
      }
      else {
        new InvalidType(typeError(then,iftpe,"Type found for then block not a subtype of the if-statement's expected type."),thenpos)
      }
    }
    else {
      // Check if either one of the branches' results is a subtype of the other, return type which is higher
      if (otherwise.tpe != NoType) {
        if (then <:< otherwise) {
          otherwise
        } else {
          if (otherwise <:< then) {
            then
          }
          else {
            // Less simple case: No ownership modifiers in if-expression, branches don't
            // subtype each other -> find some sort of an upper bound for the types
            
            // Bring underlying types of then and else branch on the level of the if-expression's type
            val lthen = lift(then,iftpe)
            val lelse = lift(otherwise,iftpe)
            // Adapt ownership modifiers of it-expressions's to the branches
            liftOm(iftpe,lthen,lelse)
          }
        }
      }
      else {
        then
      }   
    }
  }
  
  /**
   * Check a pattern match.
   * @param matchtype Type of the match expression as given by the Scala compiler.
   * @param casetypes Types of the different case statements.
   * @return the type of the pattern match.
   */
  def checkMatch(matchtype: UType, casetypes: List[(util.Position,UType)]) : UType = {
    if (matchtype.isUnderlyingOwnershipAnnotated) {
      val errors = casetypes filter {p => !(snd(p) <:< matchtype)}
      if (! errors.isEmpty) {
        InvalidTypes(
            errors map {
              case (p,t) => InvalidType(typeError(t,matchtype,"Type of a case block does not subtype expected type of match.")).setPosition(p)
            }
        )
      }
      else {
        matchtype
      }
    }
    else {
      val lifted = casetypes map {case (p,t) => (p,lift(t,matchtype))}
      lifted.length match {
        case 0 => 
          matchtype 
        case 1 => 
          snd(lifted.head)
        case _ =>
          val base = liftOm(matchtype,snd(casetypes.head),snd(casetypes.tail.head))
          if (lifted.length > 2) {
            val tl = lifted.tail.tail
            tl.foldRight(base)((cur,t) => liftOm(matchtype,snd(cur),t))
          }
          else {
            base
          }
      }
    }
  }
 
  /**
   * Check a <code>try</code> statement. 
   * @param trytype       Type of the try statement.
   * @param blocktype     Type of the block in the try statement.
   * @param blockpos      Position of the block.
   * @param catchestypes  Positions and types of the <code>catch</code> statements.
   * @param finalizertype Type of the <code>finalize</code> statement.
   * @return the type of the <code>try</code> statement.
   */
  def checkTry(trytype: UType, blocktype: UType, blockpos: util.Position, catchestypes: List[(util.Position,UType)],finalizertype: UType) : UType = {
    //Check that the try-block result and the results of the catch-blocks 
    // all subtype the type of the  try-statement
    if (! (blocktype <:< trytype)) {
      InvalidType(typeError(blocktype,trytype,"Found type of try-block does not subtype formal type of the try-catch-finally statement.")).setPosition(blockpos)
    }
    else {
      val errors = catchestypes filter {case (c,t) => ! (t <:< trytype)}
      if (! (errors.isEmpty)) {
        InvalidTypes(
            errors map { 
              case (p,t) => InvalidType(typeError(t,trytype,"Type of catch-block does not subtype formal type of the try-catch-finally statement")).setPosition(p)
            }
        )
      }
      else {
        blocktype
      }
    }
  }
  
  /**
   * Check a <code>this</code> access.
   * @param tpe Type of <code>this</code>
   * @return the possibly modified type of this.
   */
  def checkThis(tpe: UType) : UType = {
    tpe match {
      case et: ErroneousType => et
      case nt: NType         => nt.setOwnershipModifier(thiz())
      case t                 => InvalidType("No NType for 'this' found. Should not really happen... Only found "+t)
    }
  }
 
  /**
   * Check a field read.
   * @param target Type of the target on which the select takes place.
   * @param field  The field which is selected.
   * @return the type of the field, after viewpoint adaptation etc.
   */
  def checkSelect(target: UType, field: Symbol) : UType = {
    target match {
      case nt: NType  => fieldRead(nt, field)
      case tv: TVarID => tv.ubgamma match {
        case nt: NType => fieldRead(nt, field)
        case _         => InvalidType("Fieldreads on type variables not allowed.")
      }
      case t          => InvalidType("Type of n0 unknown in GT-Read: "+t)
    }
  }
  
  /**
   * Check a typed expression (ie. something like expr: tpt
   * @param expr Type of the expression.
   * @param tpe  Type the expression should have.
   */
  def checkTyped(expr: UType, tpe: UType) : UType = {
    if (!(expr <:< tpe)) {
      InvalidType(typeError(expr,tpe,"found type not a subtype of expected type in Typed expression."))
    }
    else {
      tpe
    }
  }
}